Repository: allenai/tango Branch: main Commit: 6aaa8ff0f203 Files: 329 Total size: 1.2 MB Directory structure: gitextract_p3vof5f6/ ├── .dockerignore ├── .github/ │ ├── CONTRIBUTING.md │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.yml │ │ ├── documentation.yml │ │ └── feature_request.yml │ ├── dependabot.yml │ └── workflows/ │ ├── changelog.yml │ ├── docker.yml │ ├── docker_testing.yml │ ├── integration_tests.yml │ ├── main.yml │ └── update_dependency_pr.yml ├── .gitignore ├── .readthedocs.yaml ├── CHANGELOG.md ├── CITATION.cff ├── Dockerfile ├── Dockerfile.test ├── LICENSE ├── Makefile ├── README.md ├── RELEASE_PROCESS.md ├── docs/ │ ├── .gitignore │ ├── Makefile │ ├── make.bat │ └── source/ │ ├── _static/ │ │ └── css/ │ │ └── custom.css │ ├── api/ │ │ ├── commands.rst │ │ ├── components/ │ │ │ ├── executor.rst │ │ │ ├── format.rst │ │ │ ├── index.rst │ │ │ ├── step.rst │ │ │ ├── step_cache.rst │ │ │ ├── step_graph.rst │ │ │ ├── step_info.rst │ │ │ └── workspace.rst │ │ ├── det_hash.rst │ │ ├── exceptions.rst │ │ ├── integrations/ │ │ │ ├── beaker.rst │ │ │ ├── datasets.rst │ │ │ ├── fairscale.rst │ │ │ ├── flax.rst │ │ │ ├── gs.rst │ │ │ ├── index.rst │ │ │ ├── torch.rst │ │ │ ├── transformers.rst │ │ │ └── wandb.rst │ │ ├── logging.rst │ │ ├── sequences.rst │ │ ├── settings.rst │ │ └── utilities.rst │ ├── conf.py │ ├── examples/ │ │ ├── euler.md │ │ ├── eval_p3.md │ │ ├── index.rst │ │ └── train_lm.md │ ├── faq.md │ ├── first_steps.md │ ├── index.md │ └── installation.md ├── examples/ │ ├── euler/ │ │ ├── README.md │ │ ├── complex_arithmetic.py │ │ ├── euler.jsonnet │ │ ├── euler_general.jsonnet │ │ └── run.sh │ ├── eval_p3/ │ │ ├── README.md │ │ ├── config.jsonnet │ │ └── eval.py │ ├── finetune/ │ │ ├── __init__.py │ │ ├── config.jsonnet │ │ ├── snli_steps.py │ │ └── test.py │ ├── finetune_resnet/ │ │ ├── .gitignore │ │ ├── config.jsonnet │ │ └── resnet_steps.py │ ├── flax/ │ │ ├── config.jsonnet │ │ ├── run.sh │ │ └── xsum.py │ └── train_lm/ │ ├── .gitignore │ ├── README.md │ ├── config.jsonnet │ ├── test.py │ └── tokenize_step.py ├── integration_tests/ │ ├── README.md │ └── fairscale_benchmarks/ │ ├── README.md │ ├── config.jsonnet │ └── run.sh ├── pyproject.toml ├── scripts/ │ ├── entrypoint.sh │ ├── hash_extras.py │ ├── prepare_changelog.py │ ├── prepare_citation_cff.py │ ├── release.sh │ └── release_notes.py ├── tango/ │ ├── __init__.py │ ├── __main__.py │ ├── cli.py │ ├── common/ │ │ ├── __init__.py │ │ ├── aliases.py │ │ ├── dataset_dict.py │ │ ├── det_hash.py │ │ ├── exceptions.py │ │ ├── file_lock.py │ │ ├── from_params.py │ │ ├── lazy.py │ │ ├── logging.py │ │ ├── params.py │ │ ├── registrable.py │ │ ├── remote_utils.py │ │ ├── sequences.py │ │ ├── testing/ │ │ │ ├── __init__.py │ │ │ └── steps.py │ │ ├── tqdm.py │ │ └── util.py │ ├── executor.py │ ├── executors/ │ │ ├── __init__.py │ │ └── multicore_executor.py │ ├── format.py │ ├── integrations/ │ │ ├── __init__.py │ │ ├── beaker/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── entrypoint.sh │ │ │ ├── executor.py │ │ │ ├── step_cache.py │ │ │ └── workspace.py │ │ ├── datasets/ │ │ │ └── __init__.py │ │ ├── fairscale/ │ │ │ ├── __init__.py │ │ │ ├── fsdp_config.py │ │ │ ├── module_wrapper.py │ │ │ └── training_engine.py │ │ ├── flax/ │ │ │ ├── __init__.py │ │ │ ├── data.py │ │ │ ├── eval.py │ │ │ ├── eval_callback.py │ │ │ ├── format.py │ │ │ ├── model.py │ │ │ ├── optim.py │ │ │ ├── train.py │ │ │ ├── train_callback.py │ │ │ ├── train_config.py │ │ │ ├── util.py │ │ │ └── wrapper.py │ │ ├── gs/ │ │ │ ├── __init__.py │ │ │ ├── common.py │ │ │ ├── step_cache.py │ │ │ └── workspace.py │ │ ├── torch/ │ │ │ ├── __init__.py │ │ │ ├── data.py │ │ │ ├── eval.py │ │ │ ├── eval_callback.py │ │ │ ├── exceptions.py │ │ │ ├── format.py │ │ │ ├── model.py │ │ │ ├── optim.py │ │ │ ├── train.py │ │ │ ├── train_callback.py │ │ │ ├── train_config.py │ │ │ ├── training_engine.py │ │ │ └── util.py │ │ ├── transformers/ │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── data.py │ │ │ ├── finetune.py │ │ │ ├── ia3.py │ │ │ ├── model.py │ │ │ ├── optim.py │ │ │ ├── run_generation.py │ │ │ ├── soft_prompt.py │ │ │ └── tokenizer.py │ │ └── wandb/ │ │ ├── __init__.py │ │ ├── flax_train_callback.py │ │ ├── step_cache.py │ │ ├── torch_train_callback.py │ │ ├── util.py │ │ └── workspace.py │ ├── py.typed │ ├── settings.py │ ├── step.py │ ├── step_cache.py │ ├── step_caches/ │ │ ├── __init__.py │ │ ├── local_step_cache.py │ │ ├── memory_step_cache.py │ │ └── remote_step_cache.py │ ├── step_graph.py │ ├── step_info.py │ ├── steps/ │ │ ├── __init__.py │ │ ├── dataset_remix.py │ │ ├── print.py │ │ └── shell_step.py │ ├── version.py │ ├── workspace.py │ └── workspaces/ │ ├── __init__.py │ ├── local_workspace.py │ ├── memory_workspace.py │ └── remote_workspace.py ├── test_fixtures/ │ ├── __init__.py │ ├── beaker/ │ │ └── nvidia_smi.yml │ ├── common/ │ │ ├── params_example.jsonnet │ │ └── params_example.yaml │ ├── experiment/ │ │ ├── hello_world.jsonnet │ │ ├── logging_check.jsonnet │ │ ├── multiprocessing.jsonnet │ │ ├── noisy.jsonnet │ │ └── random.jsonnet │ ├── integrations/ │ │ ├── __init__.py │ │ ├── common/ │ │ │ └── __init__.py │ │ ├── datasets/ │ │ │ └── config.json │ │ ├── fairscale/ │ │ │ ├── __init__.py │ │ │ ├── components.py │ │ │ └── config.jsonnet │ │ ├── flax/ │ │ │ ├── __init__.py │ │ │ ├── config.jsonnet │ │ │ └── xsum.py │ │ └── torch/ │ │ ├── __init__.py │ │ ├── eval.jsonnet │ │ ├── train.jsonnet │ │ ├── train_dist.jsonnet │ │ └── train_streaming.jsonnet │ └── v1_local_workspace/ │ └── cache/ │ ├── AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/ │ │ ├── cache-metadata.json │ │ ├── conda-environment.yaml │ │ ├── executor-metadata.json │ │ ├── lock │ │ ├── requirements.txt │ │ └── stepinfo.dill │ ├── CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/ │ │ ├── cache-metadata.json │ │ ├── conda-environment.yaml │ │ ├── executor-metadata.json │ │ ├── lock │ │ ├── requirements.txt │ │ └── stepinfo.dill │ ├── ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/ │ │ ├── cache-metadata.json │ │ ├── conda-environment.yaml │ │ ├── executor-metadata.json │ │ ├── lock │ │ ├── requirements.txt │ │ └── stepinfo.dill │ ├── MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/ │ │ ├── cache-metadata.json │ │ ├── conda-environment.yaml │ │ ├── executor-metadata.json │ │ ├── lock │ │ ├── requirements.txt │ │ └── stepinfo.dill │ ├── MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/ │ │ ├── cache-metadata.json │ │ ├── conda-environment.yaml │ │ ├── executor-metadata.json │ │ ├── lock │ │ ├── requirements.txt │ │ └── stepinfo.dill │ ├── SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/ │ │ ├── cache-metadata.json │ │ ├── conda-environment.yaml │ │ ├── executor-metadata.json │ │ ├── lock │ │ ├── requirements.txt │ │ └── stepinfo.dill │ └── SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/ │ ├── cache-metadata.json │ ├── conda-environment.yaml │ ├── executor-metadata.json │ ├── lock │ ├── requirements.txt │ └── stepinfo.dill └── tests/ ├── __init__.py ├── common/ │ ├── __init__.py │ ├── dataset_dict_test.py │ ├── det_hash_test.py │ ├── from_params_pep_563_test.py │ ├── from_params_test.py │ ├── params_test.py │ ├── registrable_test.py │ ├── sequences_test.py │ └── util_test.py ├── end_to_end/ │ ├── test_dataset_dict_from_separate_steps.py │ ├── test_lazy_input_with_another_step.py │ ├── test_multicore_cli.py │ ├── test_non_cacheable_into_cacheable_multiple_runs.py │ ├── test_registered_runs.py │ ├── test_run_single_step.py │ ├── test_step_indexing.py │ ├── test_steps_that_fail.py │ └── test_uncacheable_leaf_steps.py ├── executor_test.py ├── executors/ │ ├── __init__.py │ └── multicore_executor_test.py ├── format_test.py ├── integrations/ │ ├── __init__.py │ ├── beaker/ │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── executor_test.py │ │ ├── step_cache_test.py │ │ └── workspace_test.py │ ├── datasets/ │ │ ├── __init__.py │ │ └── dataset_test.py │ ├── fairscale/ │ │ ├── __init__.py │ │ └── train_test.py │ ├── flax/ │ │ ├── __init__.py │ │ ├── data_test.py │ │ ├── format_test.py │ │ ├── optim_test.py │ │ └── train_test.py │ ├── gs/ │ │ ├── __init__.py │ │ ├── step_cache_test.py │ │ └── workspace_test.py │ ├── torch/ │ │ ├── __init__.py │ │ ├── data_test.py │ │ ├── det_hash_test.py │ │ ├── eval_test.py │ │ ├── format_test.py │ │ ├── optim_test.py │ │ ├── train_callback_test.py │ │ ├── train_test.py │ │ └── training_engine_test.py │ ├── transformers/ │ │ ├── data_test.py │ │ ├── finetune_test.py │ │ ├── ia3_test.py │ │ ├── run_generation_test.py │ │ └── soft_prompt_test.py │ └── wandb/ │ ├── __init__.py │ ├── step_cache_test.py │ └── workspace_test.py ├── main_test.py ├── step_caches/ │ ├── __init__.py │ └── local_step_cache_test.py ├── step_graph_test.py ├── step_info_test.py ├── step_test.py ├── steps/ │ ├── __init__.py │ ├── dataset_remix_test.py │ └── shell_step_test.py └── workspaces/ ├── __init__.py ├── local_workspace_test.py └── memory_workspace_test.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .dockerignore ================================================ .dockerignore **.pyc **/__pycache__ .gitignore .git .coverage .mypy_cache docs examples tests test_fixtures integration_tests dist *.egg-info ================================================ FILE: .github/CONTRIBUTING.md ================================================ # Contributing Thanks for considering contributing! Please read this document to learn the various ways you can contribute to this project and how to go about doing it. ## Bug reports and feature requests ### Did you find a bug? First, do [a quick search](https://github.com/allenai/tango/issues) to see whether your issue has already been reported. If your issue has already been reported, please comment on the existing issue. Otherwise, open [a new GitHub issue](https://github.com/allenai/tango/issues). Be sure to include a clear title and description. The description should include as much relevant information as possible. The description should explain how to reproduce the erroneous behavior as well as the behavior you expect to see. Ideally you would include a code sample or an executable test case demonstrating the expected behavior. ### Do you have a suggestion for an enhancement or new feature? We use GitHub issues to track feature requests. Before you create a feature request: - Make sure you have a clear idea of the enhancement you would like. If you have a vague idea, consider discussing it first on a GitHub issue. - Check the documentation to make sure your feature does not already exist. - Do [a quick search](https://github.com/allenai/tango/issues) to see whether your feature has already been suggested. When creating your request, please: - Provide a clear title and description. - Explain why the enhancement would be useful. It may be helpful to highlight the feature in other libraries. - Include code examples to demonstrate how the enhancement would be used. ## Making a pull request When you're ready to contribute code to address an open issue, please follow these guidelines to help us be able to review your pull request (PR) quickly. 1. **Initial setup** (only do this once)
Expand details 👇
If you haven't already done so, please [fork](https://help.github.com/en/enterprise/2.13/user/articles/fork-a-repo) this repository on GitHub. Then clone your fork locally with git clone https://github.com/USERNAME/tango.git or git clone git@github.com:USERNAME/tango.git At this point the local clone of your fork only knows that it came from _your_ repo, github.com/USERNAME/tango.git, but doesn't know anything the _main_ repo, [https://github.com/allenai/tango.git](https://github.com/allenai/tango). You can see this by running git remote -v which will output something like this: origin https://github.com/USERNAME/tango.git (fetch) origin https://github.com/USERNAME/tango.git (push) This means that your local clone can only track changes from your fork, but not from the main repo, and so you won't be able to keep your fork up-to-date with the main repo over time. Therefore you'll need to add another "remote" to your clone that points to [https://github.com/allenai/tango.git](https://github.com/allenai/tango). To do this, run the following: git remote add upstream https://github.com/allenai/tango.git Now if you do `git remote -v` again, you'll see origin https://github.com/USERNAME/tango.git (fetch) origin https://github.com/USERNAME/tango.git (push) upstream https://github.com/allenai/tango.git (fetch) upstream https://github.com/allenai/tango.git (push) Finally, you'll need to create a Python 3 virtual environment suitable for working on this project. There a number of tools out there that making working with virtual environments easier. The most direct way is with the [`venv` module](https://docs.python.org/3.8/library/venv.html) in the standard library, but if you're new to Python or you don't already have a recent Python 3 version installed on your machine, we recommend [Miniconda](https://docs.conda.io/en/latest/miniconda.html). On Mac, for example, you can install Miniconda with [Homebrew](https://brew.sh/): brew install miniconda Then you can create and activate a new Python environment by running: conda create -n tango python=3.9 conda activate tango Once your virtual environment is activated, you can install your local clone in "editable mode" with pip install -U pip setuptools wheel pip install -e '.[dev,all]' The "editable mode" comes from the `-e` argument to `pip`, and essential just creates a symbolic link from the site-packages directory of your virtual environment to the source code in your local clone. That way any changes you make will be immediately reflected in your virtual environment. To test your installation, just run tango info
2. **Ensure your fork is up-to-date**
Expand details 👇
Once you've added an "upstream" remote pointing to [https://github.com/allenai/tango.git](https://github.com/allenai/tango), keeping your fork up-to-date is easy: git checkout main # if not already on main git pull --rebase upstream main git push
3. **Create a new branch to work on your fix or enhancement**
Expand details 👇
Committing directly to the main branch of your fork is not recommended. It will be easier to keep your fork clean if you work on a separate branch for each contribution you intend to make. You can create a new branch with # replace BRANCH with whatever name you want to give it git checkout -b BRANCH git push -u origin BRANCH
4. **Test your changes**
Expand details 👇
Our continuous integration (CI) testing runs [a number of checks](https://github.com/allenai/tango/actions) for each pull request on [GitHub Actions](https://github.com/features/actions). You can run most of these tests locally, which is something you should do _before_ opening a PR to help speed up the review process and make it easier for us. First, you should run [`isort`](https://github.com/PyCQA/isort) and [`black`](https://github.com/psf/black) to make sure you code is formatted consistently. Many IDEs support code formatters as plugins, so you may be able to setup isort and black to run automatically everytime you save. For example, [`black.vim`](https://github.com/psf/black/tree/master/plugin) will give you this functionality in Vim. But both `isort` and `black` are also easy to run directly from the command line. Just run this from the root of your clone: isort . black . Our CI also uses [`ruff`](https://github.com/charliermarsh/ruff) to lint the code base and [`mypy`](http://mypy-lang.org/) for type-checking. You should run both of these next with ruff check . and mypy . We also strive to maintain high test coverage, so most contributions should include additions to [the unit tests](https://github.com/allenai/tango/tree/main/tests). These tests are run with [`pytest`](https://docs.pytest.org/en/latest/), which you can use to locally run any test modules that you've added or changed. For example, if you've fixed a bug in `tango/a/b.py`, you can run the tests specific to that module with pytest -v tests/a/b_test.py If your contribution involves additions to any public part of the API, we require that you write docstrings for each function, method, class, or module that you add. See the [Writing docstrings](#writing-docstrings) section below for details on the syntax. You should test to make sure the API documentation can build without errors by running make docs If the build fails, it's most likely due to small formatting issues. If the error message isn't clear, feel free to comment on this in your pull request. And finally, please update the [CHANGELOG](https://github.com/allenai/tango/blob/main/CHANGELOG.md) with notes on your contribution in the "Unreleased" section at the top. After all of the above checks have passed, you can now open [a new GitHub pull request](https://github.com/allenai/tango/pulls). Make sure you have a clear description of the problem and the solution, and include a link to relevant issues. We look forward to reviewing your PR!
### Writing docstrings We use [Sphinx](https://www.sphinx-doc.org/en/master/index.html) to build our API docs, which automatically parses all docstrings of public classes and methods. All docstrings should adhere to the [Numpy styling convention](https://www.sphinx-doc.org/en/master/usage/extensions/example_numpy.html). ## Adding a new integration In order to add a new integration, there are several additional steps and guidelines you should follow in addition to everything listed in [Making a pull request](#making-a-pull-request). 1. First start by creating a new submodule `tango.integrations.name_of_integration` and put all of the code for your integration in there. 2. Then you must add a module docstring to the `__init__.py` file of the submodule which imports all of the public components of the integration, and defines the [`__all__`](https://docs.python.org/3/tutorial/modules.html#importing-from-a-package) special variable to include all of those components. This ensures all of the public components will show up in the documentation. 3. Next that you should add unit tests of your code to `tests/integrations/name_of_integration/`. 4. Then add a new file `docs/source/api/integrations/name_of_integration.rst`, and include the directive: ``` .. automodule:: tango.integrations.name_of_integration :members: ``` Take a look at any of the other files in that folder to see how it should look exactly. 5. And then add `name_of_integration` to the `toctree` in `docs/source/api/integrations/index.rst`. 6. After that, add any additional requirements that your integration depends on to `requirements.txt`. Be sure to put those under the "Extra dependencies for integrations" section, and add the special inline comment `# needed by: name_of_integration`. 7. And finally, in the `checks` job definition in `.github/workflows/main.yml`, add a new object to the matrix for your integration following the other examples there. ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.yml ================================================ name: 🐛 Bug Report description: Create a report to help us reproduce and fix the bug labels: 'bug' body: - type: markdown attributes: value: > #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/allenai/tango/issues?q=is%3Aissue+sort%3Acreated-desc+). - type: textarea attributes: label: 🐛 Describe the bug description: | Please provide a clear and concise description of what the bug is. If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example: ```python # All necessary imports at the beginning import tango # A succinct reproducing example trimmed down to the essential parts: assert False is True, "Oh no!" ``` If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com. Please also paste or describe the results you observe along with the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````. placeholder: | A clear and concise description of what the bug is. validations: required: true - type: textarea attributes: label: Versions description: | Please run the following and paste the output below. ```sh python --version && pip freeze ``` validations: required: true - type: markdown attributes: value: > Thanks for contributing 🎉! ================================================ FILE: .github/ISSUE_TEMPLATE/documentation.yml ================================================ name: 📚 Documentation description: Report an issue related to https://ai2-tango.readthedocs.io/latest labels: 'documentation' body: - type: textarea attributes: label: 📚 The doc issue description: > A clear and concise description of what content in https://ai2-tango.readthedocs.io/latest is an issue. validations: required: true - type: textarea attributes: label: Suggest a potential alternative/fix description: > Tell us how we could improve the documentation in this regard. - type: markdown attributes: value: > Thanks for contributing 🎉! ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.yml ================================================ name: 🚀 Feature request description: Submit a proposal/request for a new feature labels: 'feature request' body: - type: textarea attributes: label: 🚀 The feature, motivation and pitch description: > A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. validations: required: true - type: textarea attributes: label: Alternatives description: > A description of any alternative solutions or features you've considered, if any. - type: textarea attributes: label: Additional context description: > Add any other context or screenshots about the feature request. - type: markdown attributes: value: > Thanks for contributing 🎉! ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: pip directory: "/" schedule: interval: "daily" open-pull-requests-limit: 10 - package-ecosystem: "github-actions" directory: "/" schedule: interval: "daily" ================================================ FILE: .github/workflows/changelog.yml ================================================ name: Changelog concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true on: pull_request: branches: - main paths: - 'tango/**' jobs: changelog: name: CHANGELOG runs-on: ubuntu-latest if: github.event_name == 'pull_request' steps: - uses: actions/checkout@v3 with: fetch-depth: 0 - name: Check that CHANGELOG has been updated run: | # If this step fails, this means you haven't updated the CHANGELOG.md # file with notes on your contribution. git diff --name-only $(git merge-base origin/main HEAD) | grep '^CHANGELOG.md$' && echo "Thanks for helping keep our CHANGELOG up-to-date!" ================================================ FILE: .github/workflows/docker.yml ================================================ name: Docker concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true on: pull_request: branches: - main paths: - "Dockerfile" - ".dockerignore" - "pyproject.toml" push: tags: - "v*.*.*" jobs: build: name: Build (${{ matrix.build.tag }}) runs-on: ubuntu-latest strategy: fail-fast: false matrix: build: - base_image: ghcr.io/allenai/pytorch:1.12.1-cuda11.3-python3.9 tag: cuda11.3 env: IMAGE_NAME: ghcr.io/allenai/tango steps: - uses: actions/checkout@v3 - name: Build Docker image run: | docker build --build-arg BASE_IMAGE=${{ matrix.build.base_image }} -t "${IMAGE_NAME}:${{ matrix.build.tag }}" . - name: Test Docker image run: | docker run --rm "${IMAGE_NAME}:${{ matrix.build.tag }}" info - name: Log in to ghcr.io if: github.event_name != 'pull_request' run: | echo "${{ secrets.GITHUB_TOKEN }}" | docker login ghcr.io -u ${{ github.actor }} --password-stdin - name: Push latest to ghcr.io if: github.event_name != 'pull_request' run: | docker push "${IMAGE_NAME}:${{ matrix.build.tag }}" - name: Push release version to ghcr.io if: startsWith(github.ref, 'refs/tags/') run: | GITHUB_TAG=${GITHUB_REF#refs/tags/} docker tag "${IMAGE_NAME}:${{ matrix.build.tag }}" "${IMAGE_NAME}:${GITHUB_TAG}-${{ matrix.build.tag }}" docker push "${IMAGE_NAME}:${GITHUB_TAG}-${{ matrix.build.tag }}" ================================================ FILE: .github/workflows/docker_testing.yml ================================================ # This workflow is just for building our Docker image for GPU testing on Beaker, # and pushing it to Beaker. We only run it when the relevant Dockerfile (or .dockerignore) changes. name: Docker testing concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true on: pull_request: branches: - main paths: - 'Dockerfile.test' - '.dockerignore' - 'scripts/entrypoint.sh' push: branches: - main paths: - 'Dockerfile.test' - '.dockerignore' - 'scripts/entrypoint.sh' jobs: build: name: Build runs-on: ubuntu-latest env: BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} BEAKER_WORKSPACE: ai2/tango-testing IMAGE_NAME: tango-testing steps: - uses: actions/checkout@v3 - uses: allenai/setup-beaker@v2 with: token: ${{ secrets.BEAKER_TOKEN }} workspace: ${{ env.BEAKER_WORKSPACE }} - name: Build Docker image run: | docker build -t "$IMAGE_NAME" -f Dockerfile.test . - name: Determine current commit SHA (pull request) if: github.event_name == 'pull_request' run: | echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV - name: Determine current commit SHA (push) if: github.event_name != 'pull_request' run: | echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV - name: Test Docker image run: | docker run --rm --env COMMIT_SHA="$COMMIT_SHA" "$IMAGE_NAME" tango info # In order to push a new version of an image to beaker, we have to delete the old version first. # This doesn't actually delete the backing Docker image, so we'll still benefit from layer # caching when we push new versions. But we have to be careful to minimize the amount # of time between deletion and creation, because during that time any Beaker job trying to start # that depends on that image will fail. So to minimize this downtime, we first push a # "temp" version of the image, then delete the current one and quickly rename the "temp" one to take its place. # The image might not exist yet though, so it's okay if the delete fails. - name: Delete existing commit image continue-on-error: true run: | beaker image delete petew/${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }} - name: Upload new commit image run: | beaker image create --workspace ${{ env.BEAKER_WORKSPACE }} --name ${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }} ${{ env.IMAGE_NAME }} - name: Delete existing image if: github.event_name != 'pull_request' continue-on-error: true run: | beaker image delete petew/${{ env.IMAGE_NAME }} - name: Rename new commit image to final image if: github.event_name != 'pull_request' run: | beaker image rename petew/${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }} ${{ env.IMAGE_NAME }} ================================================ FILE: .github/workflows/integration_tests.yml ================================================ name: Integration tests on: workflow_dispatch: inputs: test: description: the integration test to run default: fairscale_benchmarks required: true type: choice options: - fairscale_benchmarks cluster: description: the beaker cluster to run the test on default: ai2/tango-integration-tests required: true type: choice options: - ai2/tango-integration-tests - ai2/allennlp-cirrascale # Uncomment this trigger to test changes on a pull request. # You also have to uncomment the lines below that mention 'for pull request checks' # pull_request: # branches: # - '*' jobs: run_test: name: ${{ github.event.inputs.test }} # name: fairscale_benchmarks # for pull request checks runs-on: [ubuntu-latest] timeout-minutes: 60 env: TEST_NAME: ${{ github.event.inputs.test }} # TEST_NAME: fairscale_benchmarks # for pull request checks BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} BEAKER_WORKSPACE: ai2/tango-integration-tests BEAKER_CLUSTER: ${{ github.event.inputs.cluster }} # BEAKER_CLUSTER: ai2/allennlp-cirrascale # for pull request checks IMAGE_NAME: petew/tango-testing steps: - uses: actions/checkout@v3 - name: Validate inputs run: | # The 'test' input should be a directory in `integration_tests/` test -d "integration_tests/${TEST_NAME}" - name: Determine current commit SHA (pull request) if: github.event_name == 'pull_request' run: | echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV - name: Determine current commit SHA (push) if: github.event_name != 'pull_request' run: | echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV - name: Install beaker client shell: bash run: | mkdir -p "$HOME/bin" # Download and install from latest GitHub release. curl -s https://api.github.com/repos/allenai/beaker/releases/latest \ | grep 'browser_download_url.*linux' \ | cut -d '"' -f 4 \ | wget -qi - \ && tar -xvzf beaker_linux.tar.gz -C "$HOME/bin" # Add to path. echo "$HOME/bin" >> "$GITHUB_PATH" - name: Verify beaker install run: | beaker account whoami - name: Create beaker experiment config run: | cat >beaker_config.yml << EOL version: v2-alpha description: ${{ env.TEST_NAME }} tasks: - name: test image: beaker: ${{ env.IMAGE_NAME }} command: ["/entrypoint.sh", "integration_tests/${{ env.TEST_NAME }}/run.sh"] envVars: - name: COMMIT_SHA value: $COMMIT_SHA - name: WANDB_API_KEY secret: WANDB_API_KEY - name: FILE_FRIENDLY_LOGGING value: "true" - name: TOKENIZERS_PARALLELISM # set this to avoid warnings value: "true" - name: PYTHONUNBUFFERED value: "true" result: path: '/results' resources: gpuCount: 4 context: cluster: ${{ env.BEAKER_CLUSTER }} priority: normal EOL cat beaker_config.yml - name: Submit beaker job run: | TIMESTAMP=$(date +%H%M%S) EXPERIMENT=$(beaker experiment create beaker_config.yml --workspace $BEAKER_WORKSPACE --name "${TEST_NAME}-${{ github.run_number }}-${TIMESTAMP}" | awk '{print $2}') if [ -z "$EXPERIMENT" ]; then exit 1 else echo "EXPERIMENT=$EXPERIMENT" >> $GITHUB_ENV echo "Experiment $EXPERIMENT submitted. See progress at https://beaker.org/ex/$EXPERIMENT" fi - name: Wait for job to finish run: | beaker experiment await $EXPERIMENT test finalized --timeout 60m # Check the job's exit code. test $(beaker experiment get $EXPERIMENT --format=json | jq '.[0].jobs[0].status.exitCode') -eq 0 - name: Get logs if: always() run: | # EXPERIMENT could be empty if the submission step failed. # We'll exit right away if that's the case. if [ -z "$EXPERIMENT" ]; then echo "No logs to show" exit 0 fi # Download logs from beaker. beaker experiment results $EXPERIMENT --prefix out.log --output results # If the experiment failed during startup, there might not be any logs. if [ -f results/test/out.log ]; then echo "" echo ">>> Logs:" echo "" cat results/test/out.log else echo "No logs to show" fi - name: Stop job if: cancelled() run: | if [ ! -z "$EXPERIMENT" ]; then beaker experiment stop $EXPERIMENT fi ================================================ FILE: .github/workflows/main.yml ================================================ name: Main concurrency: group: ${{ github.workflow }}-${{ github.ref }} on: pull_request: branches: - "*" push: branches: - main tags: - "v*.*.*" env: CACHE_PREFIX: v5 # Change this to invalidate existing cache. PYTHON_PATH: ./ DEFAULT_PYTHON: 3.9 WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }} BEAKER_WORKSPACE: ai2/tango-testing BEAKER_IMAGE: petew/tango-testing GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} jobs: checks: name: python ${{ matrix.python }} - ${{ matrix.task.name }} runs-on: [ubuntu-latest] timeout-minutes: 30 permissions: contents: "read" id-token: "write" strategy: fail-fast: false matrix: python: ["3.9"] task: - name: Lint extras: dev,all requires_torch: true run: | ruff check . - name: Type check extras: dev,all requires_torch: true run: | mypy --check-untyped-defs . - name: Build extras: dev,all requires_torch: true run: | tango --version python -m build - name: Style extras: dev requires_torch: false run: | isort --check . black --check . - name: Docs extras: dev,all requires_torch: true run: | cd docs && make html SPHINXOPTS="-W --keep-going" - name: Test extras: dev requires_torch: false run: | pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/ - name: Datasets integration extras: dev,datasets requires_torch: false run: | pytest -v --color=yes --doctest-modules tango/integrations/datasets tests/integrations/datasets - name: PyTorch integration extras: dev,torch requires_torch: true run: | pytest -v --color=yes --doctest-modules tango/integrations/torch tests/integrations/torch - name: Transformers integration extras: dev,flax,transformers requires_torch: true run: | pytest -v --color=yes --doctest-modules tango/integrations/transformers tests/integrations/transformers - name: FairScale integration extras: dev,fairscale requires_torch: true run: | pytest -v --color=yes --doctest-modules tango/integrations/fairscale tests/integrations/fairscale - name: W&B integration extras: dev,torch,flax,wandb requires_torch: true run: | pytest -v --color=yes --doctest-modules tango/integrations/wandb tests/integrations/wandb - name: Beaker integration extras: dev,beaker requires_torch: false run: | pytest -v --color=yes --doctest-modules tango/integrations/beaker tests/integrations/beaker - name: Flax integration extras: dev,flax,transformers requires_torch: false run: | pytest -v --color=yes --doctest-modules tango/integrations/flax tests/integrations/flax - name: GS integration extras: dev,gs requires_torch: false run: | pytest -v --color=yes --doctest-modules tango/integrations/gs tests/integrations/gs - name: Example - train_lm extras: dev,all requires_torch: true run: | cd examples/train_lm pytest -v --color=yes test.py include: # Run the core tests on other Python versions as well. - task: name: Test extras: dev requires_torch: false run: | pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/ python: "3.8" - task: name: Test extras: dev requires_torch: false run: | pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/ python: "3.10" steps: - uses: "actions/checkout@v3" - name: Checkout if: github.event_name != 'pull_request' uses: actions/checkout@v3 # For pull requests we need to checkout the HEAD commit instead of the merge # commit since some tests depend on having an existing commit. - name: Checkout (pull request) if: github.event_name == 'pull_request' uses: actions/checkout@v3 with: ref: ${{ github.event.pull_request.head.sha }} - name: Setup Python uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} - name: Install prerequisites run: | pip install --upgrade pip setuptools wheel virtualenv - name: Set build variables shell: bash run: | set -e # Get the exact Python version to use in the cache key. echo "PYTHON_VERSION=$(python --version)" >> $GITHUB_ENV echo "RUNNER_ARCH=$(uname -m)" >> $GITHUB_ENV # Use week number in cache key so we can refresh the cache weekly. echo "WEEK_NUMBER=$(date +%V)" >> $GITHUB_ENV echo "EXTRAS_HASH=$(python scripts/hash_extras.py ${{ matrix.task.extras }})" >> $GITHUB_ENV - uses: actions/cache@v3 id: virtualenv-cache with: path: .venv key: ${{ env.CACHE_PREFIX }}-${{ env.WEEK_NUMBER }}-${{ runner.os }}-${{ env.RUNNER_ARCH }}-${{ env.PYTHON_VERSION }}-${{ env.EXTRAS_HASH }}-${{ hashFiles('pyproject.toml') }} - name: Setup virtual environment (no cache hit) if: steps.virtualenv-cache.outputs.cache-hit != 'true' run: | test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv # Reference: https://github.com/marketplace/actions/authenticate-to-google-cloud#setup - name: Authenticate to Google Cloud if: matrix.task.name == 'GS integration' uses: "google-github-actions/auth@v1" with: workload_identity_provider: "projects/10554368204/locations/global/workloadIdentityPools/tango-ci-pool/providers/tango-ci-provider" service_account: "tango-service@ai2-allennlp.iam.gserviceaccount.com" - name: Pre-install torch if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'torch') || contains(matrix.task.extras, 'all') || matrix.task.requires_torch) run: | . .venv/bin/activate pip install torch==2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu - name: Pre-install flax if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'flax') || contains(matrix.task.extras, 'all')) run: | . .venv/bin/activate pip install flax jax jaxlib "tensorflow-cpu>=2.9.1" optax - name: Install editable (no cache hit) if: steps.virtualenv-cache.outputs.cache-hit != 'true' run: | . .venv/bin/activate pip install -e .[${{ matrix.task.extras }}] - name: Install editable (cache hit) if: steps.virtualenv-cache.outputs.cache-hit == 'true' run: | . .venv/bin/activate pip install --no-deps -e .[${{ matrix.task.extras }}] - name: Show environment info run: | . .venv/bin/activate echo "========= Python location ===========" which python echo "========= Python version ============" python --version echo "========= Python packages ===========" pip freeze echo "========= Tango installation ========" tango info - name: ${{ matrix.task.name }} run: | . .venv/bin/activate ${{ matrix.task.run }} - name: Upload package distribution files if: matrix.task.name == 'Build' && matrix.python == env.DEFAULT_PYTHON uses: actions/upload-artifact@v3 with: name: package path: dist - name: Upload docs build if: matrix.task.name == 'Docs' && matrix.python == env.DEFAULT_PYTHON uses: actions/upload-artifact@v3 with: name: docs path: docs/build - name: Clean up if: always() run: | . .venv/bin/activate pip uninstall -y ai2-tango gpu_tests: name: GPU Tests runs-on: ubuntu-latest steps: - name: Determine current commit SHA (pull request) if: github.event_name == 'pull_request' run: | echo "COMMIT_SHA=${{ github.event.pull_request.head.sha }}" >> $GITHUB_ENV - name: Determine current commit SHA (push) if: github.event_name != 'pull_request' run: | echo "COMMIT_SHA=$GITHUB_SHA" >> $GITHUB_ENV - name: GPU Tests uses: allenai/beaker-run-action@v1.2 with: spec: | version: v2 description: GPU Tests budget: ai2/oe-training tasks: - name: tests image: beaker: ${{ env.BEAKER_IMAGE }} context: preemptible: true resources: gpuCount: 2 envVars: - name: COMMIT_SHA value: ${{ env.COMMIT_SHA }} command: ["/entrypoint.sh", "pytest", "-v", "-m", "gpu", "tests/"] result: path: /unused token: ${{ secrets.BEAKER_TOKEN }} workspace: ${{ env.BEAKER_WORKSPACE }} clusters: ai2/general-cirrascale,ai2/allennlp-cirrascale,ai2/aristo-cirrascale,ai2/mosaic-cirrascale,ai2/s2-cirrascale release: name: Release runs-on: ubuntu-latest needs: [gpu_tests, checks] if: startsWith(github.ref, 'refs/tags/') steps: - uses: actions/checkout@v3 with: fetch-depth: 0 - name: Setup Python uses: actions/setup-python@v4 with: python-version: ${{ env.DEFAULT_PYTHON }} - name: Install requirements run: | pip install -e .[dev] - name: Prepare environment run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/tags/v}" >> $GITHUB_ENV echo "TAG=${GITHUB_REF#refs/tags/}" >> $GITHUB_ENV - name: Download package distribution files uses: actions/download-artifact@v3 with: name: package path: dist - name: Generate release notes run: | python scripts/release_notes.py > ${{ github.workspace }}-RELEASE_NOTES.md - name: Publish package to PyPI run: | twine upload -u __token__ -p ${{ secrets.PYPI_PASSWORD }} dist/* - name: Publish GitHub release uses: softprops/action-gh-release@v1 with: body_path: ${{ github.workspace }}-RELEASE_NOTES.md prerelease: ${{ contains(env.TAG, 'rc') }} files: | dist/* ================================================ FILE: .github/workflows/update_dependency_pr.yml ================================================ name: Update dependency PR on: pull_request: types: - opened paths: - "pyproject.toml" permissions: pull-requests: write jobs: torch: name: torch runs-on: ubuntu-latest if: startsWith(github.head_ref, 'dependabot/pip/torch-') steps: - uses: actions/github-script@v6 with: script: | github.rest.issues.createComment({ issue_number: context.issue.number, owner: context.repo.owner, repo: context.repo.repo, body: 'Hello! This is a [PyTorch](https://pytorch.org/) upgrade, which means you will also need to update:\n- [ ] The base image in `Dockerfile`\n- [ ] The base image in `Dockerfile.test`\n- [ ] The torch version hard-coded in `.github/workflows/main.yml`' }) ================================================ FILE: .gitignore ================================================ # build artifacts .eggs/ .mypy_cache ai2_tango.egg-info/ build/ dist/ pip-wheel-metadata/ runs/ workspace/ # dev tools .envrc .python-version .idea .venv/ .vscode/ /*.iml # jupyter notebooks .ipynb_checkpoints # miscellaneous .cache/ doc/_build/ *.swp .DS_Store # python *.pyc *.pyo __pycache__ # testing and continuous integration .coverage .pytest_cache/ .benchmarks # documentation build artifacts docs/build site/ # internal experiment configs *-internal.jsonnet *-internal.json *-internal.yaml *-internal.yml ================================================ FILE: .readthedocs.yaml ================================================ version: 2 sphinx: configuration: docs/source/conf.py fail_on_warning: true build: os: ubuntu-22.04 tools: python: "3.10" python: install: - method: pip path: . extra_requirements: - dev - all ================================================ FILE: CHANGELOG.md ================================================ # Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased ### Fixed - Fixed a bunch of dependencies - Upgraded to new version of wandb ## [v1.3.2](https://github.com/allenai/tango/releases/tag/v1.3.2) - 2023-10-27 ### Fixed - Fix issues with gcloud auth in beaker executor. ## [v1.3.1](https://github.com/allenai/tango/releases/tag/v1.3.1) - 2023-10-25 ### Fixed - Minor bugs in the `GSWorkspace()`. ### Changed - Added CLI-style execution functions for experiments defined in Python. - Added `display()` to `ExecutorOutput` for producing a table that summarizes the run. ## [v1.3.0](https://github.com/allenai/tango/releases/tag/v1.3.0) - 2023-10-13 ### Added - Added the `Workspace.remove_step()` method to safely remove steps. - The `GSWorkspace()` can now be initialized with google cloud bucket subfolders. ### Changed - The `BeakerExecutor` now uses the HEAD commit at the time the executor is instantiated to executor a step instead of the HEAD commit at the time the step is run. ### Fixed - Removed unnecessary code coverage dev requirements. - Fixed issue where new version of torch caused no LR schedulers to be registered. - Updated pinned versions of jax, jaxlib, and flax. ## [v1.2.1](https://github.com/allenai/tango/releases/tag/v1.2.1) - 2023-04-06 ### Added - Added the following workspace methods to support the Tango viz UI: `Workspace.search_registered_runs()`, `Workspace.search_step_info()`, `Workspace.num_registered_runs()`, and `Workspace.num_steps()`. ### Fixed - Fixes a bug where `FromParams` would fail to parse when an object takes a `Step` argument directly. - Changed a name so we don't override the built-in name `set`. - Fixed a bug that would cause O(n^2) memory consumption in dense step graphs. ## [v1.2.0](https://github.com/allenai/tango/releases/tag/v1.2.0) - 2023-02-10 ### Added - You can now add arguments to steps without invalidating the cache. See `Step.SKIP_DEFAULT_ARGUMENTS`. - Fixed integration status messages in `tango info` command. - Added abstractions for `RemoteClient`, `RemoteStepCache`, and `RemoteWorkspace`. - Added a GS integration that comes with `GSWorkspace`, a remote `Workspace` implementation that uses google cloud storage. - You can now bind functional steps to the underlying `Step` instance with `@step(bind=True)`, meaning the first argument to the function will be a `Step`. - Added `ShellStep` for running arbitrary shell commands. - Added `@make_registrable` decorator to make arbitrary functions registrable, to make it easier to refer to them in tango configurations. ### Fixed - Jsonnet parsing is now much faster and works on Windows. - Warnings about locks are now reliably printed every 30 seconds - We now make sure Beaker jobs have the latest version of beaker-py, so that we're compatible with the latest API changes. - Stopping early now works when the metric doesn't change at all. - Fixed bug with `FromParams` which didn't handle variable length tuples correctly. ### Changed - The default log level for Tango is now `warning`. - You can specify multiple steps with `-s` from the `tango run` command. ## [v1.1.0](https://github.com/allenai/tango/releases/tag/v1.1.0) - 2022-12-01 ### Added - Added `gpu_type` field to `StepResources`. The `BeakerExecutor` can use this to determine which clusters to a submit a step to. - Added `machine` field to `StepResources`. You can set this to "local" when using the `BeakerExecutor` to force it to run the step locally. - Added `--ext-var` argument to `tango run` for setting JSONNET external variables when loading the experiment config. - Added `@step()` decorator to create `Step` classes from functions. - Added the `transformers::with_soft_prompt` integration, to make soft-prompted prefix transformers easy. ### Removed - Removed PyTorch Lightning integration. - Removed `tango server` command and `--serve/--no-serve` option for `tango run`. - Removed `source_release.py`, which was checked in by accident. ### Fixed - Fixed issue where Executor `parallelism` option in a Tango settings file would be ignored. - Fixed a bug where the unique ID of a step that depends on a key-value of the result of another step could change if the name of the other step changes. - Fixed a bug where importing certain libraries (like torchmetrics) would mess with our exception handling because they set `sys.excepthook` for some reason. Now we always reset `sys.excepthook` after importing. - The type hints for the flax trainer suggested that the training split is optional when in fact it's mandatory. - Made `BeakerWorkspace` / `BeakerStepLock` more robust when a job is preempted. - Minor performance improvements for the Beaker executor and workspace. - Fixed bug with `step_extra_dependencies` where uncacheable dependencies wouldn't be run. ## [v1.0.2](https://github.com/allenai/tango/releases/tag/v1.0.2) - 2022-11-14 ### Changed - `BeakerScheduler` can now return a list of clusters. ## [v1.0.1](https://github.com/allenai/tango/releases/tag/v1.0.1) - 2022-10-20 ### Fixed - `LightningTrainStep` now can take a `Lazy` model object which results in a gauranteed deterministic hash. - Fixed issue where remote `Workspace` implementations like `WandbWorkspace` and `BeakerWorkspace` would use the same local cache regardless of the W&B / Beaker workspace being used. - Fixed bug with `TorchEvalStep` when constructing callbacks. - Fixed some import error issues caused when an integration is not installed. - Fix incorrect reporting of final results in `MulticoreExecutor`. ### Changed - Wandb step cache retries api call in case of timeout - `beaker-py >= 1.11` required. ## [v1.0.0](https://github.com/allenai/tango/releases/tag/v1.0.0) - 2022-10-05 ### Added - Added `step_extra_dependencies` input field to `Step` class that can be used to force a dependency on another step even if the current step doesn't directly depend on the output of the other step. See [#418](https://github.com/allenai/tango/issues/418) for more context. ### Changed - `beaker-py >= 1.10` required. ### Fixed - Long log lines will be soft-wrapped to ensure that links are clickable. - Fixed a bug where some workspaces could be left in a bad state if a step's `Format` failed to serialize the step's result in `Workspace.step_finished()`. - Sometimes functions and methods end up as arguments to steps, which means we have to hash them. Instead of taking a hash of the function, we now take a hash of the function's module and name. - Fixed a bug with the Beaker executor where it would hang at the end of a run if a step failed that is a dependency of another step. - Fixed tests to work with new version of transformers. - Fixed `Executor.execute_sub_graph_for_step()` to be able to run the step's dependencies in parallel. ## [v0.14.0](https://github.com/allenai/tango/releases/tag/v0.14.0) - 2022-09-20 ### Added - Adds a function to modify a Hugging Face transformer with IA3 adaptors - Added a `BeakerScheduler` registrable class, specified as the argument `scheduler` to `BeakerExecutor`, which controls the resources assigned to steps ran on Beaker. Users can implement their own `BeakerScheduler` subclasses to customize the resource assignment behavior. ### Changed - In the `tango run` command, `--no-server` is now the default. Use `--server` to start the server. ### Fixed - Made `BeakerExecutor` more robust to connection, timeout, SSL, and other recoverable HTTP errors. - Made the `BeakerStepLock` more robust, and as a result `BeakerWorkspace` is more robust and should require less manual intervention for locks in a bad state. - Fixed a bug with the internal scheduling logic of the `BeakerExecutor` which could delay submitting some steps in parallel. - Fixed a bug where creating a `StepInfo` object from params might result in unnecessary imports. - Fixed a bug where canceling the Beaker executor might not work properly. - Fixed a bug where the trainer trains too much when `train_epochs` is set and you're using gradient accumulation. - Fixed a bug where included modules might not be found when using multiprocessing when they're not on `sys.path` / `PYTHONPATH`. - Fixed how the results of uncacheable steps are displayed by `tango run`. - Beaker executor won't run duplicate cacheable steps at the same time. ## [v0.13.0](https://github.com/allenai/tango/releases/tag/v0.13.0) - 2022-09-07 ### Added - You can now reference into a particular index of the result of another step in a config. For example: `{type: "ref", ref: "some_previous_step", key: 0}`. The key field can be an integer if the result of the referenced step is a list or tuple, or a string if the result of the referenced step is a dictionary. - Added `priority` parameter to Beaker executor for setting the default task priority for Beaker jobs. - Added `Workspace.step_result()` method for getting a step's result from the latest run. - `tango run` will now display a URL to the logs for failed steps when you use the `BeakerExecutor`. ### Changed - The `TorchTrainStep` now enables monitoring arbitrary model outputs during training. `TorchTrainEngine.forward_train` now returns a tuple `loss, model_outputs` for each micro batch and the list of model outputs for all micro batches in a batch is passed to the `TrainCallback.log_batch` and `TrainCallback.post_batch`. - Tango will now automatically search Python modules in the current working directory for registered classes so that you don't always need to use the `--include-package` setting. - The minimum supported Python version is now 3.8. - Added support for PyTorch Lightning 1.7.x - The Beaker Executor will no-longer live-stream logs from Beaker jobs, but logs will be viewable on Beaker and more readable. - Only the Beaker executor requires a clean working directory ### Fixed - Fixed a bug that did not allow a wandb artifact's type to be set from a step's metadata dictionary. - Fixed a bug with how the Beaker executor streams log lines from Beaker which sometimes resulted in messages missing some starting characters, and tqdm lines being duplicated. - Fixed a bug in the Beaker workspace where the lock dataset wouldn't be removed if the step was found to be in an invalid state. - Improved cluster choice logic in `BeakerExecutor` to ensure greater diversity of clusters when submitting many steps at once. - Fixed bug where sub-processes of the multicore executor would use the wrong executor if `executor` was defined in a `tango.yml` file. - Deterministic hashes for numpy and torch tensors were not deterministic. Now they are. ## [v0.12.0](https://github.com/allenai/tango/releases/tag/v0.12.0) - 2022-08-23 ### Added - **Step resources:** - Added a `step_resources` parameter to the `Step` class which should be used to describe the computational resources required to run a step. `Executor` implementations can use this information. For example, if your step needs 2 GPUs, you should set `step_resources=StepResources(gpu_count=2)` (`"step_resources": {"gpu_count": 2}` in the configuration language). - Added a `Step.resources()` property method. By default this returns the value specified by the `step_resources` parameter. If your step implementation always requires the same resources, you can just override this method so you don't have to provide the `step_resources` parameter. - **Step execution:** - Added an `executor` field to the `tango.yml` settings. You can use this to define the executor you want to use by default. - Added a Beaker `Executor` to the Beaker integration, registered as an `Executor` with the name "beaker". To use this executor, add these lines to your `tango.yml` file: ```yaml executor: type: beaker beaker_workspace: ai2/my-workspace clusters: - ai2/general-cirrascale ``` See the docs for the `BeakerExecutor` for more information on the input parameters. - **Step class:** - Added a metadata field to the step class API. This can be set through the class variable `METADATA` or through the constructor argument `step_metadata`. - **Weights & Biases integration:** - You can now change the artifact kind for step result artifacts by adding a field called "artifact_kind" to a step's metadata. For models, setting "artifact_kind" to "model" will add the corresponding artifact to W&B's new model zoo. ### Changed - **CLI:** - The `tango run` command will throw an error if you have uncommitted changes in your repository, unless you use the `--allow-dirty` flag. - The `tango run` command will use the lightweight base executor (single process) by default. To use the multi-process executor, set `-j/--parallelism` to 1 or higher or -1 to use all available CPU cores. ### Fixed - Fixed bug where `StepInfo` environment and platform metadata could be out-of-date if a step is run again due to failure. - Fixed a bug where an unfortunate combination of early stopping and decreasing model performance could result in a crash in the torch trainer. ## [v0.11.0](https://github.com/allenai/tango/releases/tag/v0.11.0) - 2022-08-04 ### Added - Added a [Flax](https://flax.readthedocs.io/en/latest/) integration along with an example config. ## [v0.10.1](https://github.com/allenai/tango/releases/tag/v0.10.1) - 2022-07-26 ### Fixed - Fixed issue where the StepInfo config argument could be parsed into a Step. - Restored capability to run tests out-of-tree. ## [v0.10.0](https://github.com/allenai/tango/releases/tag/v0.10.0) - 2022-07-07 ### Changed - Renamed `workspace` parameter of `BeakerWorkspace` class to `beaker_workspace`. - `Executor` class is now a `Registrable` base class. `MulticoreExecutor` is registered as "multicore". ### Removed - Removed `StepExecutionMetadata`. Its fields have been absorbed into `StepInfo`. ### Fixed - Improved `Step.ensure_result()` such that the step's result doesn't have to be read from the cache. - Fixed an issue with the output from `MulticoreExecutor` such that it's now consistent with the default `Executor` for steps that were found in the cache. - One of our error messages referred to a configuration file that no longer exists. - Improved performance of `BeakerWorkspace`. ### Added - Added the ability to train straight `Model` instead of just `Lazy[Model]` ## [v0.9.1](https://github.com/allenai/tango/releases/tag/v0.9.1) - 2022-06-24 ### Fixed - Fixed non-deterministic behavior in `TorchTrainStep`. - Fixed bug in `BeakerWorkspace` where `.step_info(step)` would raise a `KeyError` if the step hasn't been registered as part of a run yet. - Fixed a bug in `BeakerWorkspace` where it would send too many requests to the beaker service. - Fixed a bug where `WandbWorkspace.step_finished()` or `.step_failed()` would crash if called from a different process than `.step_starting()`. - Fixed a bug in `WandbWorkspace.step_finished()` which led to a `RuntimeError` sometimes while caching the result of a step. ## [v0.9.0](https://github.com/allenai/tango/releases/tag/v0.9.0) - 2022-06-01 ### Added - Added a [Beaker](https://beaker.org) integration that comes with `BeakerWorkspace`, a remote `Workspace` implementation that uses Beaker Datasets under the hood. - Added a `datasets::dataset_remix` step that provides the split remixing functionality of `tango.steps.datasest_remix.DatasetRemixStep` now for Huggingface `DatasetDict`. - Added a config and code example of Registrable to the First Step docs with edits for clarity. ### Changed - If you try to import something from a tango integration that is not fully installed due to missing dependencies, an `IntegrationMissingError` will be raised instead of `ModuleNotFound`. - You can now set `-j 0` in `tango run` to disable multicore execution altogether. ### Fixed - Improved how steps and workspaces handle race conditions when different processes are competing to execute the same step. This would result in a `RuntimeError` before with most workspaces, but now it's handled gracefully. - Fixed bug which caused GradScaler state to not be saved and loaded with checkpoints. ## [v0.8.0](https://github.com/allenai/tango/releases/tag/v0.8.0) - 2022-05-19 ### Added - Added a Weights & Baises remote `Workspace` implementation: `WandbWorkspace`, registered as "wandb". This can be instantiated from a workspace URL in the form "wandb://entity/project". - Added a method `Workspace.step_result_for_run` which gives the result of a step given the run name and step name within that run. - Added property `Workspace.url`, which returns a URL for the workspace that can be used to instantiate the exact same workspace using `Workspace.from_url()`. Subclasses must implement this. ### Changed - `StepInfo` start and end times will be always be in UTC now. - `WandbTrainCallback` now logs system metrics from each worker process in distributed training. - `StepCache.__contains__()` and `StepCache.__getitem__()` now take accept either a `Step` or `StepInfo` as an argument (`Union[Step, StepInfo]`). - Refactored `tango.step_graph.StepGraph` to allow initialization from a `Dict[str, Step]`. - `Executor.execute_step_graph()` now attempts to execute all steps and summarizes success/failures. ### Fixed - Fixed bug with `LocalWorkspace.from_parsed_url()` ([#278](https://github.com/allenai/tango/issues/278)). - Deprecation warnings will now be logged from `tango` CLI. - Fixed the text format in the case of serializing an iterator of string. - Added missing default value of `None` to `TangoGlobalSettings.find_or_default()`. - Mypy has become incompatible with transformers and datasets, so we have to disable the checks in some places. - The `VERSION` member of step arguments that were wrapped in `Lazy` were not respected. Now they are. ## [v0.7.0](https://github.com/allenai/tango/releases/tag/v0.7.0) - 2022-04-19 ### Added - Added the "-n/--name" option to `tango run`. This option allows the user to give the run an arbitrary name. - Added a convenience property `.workspace` to `Step` class that can be called from a step's `.run()` method to get the current `Workspace` being used. - Gave `FromParams` objects (which includes all `Registrable` objects) the ability to version themselves. - Added CLI option to run a single step in a config using `--step-name` or `-s`. - Added a `MultiCoreExecutor` that executes steps in parallel. - Added an `ExecutorOutput` dataclass that is returned by `Executor.execute_step_graph()`. - `StepGraph` now prints itself in a readable way. - Tango now automatically detects when it's running under a debugger, and disables multicore support accordingly. Many debuggers can't properly follow sub-processes, so this is a convenience for people who love debuggers. - Added more models to the stuff we can import from the transformers library. - Added new example for finetuning text-to-text models. ### Changed - Renamed `click_logger` to `cli_logger`, and we now use [rich](https://github.com/Textualize/rich)'s logging `Handler` as the default handler, which means prettier output, better tracebacks, and you can use rich's markup syntax with the `cli_logger` to easily add style to text. - Refactored `tango.step_graph.StepGraph` to allow initialization from a `Dict[str, Step]`. - `Executor.execute_step_graph()` now attempts to execute all steps and summarizes success/failures. - Upgraded PyTorch version in `tango` Docker image to latest `v1.11.0+cu113`. - `RunGeneration` now allows model object as input. ### Fixed - Fixed bug that mistakenly disallowed fully-qualified names containing `"_"` (underscores) in the config. - Fixed bug where `TorchTrainStep` working directory would be left in an unrecoverable state if training failed after saving the final model weights. - Fixed bug in `FromParams` where `**kwargs` might be passed down to the constructors of arguments. - Fixed bug in the way dependencies are tracked between steps. - Fixed bug that caused `MulticoreExecutor` to hang in case of a failing step that was required recursively (not directly) downstream. - Fixed bug in the way dependencies are tracked between steps - Compatibility with PyTorch Lightning 1.6 ## [v0.6.0](https://github.com/allenai/tango/releases/tag/v0.6.0) - 2022-02-25 ### Added - New example that finetunes a pre-trained ResNet model on the Cats & Dogs dataset. - Added a '@requires_gpus' decorator for marking tests as needing GPUs. Tests marked with this will be run in the "GPU Tests" workflow on dual k80 GPUs via Beaker. - Added the "-w/--workspace" option to `tango run` and `tango server` commands. This option takes a path or URL, and instantiates the workspace from the URL using the newly added `Workspace.from_url()` method. - Added the "workspace" field to `TangoGlobalSettings`. - Added the "environment" field to `TangoGlobalSettings` for setting environment variables each time `tango` is run. - Added a utility function to get a `StepGraph` directly from a file. - Added `tango.settings` module and `tango settings` group of commands. - A format for storing sequences as `SqliteSparseSequence` - A way to massage kwargs before they determine the unique ID of a `Step` ### Changed - `local_workspace.ExecutorMetadata` renamed to `StepExecutionMetadata` and now saved as `execution-metadata.json`. - `tango run` without the option "-w/--workspace" or "-d/--workspace-dir" will now use a `MemoryWorkspace` instead of a `LocalWorkspace` in a temp directory, unless you've specified a default workspace in a `TangoGlobalSettings` file. - Moved `tango.workspace.MemoryWorkspace` and `tango.local_workspace.LocalWorkspace` to `tango.workspaces.*`. - Moved `tango.step_cache.MemoryStepCache` and `tango.step_cache.LocalStepCache` to `tango.step_caches.*`. - Deprecated the `-d/--workspace-dir` command-line option. Please use `-w/--workspace` instead. ### Fixed - Fixed a small bug `LocalWorkspace` would fail to capture the conda environment in our Docker image. - Fixed activation of `FILE_FRIENDLY_LOGGING` when set from the corresponding environment variable. - Fixed setting log level via the environment variable `TANGO_LOG_LEVEL`. - Use relative paths within the `work_dir` for symbolic links to the latest and the best checkpoints in `TorchTrainStep`. - Fixed some scenarios where Tango can hang after finishing all steps. - `distributed_port` and `log_every` parameters won't factor into `TorchTrainStep`'s unique ID. - `MappedSequence` now works with slicing. - `MappedSequence` now works with Huggingface `Dataset`. - Uncacheable steps are now visible in Tango UI. - Fixed bug in `Registrable.list_available()` where an error might be raised if the default implementation hadn't been explicitly imported. - Fixed issue where having a default argument to the `run()` method wasn't getting applied to the step's unique ID. ## [v0.5.0](https://github.com/allenai/tango/releases/tag/v0.5.0) - 2022-02-09 ### Added - Added `TrainingEngine` abstraction to torch integration. - Added [FairScale](https://fairscale.readthedocs.io/en/latest/) with a `FairScaleTrainingEngine` that leverages FairScale's `FullyShardedDataParallel`. This is meant to be used within the `TorchTrainStep`. - All PyTorch components (such as learning rate schedulers, optimizers, data collators, etc) from the transformers library and now registered under the corresponding class in the torch integration. For example, transformers `Adafactor` optimizer is registered as an `Optimizer` under the name "transformers::Adafactor". More details can be found in the documentation for the transformers integration. ### Changed - Various changes to the parameters othe `TorchTrainStep` due to the introduction of the `TrainingEngine` class. - Params logged as `DEBUG` level instead of `INFO` to reduce noise in logs. - The waiting message for `FileLock` is now clear about which file it's waiting for. - Added an easier way to get the default Tango global config - Most methods to `TorchTrainCallback` also take an `epoch` parameter now. - `WandbTrainCallback` now logs peak GPU memory occupied by PyTorch tensors per worker. This is useful because W&B's system metrics only display the total GPU memory reserved by PyTorch, which is always higher than the actual amount of GPU memory occupied by tensors. So these new metrics give a more accurate view into how much memory your training job is actually using. - Plain old Python functions can now be used in `Lazy` objects. - `LocalWorkspace` now creates a symlink to the outputs of the latest run. - Tango is now better at guessing when a step has died and should be re-run. - Tango is now more lenient about registering the same class under the same name twice. - When you use `dict` instead of `Dict` in your type annotations, you now get a legible error message. Same for `List`, `Tuple`, and `Set`. ### Fixed - Fixed a bug in `Registrable` and `FromParams` where registered function constructors would not properly construct arguments that were classes. - Fixed a bug in `FromParams` that would cause a crash when an argument to the constructor had the name `params`. - Made `FromParams` more efficient by only trying to parse the params as a `Step` when it looks like it actually could be a step. - Fixed bug where `Executor` would crash if `git` command could not be found. - Fixed bug where validation settings were not interpreted the right way by the torch trainer. - When you register the same name twice using `Registrable`, you get an error message. That error message now contains the correct class name. ## [v0.4.0](https://github.com/allenai/tango/releases/tag/v0.4.0) - 2022-01-27 ### Changed - Default log level is `WARNING` instead of `ERROR`. - The web UI now renders the step graph left-to-right. - The web UI now shows runs by date, with the most recent run at the top. - The web UI now shows steps in a color-coded way. - The `tango run` command now prints user-friendly paths if possible. - The `--include-package` flag now also accepts paths instead of module names. - `tango.common.sqlite_sparse_sequence.SqliteSparseSequence` now lives at `tango.common.sequences.SqliteSparseSequence`. ### Fixed - Ensure tqdm log lines always make it into the log file `out.log` even when log level is `WARNING` or `ERROR`. - Numerous parts of Tango now have documentation when they didn't before. ## [v0.4.0rc5](https://github.com/allenai/tango/releases/tag/v0.4.0rc5) - 2022-01-19 ### Added - Added `TorchEvalStep` to torch integration, registered as "torch::eval". ### Changed - Renamed `aggregate_val_metric` to `auto_aggregate_val_metric` in `TorchTrainStep`. - `devices` parameter to `TorchTrainStep` replaced with `device_count: int`. - Run name printed at the end of a run so it's easier to find. - Type information added to package data. See [PEP 561](https://www.python.org/dev/peps/pep-0561) for more information. - A new integration, `transformers`, with two new steps for running seq2seq models. - Added `logging_tqdm`, if you don't want a progress bar, but you still want to see progress in the logs. - Added `threaded_generator()`, for wrapping generators so that they run in a separate thread from the generator's consumer. - Added a new example for evaluating the T0 model on XSum, a summarization task. - Added `MappedSequence` for functionally wrapping sequences. - Added `TextFormat`, in case you want to store the output of your steps in raw text instead of JSON. - Steps can now list arguments in `SKIP_ID_ARGUMENTS` to indicate that the argument should not affect a step's unique id. This is useful for arguments that affect the execution of a step, but not the output. - `Step` now implements `__str__`, so steps look pretty in the debugger. - Added `DatasetCombineStep`, a step that combines multiple datasets into one. - Added `common.logging.initialize_worker_logging()` function for configuring logging from worker processes/threads. - Logs from `tango run ...` will be written to a file called `out.log` in the run directory. ### Fixed - Fixed torch `StopEarlyCallback` state not being recovered properly on restarts. - Fixed file friendly logging by removing special styling characters. - Ensured exceptions captured in logs. - `LocalWorkspace` now works properly with uncacheable steps. - When a Tango run got killed hard, with `kill -9`, or because the machine lost power, `LocalWorkspace` would sometimes keep a step marked as "running", preventing further executions. This still happens sometimes, but it is now much less likely (and Tango gives you instructions for how to fix it). - To make all this happen, `LocalWorkspace` now saves step info in a Sqlite database. Unfortunately that means that the workspace format changes and existing workspace directories won't work properly with it. - Fixed premature cleanup of temporary directories when using `MemoryWorkspace` ## [v0.4.0rc4](https://github.com/allenai/tango/releases/tag/v0.4.0rc4) - 2021-12-20 ### Fixed - Fixed a bug where `StepInfo` fails to deserialize when `error` is an exception that can't be pickled. ## [v0.4.0rc3](https://github.com/allenai/tango/releases/tag/v0.4.0rc3) - 2021-12-15 ### Added - Added `DatasetsFormat` format and `LoadStreamingDataset` step to `datasets` integration. - `SqliteDictFormat` for datasets. - Added `pre_epoch()` and `post_epoch()` callback methods to PyTorch `TrainCallback`. ### Changed - `LoadDataset` step from `datasets` integration is now cacheable, using the `DatasetsFormat` format by default. But this only works with non-streaming datasets. For streaming datasets, you should use the `LoadStreamingDataset` step instead. ### Fixed - Fixed bug where `KeyboardInterrupt` exceptions were not handled properly by steps and workspaces. - `WandbTrainCallback` now will use part of the step's unique ID as the name for the W&B run by default, to make it easier to indentify which tango step corresponds to each run in W&B. - `WandbTrainCallback` will save the entire `TrainConfig` object to the W&B config. ## [v0.4.0rc2](https://github.com/allenai/tango/releases/tag/v0.4.0rc2) - 2021-12-13 ### Added - Sample experiment configurations that prove Euler's identity ### Changed - Loosened `Click` dependency to include v7.0. - Loosened `datasets` dependency. - Tightened `petname` dependency to exclude next major release for safety. ### Fixed - `Workspace`, `MemoryWorkspace`, and `LocalWorkspace` can now be imported directly from the `tango` base module. - Uncacheable leaf steps would never get executed. This is now fixed. - We were treating failed steps as if they were completed by accident. - The visualization had a problem with showing steps that never executed because a dependency failed. - Fixed a bug where `Lazy` inputs to a `Step` would fail to resolve arguments that come from the result of another step. - Fixed a bug in `TorchTrainStep` where some arguments for distributed training (`devices`, `distributed_port`) weren't being set properly. ## [v0.4.0rc1](https://github.com/allenai/tango/releases/tag/v0.4.0rc1) - 2021-11-30 ### Added - Introduced the concept of the `Workspace`, with `LocalWorkspace` and `MemoryWorkspace` as initial implementations. - Added a stub of a webserver that will be able to visualize runs as they happen. - Added separate classes for `LightningTrainingTypePlugin`, `LightningPrecisionPlugin`, `LightningClusterEnvironmentPlugin`, `LightningCheckpointPlugin` for compatibility with `pytorch-lightning>=1.5.0`. - Added a visualization of workspaces that can show step graphs while they're executing. ### Removed - Removed old `LightningPlugin` class - Removed requirement of the `overrides` package ### Changed - Made it possible to construct a step graph out of `Step` objects, instead of constructing it out of `StepStub` objects. - Removed dataset fingerprinting code, since we can now use `Step` to make sure things are cached. - Made steps deterministic by default. - Brought back `MemoryStepCache`, so we can run steps without configuring anything. - W&B `torch::TrainCallback` logs with `step=step+1` now so that training curves in the W&B dashboard match up with checkpoints saved locally and are easier to read (e.g. step 10000 instead of 9999). - `filelock >= 3.4` required, parameter `poll_intervall` to `tango.common.file_lock.FileLock.acquire` renamed to `poll_interval`. ### Fixed - Fixed bug in `FromParams` where a parameter to a `FromParams` class may not be instantiated correctly if it's a class with a generic type parameter. ## [v0.3.6](https://github.com/allenai/tango/releases/tag/v0.3.6) - 2021-11-12 ### Added - Added a `.log_batch()` method on `torch::TrainCallback` which is given the average loss across distributed workers, but only called every `log_every` steps. ### Removed - Removed `.pre_log_batch()` method on `torch::TrainCallback`. ### Fixed - Fixed typo in parameter name `remove_stale_checkpoints` in `TorchTrainStep` (previously was `remove_state_checkpoints`). - Fixed bug in `FromParams` that would cause failures when `from __future__ import annotations` was used with Python older than 3.10. See [PEP 563](https://www.python.org/dev/peps/pep-0563/) for details. ## [v0.3.5](https://github.com/allenai/tango/releases/tag/v0.3.5) - 2021-11-05 ### Fixed - Fixed a bug in `FromParams` where the "type" parameter was ignored in some cases where the `Registrable` base class did not directly inherit from `Registrable`. ## [v0.3.4](https://github.com/allenai/tango/releases/tag/v0.3.4) - 2021-11-04 ### Added - Added `StopEarlyCallback`, a `torch::TrainCallback` for early stopping. - Added parameter `remove_stale_checkpoints` to `TorchTrainStep`. ### Changed - Minor changes to `torch::TrainCallback` interface. - Weights & Biases `torch::TrainCallback` now logs best validation metric score. ## [v0.3.3](https://github.com/allenai/tango/releases/tag/v0.3.3) - 2021-11-04 ### Added - Added support for PEP 604 in `FromParams`, i.e. writing union types as "X | Y" instead of "Union[X, Y]". - [internals] Added a spot for miscellaneous end-to-end integration tests (not to be confused with "tests of integrations") in `tests/end_to_end/`. - [internals] Core tests now run on all officially supported Python versions. ### Fixed - Fixed a bug in `FromParams` where non-`FromParams` class parameters were not instantiated properly (or at all). - Fixed a bug in `FromParams` where kwargs were not passed on from a wrapper class to the wrapped class. - Fixed small bug where some errors from git would be printed when executor metadata is created outside of a git repository. ## [v0.3.2](https://github.com/allenai/tango/releases/tag/v0.3.2) - 2021-11-01 ### Fixed - Fixed a bug with `FromParams` that caused `.from_params()` to fail when the params contained an object that was already instantiated. - tango command no longer installs a SIGTERM handler, which fixes some bugs with integrations that use multiprocessing. ## [v0.3.1](https://github.com/allenai/tango/releases/tag/v0.3.1) - 2021-10-29 ### Changed - Updated the `LightningTrainStep` to optionally take in a `LightningDataModule` as input. ## [v0.3.0](https://github.com/allenai/tango/releases/tag/v0.3.0) - 2021-10-28 ### Added - Added `IterableDatasetDict`, a version of `DatasetDict` for streaming-like datasets. - Added a [PyTorch Lightning](https://www.pytorchlightning.ai) integration with `LightningTrainStep`. ### Fixed - Fixed bug with `FromParams` and `Lazy` where extra arguments would sometimes be passed down through to a `Lazy` class when they shouldn't. ## [v0.2.4](https://github.com/allenai/tango/releases/tag/v0.2.4) - 2021-10-22 ### Added - Added support for [torch 1.10.0](https://github.com/pytorch/pytorch/releases). ### Changed - `--file-friendly-logging` flag is now an option to the main `tango` command, so needs to be passed before `run`, e.g. `tango --file-friendly-logging run ...`. ### Fixed - Fixed bug with `Step.from_params`. - Ensure logging is initialized is spawn processes during distributed training with `TorchTrainStep`. ## [v0.2.3](https://github.com/allenai/tango/releases/tag/v0.2.3) - 2021-10-21 ### Added - Added support for global settings file, `tango.yml`. - Added 'include_package' (array of string) param to config spec. - Added a custom error `StopEarly` that a `TrainCallback` can raise within the `TorchTrainStep` to stop training early without crashing. - Added step config, tango command, and tango version to executor metadata. - Executor now also saves pip dependencies and conda environment files to the run directory for each step. ### Fixed - Ensured `**kwargs` arguments are logged in `FromParams`. ## [v0.2.2](https://github.com/allenai/tango/releases/tag/v0.2.2) - 2021-10-19 ### Added - Added new steps to `datasets` integration: `ConcatenateDatasets` ("datasets::concatenate") and `InterleaveDatasets` (datasets::interleave). - Added `__contains__` and `__iter__` methods on `DatasetDict` so that it is now a `Mapping` class. - Added `tango info` command that - among other things - displays which integrations are installed. ## [v0.2.1](https://github.com/allenai/tango/releases/tag/v0.2.1) - 2021-10-18 ### Added - Added `convert_to_tango_dataset_dict()` function in the `datasets` integration. It's important for step caching purposes to use this to convert a HF `DatasetDict` to a native Tango `DatasetDict` when that `DatasetDict` is part of the input to another step. Otherwise the HF `DatasetDict` will have to be pickled to determine its hash. ### Changed - `Format.checksum()` is now an abstract method. Subclasses should only compute checksum on the serialized artifact and nothing else in the directory. - [internals] Changed the relationship between `Executor`, `StepCache`, and `Step.` `Executor` now owns the `StepCache`, and `Step` never interacts with `StepCache` directly. ## [v0.2.0](https://github.com/allenai/tango/releases/tag/v0.2.0) - 2021-10-15 ### Added - Added a [Weights & Biases](https://wandb.ai) integration with a training callback ("wandb::log") for `TorchTrainStep` ("torch::train") that logs training and validation metrics to W&B. ### Fixed - Fixed `Format.checksum()` when there is a symlink to a directory in the cache folder. ## [v0.1.3](https://github.com/allenai/tango/releases/tag/v0.1.3) - 2021-10-15 ### Added - Added the ability to track a metric other than "loss" for validation in `TorchTrainStep` ("torch::train"). ### Fixed - Final model returned from `TorchTrainStep` ("torch::train") will have best weights loaded. - Checkpoints are saved from `TorchTrainStep` ("torch::train") even when there is no validation loop. - Fixed `TorchTrainStep` ("torch::train") when `validation_split` is `None`. - Fixed distributed training with `TorchTrainStep` ("torch::train") on GPU devices. ## [v0.1.2](https://github.com/allenai/tango/releases/tag/v0.1.2) - 2021-10-13 ### Added - Added support for YAML configuration files. ## [v0.1.1](https://github.com/allenai/tango/releases/tag/v0.1.1) - 2021-10-12 ### Added - `TorchTrainStep` now displays a progress bar while saving a checkpoint to file. - The default executor now saves a "executor-metadata.json" file to the directory for each step. ### Changed - Renamed `DirectoryStepCache` to `LocalStepCache` (registered as "local"). - `LocalStepCache` saves metadata to `cache-metadata.json` instead of `metadata.json`. ### Fixed - Fixed bug with `TorchTrainStep` during distributed training. - `FromParams` will automatically convert strings into `Path` types now when the annotation is `Path`. ## [v0.1.0](https://github.com/allenai/tango/releases/tag/v0.1.0) - 2021-10-11 ### Added - Added `StepGraph` and `Executor` abstractions. - Added a basic PyTorch training step registered as `"torch::train"`, along with other registrable components, such as `Model`, `DataLoader`, `Sampler`, `DataCollator`, `Optimizer`, and `LRScheduler`. - Added `DatasetRemixStep` in `tango.steps`. - Added module `tango.common.sequences`. - Added `DatasetDict` class in `tango.common.dataset_dict`. - Added [🤗 Datasets](https://github.com/huggingface/datasets) integration. - Added command-line options to set log level or disable logging completely. ### Changed - `Step.work_dir`, `Step.unique_id`, `Step.dependencies`, and `Step.recursive_dependencies` are now a properties instead of methods. - `tango run` command will acquire a lock on the directory to avoid race conditions. - Integrations can now be installed with `pip install tango[INTEGRATION_NAME]`. For example, `pip install tango[torch]`. - Added method `Registrable.search_modules()` for automatically finding and importing the modules where a given ``name`` might be registered. - `FromParams.from_params()` and `Registrable.resolve_class_name` will now call `Registrable.search_modules()` to automatically import modules where the type might be defined. Thus for classes that are defined and registered within any `tango.*` submodules it is not necessary to explicitly import them. ### Fixed - `Step` implementations can now take arbitrary `**kwargs` in their `run()` methods. ## [v0.0.3](https://github.com/allenai/tango/releases/tag/v0.0.3) - 2021-09-27 ### Added - Added `tango` command. ## [v0.0.2](https://github.com/allenai/tango/releases/tag/v0.0.2) - 2021-09-27 ### Added - Ported over core tango components from AllenNLP. ## [v0.0.1](https://github.com/allenai/tango/releases/tag/v0.0.1) - 2021-09-22 ### Added - Added initial project boilerplate. ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 message: "If you use this software, please cite it as below." authors: - family-names: "Groeneveld" given-names: "Dirk" affiliation: "Allen Institute for Artificial Intelligence" - family-names: "Bhagia" given-names: "Akshita" affiliation: "Allen Institute for Artificial Intelligence" - family-names: "Walsh" given-names: "Pete" affiliation: "Allen Institute for Artificial Intelligence" title: "AI2 Tango" abstract: "Organize your experiments into discrete steps that can be cached and reused throughout the lifetime of your research project." version: "1.3.2" repository-code: "https://github.com/allenai/tango" license: "Apache-2.0" date-released: "2023-10-27" repository-code: "https://github.com/allenai/tango" ================================================ FILE: Dockerfile ================================================ # This Dockerfile can be used to build a Docker image suitable for tango projects. ARG BASE_IMAGE=ghcr.io/allenai/pytorch:2.0.0-cuda11.7-python3.10 FROM ${BASE_IMAGE} WORKDIR /stage COPY . . RUN /opt/conda/bin/pip install --no-cache-dir .[all] WORKDIR /workspace RUN rm -rf /stage/ ENTRYPOINT ["/opt/conda/bin/tango"] ================================================ FILE: Dockerfile.test ================================================ # This Dockerfile is for building an image suitable for running tango's GPU tests and integration tests. # There are no instruction lines in this Dockerfile that install tango. Instead, the entrypoint # script handles installing tango from a particular commit at runtime, based on the environment # variable "COMMIT_SHA". That way we don't need to rebuild and push the image each time we run # tests, and we can be sure the dependencies are always up-to-date. FROM ghcr.io/allenai/pytorch:2.0.0-cuda11.7-python3.10 COPY scripts/entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh WORKDIR /testing ENTRYPOINT ["/entrypoint.sh"] ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: Makefile ================================================ .PHONY : docs docs : rm -rf docs/build/ sphinx-autobuild -b html --watch tango/ --watch examples/ docs/source/ docs/build/ .PHONY : run-checks run-checks : isort --check . black --check . ruff check . mypy --check-untyped-defs . CUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/ CUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules tango/integrations/torch tests/integrations/torch CUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules tango/integrations/transformers tests/integrations/transformers ================================================ FILE: README.md ================================================



AI2 Tango replaces messy directories and spreadsheets full of file versions by organizing experiments into discrete steps that can be cached and reused throughout the lifetime of a research project.


CI PyPI Documentation Status License
## Quick links - [Documentation](https://ai2-tango.readthedocs.io/) - [PyPI Package](https://pypi.org/project/ai2-tango/) - [Contributing](https://github.com/allenai/tango/blob/main/CONTRIBUTING.md) - [License](https://github.com/allenai/tango/blob/main/LICENSE) ## In this README - [Quick start](#quick-start) - [Installation](#installation) - [Installing with PIP](#installing-with-pip) - [Installing with Conda](#installing-with-conda) - [Installing from source](#installing-from-source) - [Checking your installation](#checking-your-installation) - [Docker image](#docker-image) - [FAQ](#faq) - [Team](#team) - [License](#license) ## Quick start Create a Tango step: ```python # hello.py from tango import step @step() def hello(name: str) -> str: message = f"Hello, {name}!" print(message) return message ``` And create a corresponding experiment configuration file: ```jsonnet // hello.jsonnet { steps: { hello: { type: "hello", name: "World", } } } ``` Then run the experiment using a local workspace to cache the result: ```bash tango run hello.jsonnet -w /tmp/workspace ``` You'll see something like this in the output: ``` Starting new run expert-llama ● Starting step "hello"... Hello, World! ✓ Finished step "hello" ✓ Finished run expert-llama ``` If you run this a second time the output will now look like this: ``` Starting new run open-crab ✓ Found output for step "hello" in cache... ✓ Finished run open-crab ``` You won't see "Hello, World!" this time because the result of the step was found in the cache, so it wasn't run again. For a more detailed introduction check out the [First Steps](https://ai2-tango.readthedocs.io/en/latest/first_steps.html) walk-through. ## Installation **ai2-tango** requires Python 3.8 or later. ### Installing with `pip` **ai2-tango** is available [on PyPI](https://pypi.org/project/ai2-tango/). Just run ```bash pip install ai2-tango ``` To install with a specific integration, such as `torch` for example, run ```bash pip install 'ai2-tango[torch]' ``` To install with all integrations, run ```bash pip install 'ai2-tango[all]' ``` ### Installing with `conda` **ai2-tango** is available on conda-forge. You can install just the base package with ```bash conda install tango -c conda-forge ``` You can pick and choose from the integrations with one of these: ```bash conda install tango-datasets -c conda-forge conda install tango-torch -c conda-forge conda install tango-wandb -c conda-forge ``` You can also install everything: ```bash conda install tango-all -c conda-forge ``` Even though **ai2-tango** itself is quite small, installing everything will pull in a lot of dependencies. Don't be surprised if this takes a while! ### Installing from source To install **ai2-tango** from source, first clone [the repository](https://github.com/allenai/tango): ```bash git clone https://github.com/allenai/tango.git cd tango ``` Then run ```bash pip install -e '.[all]' ``` To install with only a specific integration, such as `torch` for example, run ```bash pip install -e '.[torch]' ``` Or to install just the base tango library, you can run ```bash pip install -e . ``` ### Checking your installation Run ```bash tango info ``` to check your installation. ### Docker image You can build a Docker image suitable for tango projects by using [the official Dockerfile](https://github.com/allenai/tango/blob/main/Dockerfile) as a starting point for your own Dockerfile, or you can simply use one of our [prebuilt images](https://github.com/allenai/tango/pkgs/container/tango) as a base image in your Dockerfile. For example: ```Dockerfile # Start from a prebuilt tango base image. # You can choose the right tag from the available options here: # https://github.com/allenai/tango/pkgs/container/tango/versions FROM ghcr.io/allenai/tango:cuda11.3 # Install your project's additional requirements. COPY requirements.txt . RUN /opt/conda/bin/pip install --no-cache-dir -r requirements.txt # Install source code. # This instruction copies EVERYTHING in the current directory (build context), # which may not be what you want. Consider using a ".dockerignore" file to # exclude files and directories that you don't want on the image. COPY . . ``` Make sure to choose the right base image for your use case depending on the version of tango you're using and the CUDA version that your host machine supports. You can see a list of all available image tags [on GitHub](https://github.com/allenai/tango/pkgs/container/tango/versions). ## FAQ ### Why is the library named Tango? The motivation behind this library is that we can make research easier by composing it into well-defined steps. What happens when you choreograph a number of steps together? Well, you get a dance. And since our [team's leader](https://nasmith.github.io/) is part of a tango band, "AI2 Tango" was an obvious choice! ### How can I debug my steps through the Tango CLI? You can run the `tango` command through [pdb](https://docs.python.org/3/library/pdb.html). For example: ```bash python -m pdb -m tango run config.jsonnet ``` ### How is Tango different from [Metaflow](https://metaflow.org), [Airflow](https://airflow.apache.org), or [redun](https://github.com/insitro/redun)? We've found that existing DAG execution engines like these tools are great for production workflows but not as well suited for messy, collaborative research projects where code is changing constantly. AI2 Tango was built *specifically* for these kinds of research projects. ### How does Tango's caching mechanism work? AI2 Tango caches the results of steps based on the `unique_id` of the step. The `unique_id` is essentially a hash of all of the inputs to the step along with: 1. the step class's fully qualified name, and 2. the step class's `VERSION` class variable (an arbitrary string). Unlike other workflow engines like [redun](https://github.com/insitro/redun), Tango does *not* take into account the source code of the class itself (other than its fully qualified name) because we've found that using a hash of the source code bytes is way too sensitive and less transparent for users. When you change the source code of your step in a meaningful way you can just manually change the `VERSION` class variable to indicate to Tango that the step has been updated. ## Team **ai2-tango** is developed and maintained by the AllenNLP team, backed by [the Allen Institute for Artificial Intelligence (AI2)](https://allenai.org/). AI2 is a non-profit institute with the mission to contribute to humanity through high-impact AI research and engineering. To learn more about who specifically contributed to this codebase, see [our contributors](https://github.com/allenai/tango/graphs/contributors) page. ## License **ai2-tango** is licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0). A full copy of the license can be found [on GitHub](https://github.com/allenai/tango/blob/main/LICENSE). ================================================ FILE: RELEASE_PROCESS.md ================================================ # GitHub Release Process ## Steps 1. Update the version in `tango/version.py`. 2. Run the release script: ```bash ./scripts/release.sh ``` This will automatically update the CHANGELOG, commit the changes to the CHANGELOG and `version.py` (and any other files you might have changed), and then create a new tag in git which will trigger a workflow on GitHub Actions that handles the rest. ## Fixing a failed release If for some reason the GitHub Actions release workflow failed with an error that needs to be fixed, you'll have to delete both the tag and corresponding release from GitHub. After you've pushed a fix, delete the tag from your local clone with ```bash git tag -l | xargs git tag -d && git fetch -t ``` Then repeat the steps above. ================================================ FILE: docs/.gitignore ================================================ build ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line, and also # from the environment for the first two. SPHINXOPTS ?= SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=source set BUILDDIR=build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: docs/source/_static/css/custom.css ================================================ ================================================ FILE: docs/source/api/commands.rst ================================================ Commands ======== .. automodule:: tango.__main__ ================================================ FILE: docs/source/api/components/executor.rst ================================================ Executor ======== Base class ---------- .. autoclass:: tango.executor.Executor :members: .. autoclass:: tango.executor.ExecutorOutput :members: .. autoclass:: tango.executor.ExecutionMetadata :members: ================================================ FILE: docs/source/api/components/format.rst ================================================ Format ====== Base class ---------- .. autoclass:: tango.format.Format :members: :private-members: Implementations --------------- .. automodule:: tango.format :members: :exclude-members: Format,read,write,checksum ================================================ FILE: docs/source/api/components/index.rst ================================================ Components ========== The core components of **AI2 Tango**. .. toctree:: :maxdepth: 2 :caption: Components step step_info step_graph workspace step_cache format executor ================================================ FILE: docs/source/api/components/step.rst ================================================ Step ==== Base class ---------- .. autoclass:: tango.step.Step :members: :special-members: :exclude-members: from_params .. autofunction:: tango.step.step .. autoclass:: tango.step.WithUnresolvedSteps :members: .. autoclass:: tango.step.StepResources :members: Implementations --------------- .. automodule:: tango.steps :members: ================================================ FILE: docs/source/api/components/step_cache.rst ================================================ StepCache ========= Base class ---------- .. autoclass:: tango.step_cache.StepCache :members: :special-members: Implementations --------------- .. autoclass:: tango.step_caches.LocalStepCache :members: .. autoclass:: tango.step_caches.MemoryStepCache Metadata -------- .. autoclass:: tango.step_cache.CacheMetadata :members: ================================================ FILE: docs/source/api/components/step_graph.rst ================================================ StepGraph ========= .. autoclass:: tango.step_graph.StepGraph :members: ================================================ FILE: docs/source/api/components/step_info.rst ================================================ StepInfo ======== .. autoclass:: tango.step_info.StepInfo :member-order: bysource :members: .. autoclass:: tango.step_info.StepState :member-order: bysource :members: .. autoclass:: tango.step_info.PlatformMetadata :member-order: bysource :members: .. autoclass:: tango.step_info.EnvironmentMetadata :member-order: bysource :members: .. autoclass:: tango.step_info.GitMetadata :member-order: bysource :members: .. autoclass:: tango.step_info.TangoMetadata :member-order: bysource :members: ================================================ FILE: docs/source/api/components/workspace.rst ================================================ Workspace ========= Base class ---------- .. autoclass:: tango.workspace.Workspace :members: Implementations --------------- .. autoclass:: tango.workspaces.LocalWorkspace .. autoclass:: tango.workspaces.MemoryWorkspace Metadata -------- .. autoclass:: tango.workspace.Run :members: .. autoclass:: tango.workspace.RunInfo :members: Miscellaneous ------------- .. autoclass:: tango.workspace.RunSort :members: .. autoclass:: tango.workspace.StepInfoSort :members: ================================================ FILE: docs/source/api/det_hash.rst ================================================ Deterministic Hashing ===================== In order to detect whether a :class:`~tango.step.Step` has to be re-run or not, Tango relies on some tools to compute deterministic hashes from the inputs to the :class:`~tango.step.Step`. The center-piece of this module is the :func:`~tango.common.det_hash.det_hash` function, which computes a deterministic hash of an arbitrary Python object. The other things in this module influence how that works in various ways. .. automodule:: tango.common.det_hash :members: ================================================ FILE: docs/source/api/exceptions.rst ================================================ Exceptions ========== .. autoexception:: tango.common.exceptions.TangoError :members: .. automodule:: tango.common.exceptions :members: :exclude-members: TangoError ================================================ FILE: docs/source/api/integrations/beaker.rst ================================================ 🧪 Beaker ========= .. automodule:: tango.integrations.beaker Reference --------- .. autoclass:: tango.integrations.beaker.BeakerWorkspace .. autoclass:: tango.integrations.beaker.BeakerStepCache .. autoclass:: tango.integrations.beaker.BeakerExecutor :members: DEFAULT_BEAKER_IMAGE .. autoclass:: tango.integrations.beaker.BeakerScheduler :members: .. autoclass:: tango.integrations.beaker.SimpleBeakerScheduler .. autoclass:: tango.integrations.beaker.ResourceAssignment :members: .. autoclass:: tango.integrations.beaker.ResourceAssignmentError ================================================ FILE: docs/source/api/integrations/datasets.rst ================================================ 🤗 Datasets =========== .. automodule:: tango.integrations.datasets Reference --------- .. autofunction:: tango.integrations.datasets.convert_to_tango_dataset_dict .. autoclass:: tango.integrations.datasets.DatasetsFormat .. autoclass:: tango.integrations.datasets.LoadDataset :members: .. autoclass:: tango.integrations.datasets.LoadStreamingDataset :members: .. autoclass:: tango.integrations.datasets.InterleaveDatasets :members: .. autoclass:: tango.integrations.datasets.ConcatenateDatasets :members: .. autoclass:: tango.integrations.datasets.DatasetRemixStep :members: ================================================ FILE: docs/source/api/integrations/fairscale.rst ================================================ 🔥 FairScale ============ .. automodule:: tango.integrations.fairscale Reference --------- .. autoclass:: tango.integrations.fairscale.FairScaleTrainingEngine .. autoclass:: tango.integrations.fairscale.FSDPConfig :members: .. autofunction:: tango.integrations.fairscale.with_wrapped_modules ================================================ FILE: docs/source/api/integrations/flax.rst ================================================ Flax ======= .. automodule:: tango.integrations.flax Reference --------- Train step ~~~~~~~~~~ .. autoclass:: tango.integrations.flax.FlaxTrainStep :members: .. autoclass:: tango.integrations.flax.TrainConfig :members: Eval step ~~~~~~~~~ .. autoclass:: tango.integrations.flax.FlaxEvalStep :members: Flax format ~~~~~~~~~~~~ .. autoclass:: tango.integrations.flax.FlaxFormat Model ~~~~~ .. autoclass:: tango.integrations.flax.Model :members: Optim ~~~~~ .. autoclass:: tango.integrations.flax.Optimizer :members: .. autoclass:: tango.integrations.flax.LRScheduler :members: Data ~~~~ .. autoclass:: tango.integrations.flax.DataLoader :members: .. autoclass:: tango.integrations.flax.FlaxDataLoader :members: Callbacks ~~~~~~~~~ .. autoclass:: tango.integrations.flax.TrainCallback :members: :member-order: bysource .. autoclass:: tango.integrations.flax.EvalCallback :members: :member-order: bysource ================================================ FILE: docs/source/api/integrations/gs.rst ================================================ ☁️ Google Cloud Storage ======================= .. automodule:: tango.integrations.gs Reference --------- .. autoclass:: tango.integrations.gs.GSWorkspace .. autoclass:: tango.integrations.gs.GSStepCache ================================================ FILE: docs/source/api/integrations/index.rst ================================================ Integrations ============ .. automodule:: tango.integrations .. toctree:: :maxdepth: 2 :caption: Integrations torch fairscale datasets transformers wandb beaker flax gs ================================================ FILE: docs/source/api/integrations/torch.rst ================================================ 🔥 PyTorch ========== .. automodule:: tango.integrations.torch Reference --------- Train step ~~~~~~~~~~ .. autoclass:: tango.integrations.torch.TorchTrainStep :members: .. autoclass:: tango.integrations.torch.TrainConfig :members: Eval step ~~~~~~~~~ .. autoclass:: tango.integrations.torch.TorchEvalStep :members: Torch format ~~~~~~~~~~~~ .. autoclass:: tango.integrations.torch.TorchFormat Model ~~~~~ .. autoclass:: tango.integrations.torch.Model :members: TrainingEngine ~~~~~~~~~~~~~~ .. autoclass:: tango.integrations.torch.TrainingEngine :members: .. autoclass:: tango.integrations.torch.TorchTrainingEngine Optim ~~~~~ .. autoclass:: tango.integrations.torch.Optimizer :members: .. autoclass:: tango.integrations.torch.LRScheduler :members: Data ~~~~ .. autoclass:: tango.integrations.torch.DataLoader :members: .. autoclass:: tango.integrations.torch.Sampler :members: .. autoclass:: tango.integrations.torch.DataCollator :members: :special-members: __call__ .. autoclass:: tango.integrations.torch.ConcatTensorDictsCollator :members: Callbacks ~~~~~~~~~ .. autoclass:: tango.integrations.torch.TrainCallback :members: :member-order: bysource .. autoclass:: tango.integrations.torch.EvalCallback :members: :member-order: bysource .. autoclass:: tango.integrations.torch.StopEarlyCallback .. autoclass:: tango.integrations.torch.StopEarly :members: ================================================ FILE: docs/source/api/integrations/transformers.rst ================================================ 🤗 Transformers =============== .. automodule:: tango.integrations.transformers :members: .. autofunction:: tango.integrations.transformers.ia3.modify_with_ia3 ================================================ FILE: docs/source/api/integrations/wandb.rst ================================================ ⚖️ Weights & Biases =================== .. automodule:: tango.integrations.wandb Reference --------- .. autoclass:: tango.integrations.wandb.WandbWorkspace .. autoclass:: tango.integrations.wandb.WandbStepCache .. autoclass:: tango.integrations.wandb.WandbTrainCallback .. autoclass:: tango.integrations.wandb.WandbFlaxTrainCallback ================================================ FILE: docs/source/api/logging.rst ================================================ Logging ======= .. automodule:: tango.common.logging Reference --------- .. autodata:: tango.common.logging.TANGO_LOG_LEVEL .. autodata:: tango.common.logging.FILE_FRIENDLY_LOGGING .. autodata:: tango.common.logging.cli_logger .. autofunction:: tango.common.logging.initialize_logging .. autofunction:: tango.common.logging.initialize_worker_logging .. autofunction:: tango.common.logging.initialize_prefix_logging .. autofunction:: tango.common.logging.teardown_logging .. autofunction:: tango.common.logging.file_handler ================================================ FILE: docs/source/api/sequences.rst ================================================ Sequences ========= This module contains some utilities to make sequences out of other sequences. All of these are lazy, so they take minimal time and memory when you create them. These work particularly well when used together. For example, you can concatenate two sequences (:class:`~tango.common.sequences.ConcatenatedSequence`), and then shuffle them (:class:`~tango.common.sequences.ShuffledSequence`). This module is not dependent on other Tango modules and can be used in isolation. .. automodule:: tango.common.sequences :members: ================================================ FILE: docs/source/api/settings.rst ================================================ Global settings --------------- Some command-line options can set globally in a ``tango.yml`` or ``tango.yaml`` settings file. Tango will check the current directory and ``~/.config/``, in that order. The full spec of this config is defined by the :class:`~tango.settings.TangoGlobalSettings` class. .. autoclass:: tango.settings.TangoGlobalSettings :members: :exclude-members: path,find_or_default :member-order: bysource ================================================ FILE: docs/source/api/utilities.rst ================================================ Utilities ========= .. automodule:: tango.common :members: :exclude-members: det_hash ================================================ FILE: docs/source/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html import logging import os import sys from datetime import datetime # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath("../../")) from tango.version import VERSION, VERSION_SHORT # noqa: E402 # -- Project information ----------------------------------------------------- project = "AI2 Tango" copyright = f"{datetime.today().year}, Allen Institute for Artificial Intelligence" author = "Allen Institute for Artificial Intelligence" version = VERSION_SHORT release = VERSION # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", "myst_parser", "sphinx.ext.intersphinx", "sphinx.ext.viewcode", "sphinx.ext.doctest", "sphinx_copybutton", "sphinx_autodoc_typehints", ] suppress_warnings = ["myst.header"] # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build"] source_suffix = [".rst", ".md"] # -- Extension configuration ------------------------------------------------- intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "rich": ("https://rich.readthedocs.io/en/latest", None), "torch": ("https://pytorch.org/docs/stable", None), "flax": ("https://flax.readthedocs.io/en/latest", None), "fairscale": ("https://fairscale.readthedocs.io/en/latest/", None), "datasets": ("https://huggingface.co/docs/datasets/master/en", None), "transformers": ("https://huggingface.co/docs/transformers/master/en", None), "beaker": ("https://beaker-py.readthedocs.io/en/latest/", None), } # Tell myst-parser to assign header anchors for h1-h3. myst_heading_anchors = 3 # By default, sort documented members by type within classes and modules. autodoc_member_order = "groupwise" python_use_unqualified_type_names = True # Include default values when documenting parameter types. typehints_defaults = "comma" # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = "furo" html_title = f"ai2-tango v{VERSION}" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] html_css_files = ["css/custom.css"] html_favicon = "_static/favicon.ico" html_theme_options = { "light_css_variables": { "color-announcement-background": "#1B4596", "color-announcement-text": "#FFFFFF", }, "dark_css_variables": {}, "light_logo": "tango_final_squareish.png", "dark_logo": "tango_final_squareish.png", "footer_icons": [ { "name": "GitHub", "url": "https://github.com/allenai/tango", "html": """ """, # noqa: E501 "class": "", }, ], } # -- Hack to get rid of stupid warnings from sphinx_autodoc_typehints -------- class ShutupSphinxAutodocTypehintsFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: if "Cannot resolve forward reference" in record.msg: return False return True logging.getLogger("sphinx.sphinx_autodoc_typehints").addFilter(ShutupSphinxAutodocTypehintsFilter()) ================================================ FILE: docs/source/examples/euler.md ================================================ ```{include} ../../../examples/euler/README.md ``` ## Running the experiment If you haven't already, clone the [tango repository](https://github.com/allenai/tango) and then change directories into `examples/euler`. You can then run the experiment with: ```bash tango run euler_general.jsonnet -i complex_arithmetic -w workspace ``` This will leave its results in a subdirectory of `workspace/runs/` corresponding to the name of the run. The output it prints should look something like this: ``` Starting new run comic-heron Server started at http://localhost:8080/run/comic-heron [step i_times_pi] ● Starting step "i_times_pi"... [step i_times_pi] ✓ Finished step "i_times_pi" [step cos] ● Starting step "cos"... [step cos] ✓ Finished step "cos" [step sin] ● Starting step "sin"... [step sin] ✓ Finished step "sin" [step pow_e] ✓ Found output for step "i_times_pi" in cache (needed by "pow_e")... [step pow_e] ● Starting step "pow_e"... [step pow_e] ✓ Finished step "pow_e" [step i_times_sin] ✓ Found output for step "sin" in cache (needed by "i_times_sin")... [step i_times_sin] ● Starting step "i_times_sin"... [step i_times_sin] ✓ Finished step "i_times_sin" [step sum] ✓ Found output for step "cos" in cache (needed by "sum")... [step sum] ✓ Found output for step "i_times_sin" in cache (needed by "sum")... [step sum] ● Starting step "sum"... [step sum] ✓ Finished step "sum" [step sub] ✓ Found output for step "sum" in cache (needed by "sub")... [step sub] ✓ Found output for step "pow_e" in cache (needed by "sub")... [step sub] ● Starting step "sub"... [step sub] ✓ Finished step "sub" [step print] ✓ Found output for step "sub" in cache (needed by "print")... [step print] ● Starting step "print"... [step print] 0j [step print] ✓ Finished step "print" ✓ Finished run comic-heron ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Step Name ┃ Status ┃ Cached Result ┃ ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ cos │ ✓ succeeded │ workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk │ │ i_times_pi │ ✓ succeeded │ workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae │ │ i_times_sin │ ✓ succeeded │ workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf │ │ pow_e │ ✓ succeeded │ workspace/cache/ExponentiateStep-1swPpNipP6HBSP5rKdNjEqbYAWNf4CdG │ │ print │ ✓ succeeded │ N/A │ │ sin │ ✓ succeeded │ workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk │ │ sub │ ✓ succeeded │ workspace/cache/SubtractionStep-4ygj1UyLk6TCVBxN7DWTCccbMa7M1C5v │ │ sum │ ✓ succeeded │ workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP │ └─────────────┴─────────────┴───────────────────────────────────────────────────────────────────┘ ✓ 8 succeeded Use your workspace to get the cached result of a step, e.g. >>> from tango import Workspace >>> workspace = Workspace.from_url(...) >>> workspace.step_result_for_run("comic-heron", "sum") ``` A few things are of note here: 1. Tango assigns a name to your run. In this case, the name is "comic-heron". 2. In this configuration, the "print" step prints the output ("`0j`"). Most of the time though, you will look for the output in the output directories that are given in the table. 3. You might notice that the "print" step produces no output. That's because it is uncacheable, and thus writes out nothing. ## Change a step Let's make an update to a step! Open `complex_arithmetic.py` and change `AdditionStep`. The actual change you make in the `run()` method does not matter, but the important thing is to update the `VERSION` member of the `AdditionStep` class. `AdditionStep` does not yet have a `VERSION`, so we will give it one: ```Python @Step.register("cadd") class AdditionStep(Step): VERSION = "002" # This is the important change. def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex: # type: ignore return make_complex(a) + make_complex(b) ``` Now run the config again with ```bash tango run euler_general.jsonnet -i complex_arithmetic -w workspace ``` This time, the output will look like this: ``` Starting new run right-amoeba Server started at http://localhost:8080/run/right-amoeba [step sum] ✓ Found output for step "cos" in cache (needed by "sum")... [step sum] ✓ Found output for step "i_times_sin" in cache (needed by "sum")... [step sum] ● Starting step "sum"... [step sum] ✓ Finished step "sum" [step sub] ✓ Found output for step "sum" in cache (needed by "sub")... [step sub] ✓ Found output for step "pow_e" in cache (needed by "sub")... [step sub] ● Starting step "sub"... [step sub] ✓ Finished step "sub" [step print] ✓ Found output for step "sub" in cache (needed by "print")... [step print] ● Starting step "print"... [step print] 0j [step print] ✓ Finished step "print" ✓ Finished run right-amoeba ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Step Name ┃ Status ┃ Cached Result ┃ ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ cos │ - not run │ workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk │ │ i_times_pi │ - not run │ workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae │ │ i_times_sin │ - not run │ workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf │ │ pow_e │ - not run │ workspace/cache/ExponentiateStep-1swPpNipP6HBSP5rKdNjEqbYAWNf4CdG │ │ print │ ✓ succeeded │ N/A │ │ sin │ - not run │ workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk │ │ sub │ ✓ succeeded │ workspace/cache/SubtractionStep-42mdcQBtrNAYvxYhmzdd1vj2uCG8N5Yf │ │ sum │ ✓ succeeded │ workspace/cache/AdditionStep-002-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP │ └─────────────┴─────────────┴───────────────────────────────────────────────────────────────────┘ ✓ 3 succeeded, 5 not run Use your workspace to get the cached result of a step, e.g. >>> from tango import Workspace >>> workspace = Workspace.from_url(...) >>> workspace.step_result_for_run("right-amoeba", "sum") ``` As you can see, it re-used the cached results for several of the steps, and only ran three steps anew. ```{eval-rst} :class:`tango.step.Step.VERSION` is just one of the ways in which you can change the behavior of a step. Head over to the documentation of the :class:`tango.step.Step` class to see the others. ``` ================================================ FILE: docs/source/examples/eval_p3.md ================================================ ```{include} ../../../examples/eval_p3/README.md ``` ## `RougeScoreStep` `RougeScoreStep` is defined in `eval.py`: ```{literalinclude} ../../../examples/eval_p3/eval.py :language: py ``` ## Config The configuration file, `config.jsonnet`, uses some advanced [Jsonnet](https://jsonnet.org) concepts like `std.foldl` to create the same configuration for all 10 prompts: ```{literalinclude} ../../../examples/eval_p3/config.jsonnet ``` ## Run it You can run the experiment with: ```bash tango run config.jsonnet -i eval -d /tmp/workspace ``` ================================================ FILE: docs/source/examples/index.rst ================================================ Examples ======== Real-world examples of using Tango. You can find all of these `on GitHub `_ as well. .. toctree:: :maxdepth: 2 :caption: Examples euler train_lm eval_p3 ================================================ FILE: docs/source/examples/train_lm.md ================================================ # Fine-tuning a language model ```{include} ../../../examples/train_lm/README.md :start-after: :end-before: ``` ```{tip} You can find the full code for this example on [GitHub](https://github.com/allenai/tango/tree/main/examples/train_lm). ``` ## Components We'll need to write a step for tokenizing the data and preparing it for language model training. All of the other steps we need are provided by Tango integrations. So, create a file called `tokenize_step.py` with following contents: ```{literalinclude} ../../../examples/train_lm/tokenize_step.py :language: py ``` ## Configuration file Next you'll need to create a configuration file that defines the experiment. Just copy over these contents into a file called `config.jsonnet`: ```{literalinclude} ../../../examples/train_lm/config.jsonnet ``` ## Run it Now we can run the experiment with: ```bash tango run config.jsonnet -i tokenize_step.py -d /tmp/results ``` ================================================ FILE: docs/source/faq.md ================================================ # FAQ ```{include} ../../README.md :start-after: :end-before: ``` ================================================ FILE: docs/source/first_steps.md ================================================ # First Steps ## What is a Step? Tango is a Python library for choreographing machine learning research experiments by executing a series of steps. A step can do anything, really, such as [prepare a dataset](tango.integrations.datasets.LoadDataset), [train a model](tango.integrations.torch.TorchTrainStep), send an email to your mother wishing her happy birthday, *etc*. Concretely, each step is just a subclass of {class}`~tango.step.Step`, where the {meth}`~tango.step.Step.run` method in particular defines what the step actually does. So anything that can be implemented in Python can be run as a step. Steps can also depend on other steps in that the output of one step can be part of the input to another step. Therefore, the steps that make up an experiment form a [directed graph](tango.step_graph.StepGraph). The concept of the {class}`~tango.step.Step` is the bread and butter that makes Tango so general and powerful. *So* powerful, in fact, that you might be wondering if Tango is [Turing-complete](https://en.wikipedia.org/wiki/Turing_completeness)? Well, we don't know yet, but we can say at least that Tango is **Tango-complete** 😉 ## Configuration files Experiments themselves are defined through JSON, [Jsonnet](https://jsonnet.org/), or YAML configuration files. At a minimum, these files must contain the "steps" field, which should be a mapping of arbitrary (yet unique) step names to the configuration of the corresponding step. For example, let's create a config file called `config.jsonnet` with the following contents: ```json { "steps": { "random_name": { "type": "random_choice", "choices": ["Turing", "Tango", "Larry"], }, "say_hello": { "type": "concat_strings", "string1": "Hello, ", "string2": { "type": "ref", "ref": "random_name" } }, "print": { "type": "print", "input": { "type": "ref", "ref": "say_hello" } } } } ``` *Can you guess what this experiment does?* There are three steps in this experiment graph: "random_name" is the name of one step, "say_hello" is the name of another, and "print" is the name of the last. The "type" parameter within the config of each step tells Tango which {class}`~tango.step.Step` class implementation to use for that step. So, within the "random_name" step config ```json "random_name": { "type": "random_choice", "choices": ["Turing", "Tango", "Larry"], } ``` the `"type": "random_choice"` part tells Tango to use the {class}`~tango.step.Step` subclass that is registered by the name "random_choice". But wait... what do we mean by *registered*? Tango keeps track of an internal registry for certain classes (such as the {class}`~tango.step.Step` class) that is just a mapping of arbitrary unique names to subclasses. When you look through Tango's source code, you'll see things like: ```python @Step.register("foo") class Foo(Step): ... ``` This is how subclasses get added to the registry. In this case the subclass `Foo` is added to the `Step` registry under the name "foo", so if you were to use `"type": "foo"` in your configuration file, Tango would understand that you mean to use the `Foo` class for the given step. ```{tip} Any class that inherits from {class}`~tango.common.registrable.Registrable` can have its own registry. ``` Now back to our example. The step classes referenced in our configuration file ("random_choice" and "concat_strings") don't actually exist in the Tango library (though the ["print" step](tango.steps.PrintStep) does), but we can easily implement and register them on our own. Let's put them in a file called `components.py`: ```python # file: components.py import random from typing import List from tango import Step @Step.register("random_choice") class RandomChoiceStep(Step): DETERMINISTIC = False def run(self, choices: List[str]) -> str: return random.choice(choices) @Step.register("concat_strings") class ConcatStringsStep(Step): def run(self, string1: str, string2: str) -> str: return string1 + string2 ``` ```{important} It's important that you use type hints in your code so that Tango can properly construct Python objects from the corresponding serialized (JSON) objects and warn you when the types don't match up. ``` So as long as Tango is able to import this module (`components.py`) these step implementations will be added to the registry and Tango will know how to instantiate and run them. There's also a short-hand way of implementing steps, using the {func}`@step() ` function decorator: ```python from tango import step @step(deterministic=False) def random_choice(choices: List[str]) -> str: return random.choice(choices) @step() def concat_strings(string1: str, string2: str) -> str: return string1 + string2 ``` This will register these steps under the name of the corresponding function, i.e. "random_choice" and "concat_strings", by default, though that can be overridden by specifying the "name" parameter to the decorator: ```python @step(name="random-string", deterministic=False) def random_choice(choices: List[str]) -> str: return random.choice(choices) ``` ## Executing an experiment At this point we've implemented our custom steps (`components.py`) and created our configuration file `config.jsonnet`, so we're ready to actually run this experiment. For that, just use the `tango run` command: ``` $ tango run config.jsonnet -i components ``` ```{tip} - The `-i` option is short for `--include-package`, which takes the name of a Python package which Tango will try to import. In this case our custom steps are in `components.py`, so we need Tango to import this module to find those steps. As long as `components.py` is in the current directory or somewhere else on the `PYTHONPATH`, Tango will be able to find and import this module when you pass `-i components` (note the lack of the `.py` at the end). ``` You should see something like this in the output: ``` Starting new run cute-kitten ● Starting step "random_name" ✓ Finished step "random_name" ● Starting step "say_hello" ✓ Finished step "say_hello" ● Starting step "print" Hello, Tango ✓ Finished step "print" ``` ## Step caching This particular experiment didn't write any results to disk, but in many situations you'll want to save the output of at least some of your steps. For example, if you're using the {class}`~tango.integrations.torch.TorchTrainStep` step, the output is a trained model, which is certainly a useful thing to keep around. In other cases, you may not actually care about the direct result of a particular step, but it could still be useful to save it when possible so that Tango doesn't need to run the step again unnecessarily. This is where Tango's caching mechanism comes in. To demonstrate this, let's look at another example that pretends to do some expensive computation. Here is the `config.jsonnet` file: ```json { "steps": { "add_numbers": { "type": "really_inefficient_addition", "num1": 34, "num2": 8 } } } ``` And let's implement "really_inefficient_addition": ```python # components.py import time from tango import Step, JsonFormat from tango.common import Tqdm @Step.register("really_inefficient_addition") class ReallyInefficientAdditionStep(Step): DETERMINISTIC = True CACHEABLE = True FORMAT = JsonFormat() def run(self, num1: int, num2: int) -> int: for _ in Tqdm.tqdm(range(100), desc="Computing...", total=100): time.sleep(0.05) return num1 + num2 ``` There are a couple of things to note about this step, other than the obvious inefficiencies; the class variables we've defined: {attr}`~tango.step.Step.DETERMINISTIC`, {attr}`~tango.step.Step.CACHEABLE`, and {attr}`~tango.step.Step.FORMAT`. `DETERMINISTIC = True` tells Tango that, given particular inputs, the output to this step will always be the same every time it is ran, which has implications on caching. By default, Tango assumes steps are deterministic. You can override this by saying `DETERMINISTIC = False`. Tango will warn you when you try to cache a non-deterministic step. `CACHEABLE = True` tells Tango that it can cache this step and `FORMAT = JsonFormat()` defines which {class}`~tango.format.Format` Tango will use to serialize the result of the step. This time when we run the experiment we'll designate a specific directory for Tango to use: ```bash $ tango run config.jsonnet -i components -d workspace/ ``` ``` Starting new run live-tarpon ● Starting step "add_numbers" Computing...: 100%|##########| 100/100 [00:05<00:00, 18.99it/s] ✓ Finished step "add_numbers" ✓ The output for "add_numbers" is in workspace/runs/live-tarpon/add_numbers ``` The last line in the output tells us where we can find the result of our "add_numbers" step. `live-tarpon` is the name of the run. Run names are randomly generated and may be different on your machine. `add_numbers` is the name of the step in your config. The whole path is a symlink to a directory, which contains (among other things) a file `data.json`: ```bash $ cat workspace/runs/live-tarpon/add_numbers/data.json ``` ``` 42 ``` Now look what happens when we run this step again: ```bash $ tango run config.jsonnet -i components -d workspace/ ``` ``` Starting new run modest-shrimp ✓ Found output for "add_numbers" in cache ✓ The output for "add_numbers" is in workspace/runs/modest-shrimp/add_numbers ``` Tango didn't have to run our really inefficient addition step this time because it found the previous cached result. It put the results in the result directory for a different run (in our case, the `modest-shrimp` run), but once again it is a symlink that links to the same results from our first run. If we changed the inputs to the step in `config.jsonnet`: ```diff "add_numbers": { "type": "really_inefficient_addition", "num1": 34, - "num2": 8 + "num2": 2 } } } ``` And ran it again: ```bash $ tango run config.jsonnet -i components -d workspace/ ``` ``` Starting new run true-parrot ● Starting step "add_numbers" Computing...: 100%|##########| 100/100 [00:05<00:00, 19.13it/s] ✓ Finished step "add_numbers" ✓ The output for "add_numbers" is in workspace/runs/true-parrot/add_numbers ``` You'd see that Tango had to run our "add_numbers" step again. You may have noticed that `workspace/runs/true-parrot/add_numbers` is now a symlink that points to a different place than it did for the first two runs. That's because it produced a different result this time. All the result symlinks point into the `workspace/cache/` directory, where all the step's results are cached. This means that if we ran the experiment again with the original inputs, Tango would still find the cached result and wouldn't need to rerun the step. ## Arbitrary objects as inputs ### `FromParams` So far the inputs to all of the steps in our examples have been built-in Python types that can be deserialized from JSON (e.g. {class}`int`, {class}`str`, etc.), but sometimes you need the input to a step to be an instance of an arbitrary Python class. Tango allows this as well as it can infer from type hints what the class is and how to instantiate it. When writing your own classes, it's recommended that you have your class inherit from the {class}`~tango.common.from_params.FromParams` class, which will gaurantee that Tango can instantiate it from a config file. For example, suppose we had a step like this: ```python from tango import Step from tango.common import FromParams class Bar(FromParams): def __init__(self, x: int) -> None: self.x = x @Step.register("foo") class FooStep(Step): def run(self, bar: Bar) -> int: return bar.x ``` ```{tip} If you've used [AllenNLP](https://github.com/allenai/allennlp) before, this will look familiar! In fact, it's the same system under the hood. ``` Then we could create a config like this: ```json { "steps": { "foo": { "type": "foo", "bar": {"x": 1} } } } ``` And Tango will figure out how to deserialize `{"x": 1}` into a `Bar` instance. You can also have `FromParams` objects nested within other `FromParams` objects or standard containers like {class}`list`: ```python from typing import List from tango import Step from tango.common import FromParams class Bar(FromParams): def __init__(self, x: int) -> None: self.x = x class Baz(FromParams): def __init__(self, bar: Bar) -> None: self.bar = bar @Step.register("foo") class FooStep(Step): def run(self, bars: List[Bar], baz: Baz) -> int: return sum([bar.x for bar in bars]) + baz.bar.x ``` ### `Registrable` The {class}`~tango.common.registrable.Registrable` class is a special kind of {class}`~tango.common.from_params.FromParams` class that allows you to specify from the config which subclass of an expected class to deserialize into. This is actually how we've been instantiating specific `Step` subclasses. Because {class}`~tango.step.Step` inherits from {class}`~tango.common.registrable.Registrable`, we can use the `"type"` fields in the config file to specify a `Step` subclass. This is also very useful when you're writing a step that requires a certain type as input, but you want to be able to change the exact subclass of the type from your config file. For example, the {class}`~tango.integrations.torch.TorchTrainStep` takes `Registrable` inputs such as {class}`~tango.integrations.torch.Model`. Model variants can then be subclasses that are specified in the config file by their registered names. A sketch of this might look like the following: ```python from tango import Step from tango.common import FromParams, Registrable class Model(torch.nn.Module, Registrable): ... @Model.register("variant1") class Variant1(Model): ... @Model.register("variant2") class Variant2(Model): ... @Step.register("torch::train") class TorchTrainerStep(Step): def run(self, model: Model, ...) -> Model: ... ``` And a sketch of the config file would be something like this: ```json { "steps": { "train": { "type": "torch::train", "model": { "type": "variant1", } } } } ``` As in the `FromParams` example the specifications can be nested, but now we also denote the subclass with the `"type": "..."` field. To swap models we need only change "variant1" to "variant2" in the config. The value for "type" can either be the name that the class is registered under (e.g. "train" for `TorchTrainStep`), or the fully qualified class name (e.g. `tango.integrations.torch.TorchTrainStep`). You'll see more examples of this in the [next section](examples/index). ================================================ FILE: docs/source/index.md ================================================ # **AI2 Tango** ```{include} ../../README.md :start-after: :end-before: ``` ```{toctree} :maxdepth: 2 :hidden: :caption: Getting started installation first_steps examples/index faq ``` ```{toctree} :maxdepth: 2 :hidden: :caption: API Reference api/commands api/components/index api/integrations/index api/settings api/exceptions api/logging api/sequences api/det_hash api/utilities ``` ```{toctree} :hidden: :caption: Development CONTRIBUTING CHANGELOG License GitHub Repository ``` To learn about Tango in 5 minutes, head over to the [First Steps section](first_steps). If you'd rather learn from examples, check out the [Examples section](examples/index). ## Team ```{include} ../../README.md :start-after: :end-before: ``` ## License ```{include} ../../README.md :start-after: :end-before: ``` ## Indices and tables ```{eval-rst} * :ref:`genindex` * :ref:`modindex` ``` ================================================ FILE: docs/source/installation.md ================================================ Installation ============ ```{include} ../../README.md :start-after: :end-before: ``` ================================================ FILE: examples/euler/README.md ================================================ Euler ===== This is a toy example that proves Euler's identity using Tango. You can use this to play with the concept of a `Step` and see how Tango runs things without getting distracted by the details of what you're running. ================================================ FILE: examples/euler/complex_arithmetic.py ================================================ import cmath from typing import Tuple, Union from tango import Step ComplexOrTuple = Union[complex, Tuple[float, float]] def make_complex(x: ComplexOrTuple) -> complex: if isinstance(x, complex): return x elif isinstance(x, (int, float)): return complex(x) else: return complex(*x) @Step.register("cadd") class AdditionStep(Step): def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex: # type: ignore return make_complex(a) + make_complex(b) @Step.register("csub") class SubtractionStep(Step): def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex: # type: ignore return make_complex(a) - make_complex(b) @Step.register("cexp") class ExponentiateStep(Step): def run(self, x: ComplexOrTuple, base: ComplexOrTuple = cmath.e) -> complex: # type: ignore return make_complex(base) ** make_complex(x) @Step.register("cmul") class MultiplyStep(Step): def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex: # type: ignore return make_complex(a) * make_complex(b) @Step.register("csin") class SineStep(Step): def run(self, x: ComplexOrTuple) -> complex: # type: ignore return cmath.sin(make_complex(x)) @Step.register("ccos") class CosineStep(Step): def run(self, x: ComplexOrTuple) -> complex: # type: ignore return cmath.cos(make_complex(x)) ================================================ FILE: examples/euler/euler.jsonnet ================================================ local i = [0.0, 1.0]; local pi = [3.1415926535, 0.0]; { "steps": { "i_times_pi": { "type": "cmul", "a": i, "b": pi }, "pow_e": { "type": "cexp", "x": { "type": "ref", "ref": "i_times_pi" } }, "plus_one": { "type": "cadd", "a": { "type": "ref", "ref": "pow_e" }, "b": [1, 0] }, "print": { "type": "print", "input": { "type": "ref", "ref": "plus_one" } } } } ================================================ FILE: examples/euler/euler_general.jsonnet ================================================ local i = [0.0, 1.0]; local pi = [3.1415926535, 0.0]; { "steps": { "cos": { "type": "ccos", "x": pi }, "sin": { "type": "csin", "x": pi }, "i_times_sin": { "type": "cmul", "a": i, "b": { "type": "ref", "ref": "sin" } }, "sum": { "type": "cadd", "a": { "type": "ref", "ref": "cos" }, "b": { "type": "ref", "ref": "i_times_sin" }, }, "i_times_pi": { "type": "cmul", "a": i, "b": pi }, "pow_e": { "type": "cexp", "x": { "type": "ref", "ref": "i_times_pi" } }, "sub": { "type": "csub", "a": { "type": "ref", "ref": "sum" }, "b": { "type": "ref", "ref": "pow_e" }, }, "print": { "type": "print", "input": { "type": "ref", "ref": "sub" } } } } ================================================ FILE: examples/euler/run.sh ================================================ #!/bin/bash tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic ================================================ FILE: examples/eval_p3/README.md ================================================ # Evaluating T0 This example uses the `transformers::run_generation_dataset` step to run the [T0 model](https://api.semanticscholar.org/CorpusID:239009562). It runs the [XSum summarization data](https://github.com/EdinburghNLP/XSum), prompted in 10 different ways, and computes ROUGE scores for all variants. Finally, it computes an overall ROUGE score. This example uses mostly built-in Tango steps. You will need the `datasets` and `transformers` integrations. The only custom step in this example is the `RougeScoreStep`, which computes ROUGE scores from the generated text. ================================================ FILE: examples/eval_p3/config.jsonnet ================================================ local model = "bigscience/T0_3B"; local batch_size = 8; local datasets = [ 'xsum_DOC_boils_down_to_simple_idea_that', 'xsum_DOC_given_above_write_one_sentence', 'xsum_DOC_how_would_you_rephrase_few_words', 'xsum_DOC_tldr', 'xsum_DOC_write_summary_of_above', 'xsum_article_DOC_summary', 'xsum_college_roommate_asked_DOC_so_I_recap', 'xsum_read_below_DOC_write_abstract', 'xsum_summarize_DOC', 'xsum_summarize_this_DOC_summary' ]; # This creates three steps for each of the datasets: # 1. Load the dataset. # 2. Generate output based on the dataset. # 3. Evaluate the output against the gold answers. local dataset_steps = std.foldl( function(x, dataset_name) x + { ["dataset_" + dataset_name]: { "type": "datasets::load", "path": "bigscience/P3", "name": dataset_name, }, ["generation_" + dataset_name]: { "type": "transformers::run_generation_dataset", "max_length": 200, "input": {"ref": "dataset_" + dataset_name}, "batch_size": batch_size, "model": model, "prompt_field": "inputs_pretokenized", "output_field": "generation", "splits": ["validation"] }, ["eval_" + dataset_name]: { "type": "rouge_score", "input": {"ref": "generation_" + dataset_name}, "input_split": "validation", "target_field": "targets_pretokenized", "prediction_field": "generation" } }, datasets, {} ); # In addition to the three steps per dataset, we also combine all the generations and # evaluate them all together. { "steps": dataset_steps + { "all_generations": { "type": "dataset_combine", "inputs": std.map( function(dataset_name) {"ref": "generation_" + dataset_name}, datasets ) }, "all_evaluations": { "type": "rouge_score", "input": {"ref": "all_generations"}, "input_split": "validation", "target_field": "targets_pretokenized", "prediction_field": "generation" } } } ================================================ FILE: examples/eval_p3/eval.py ================================================ import logging from typing import Dict from torch import Tensor from torchmetrics.text.rouge import ROUGEScore from tango import Format, JsonFormat, Step from tango.common import DatasetDict from tango.common.tqdm import Tqdm logger = logging.getLogger(__name__) @Step.register("rouge_score") class RougeScoreStep(Step[Dict[str, Tensor]]): VERSION = "002" FORMAT: Format = JsonFormat() def run( # type: ignore self, input: DatasetDict, input_split: str, target_field: str, prediction_field: str, use_stemmer: bool = True, ) -> Dict[str, Tensor]: metric = ROUGEScore( use_stemmer=use_stemmer, rouge_keys=("rouge1", "rouge2", "rougeL"), accumulate="avg", ) for instance in Tqdm.tqdm(input[input_split], desc="Calculating scores"): target = instance[target_field] for prediction in instance[prediction_field]: metric.update(prediction, target) return metric.compute() ================================================ FILE: examples/finetune/__init__.py ================================================ ================================================ FILE: examples/finetune/config.jsonnet ================================================ ################## # Model settings # ################## local pretrained_model = "t5-base"; local load_with_low_cpu_mem_usage = false; local modules_to_wrap = ["[a-zA-Z_.]+\\.[0-9]+"]; # TODO: works for t5 and gpt2. confirm with other models too. #################### # Trainer settings # #################### # Trainer settings, adjust to your use-case. local training_steps = 20; # total number of optimization steps to train for local validate_every = 5; # how often to validate and save checkpoints local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU) local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) # This is the batch size per GPU, ignoring gradient accumulation: local batch_size = 2; # So the effective batch size is `batch_size * grad_accum * devices` local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) local amp = false; # use PyTorch's native automatic mixed precision local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. ###################### # Optimizer settings # ###################### local warmup_steps = 20; local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2" assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp"; # FullyShardedDataParallel config: local fsdp_config = if fsdp then { reshard_after_forward: true, move_params_to_cpu: cpu_offloading, move_grads_to_cpu: cpu_offloading, mixed_precision: amp, } else null; local training_engine = { type: if fsdp then "fairscale" else "torch", optimizer: { type: "torch::AdamW", lr: learning_rate, betas: [0.9, 0.95], eps: 1e-6, }, lr_scheduler: { type: "transformers::linear", num_warmup_steps: warmup_steps, num_training_steps: training_steps, }, amp: amp, [if fsdp then "fsdp_config" else null]: fsdp_config, }; local distributed_dataloader = { batch_size: batch_size, sampler: { type: "torch::DistributedSampler", shuffle: true, drop_last: true, }, }; local single_device_dataloader = { shuffle: true, batch_size: batch_size, }; local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; { steps: { raw_data: { type: "datasets::load", path: "snli", }, /*"subset_data": { type: "subset-data", data: { type: "ref", ref: "raw_data" }, max_samples: 10, },*/ processed_data: { type: "snli-text2text", data: { type: "ref", ref: "raw_data" }, }, trained_model: { type: "transformers::finetune", model: { type: "fairscale::with_wrapped_modules", model: { type: "transformers::finetune::from_pretrained", pretrained_model_name_or_path: pretrained_model, low_cpu_mem_usage: load_with_low_cpu_mem_usage, }, modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually fsdp_config: fsdp_config, activation_checkpointing: activation_checkpointing, }, tokenizer: { pretrained_model_name_or_path: pretrained_model }, dataset_dict: { type: "ref", ref: "processed_data" }, train_dataloader: dataloader, validation_split: "validation", grad_accum: grad_accum, train_steps: training_steps, validate_every: validate_every, checkpoint_every: validate_every, log_every: 1, device_count: devices, training_engine: training_engine, }, generations: { type: "transformers::run_generation_dataset", max_length: 5, input: {"type": "ref", "ref": "processed_data"}, batch_size: batch_size, model: {"type": "ref", "ref": "trained_model"}, prompt_field: "source", output_field: "generation", splits: ["validation"] } } } ================================================ FILE: examples/finetune/snli_steps.py ================================================ from typing import Union import datasets as ds from tango.integrations.datasets import DatasetsFormat from tango.step import Step @Step.register("subset-data") class SubsetData(Step): """ Creates a subset of the data; mostly to be used for testing/debugging. """ DETERMINISTIC = True CACHEABLE = True VERSION = "001" FORMAT = DatasetsFormat() def run( # type: ignore self, data: Union[ds.DatasetDict, ds.Dataset], max_samples: int = 5, ) -> Union[ds.DatasetDict, ds.Dataset]: """ Returns a copy of the `data` with number of samples limited to `max_samples` for each split. :param data: The dataset or dataset dict object. :param max_samples: The maximum number of samples to return per split. """ # Unlike `ds.Dataset.select`, this works on both `ds.Dataset` and `ds.DatasetDict`. def filter_fn(example, indices): return indices < max_samples return data.filter(filter_fn, with_indices=True) @Step.register("snli-text2text") class SnliText2Text(Step): """ Converts the snli dataset to a text-to-text format. Examples -------- original_instance = { "premise": "Two cats are sitting on a wall.", "hypothesis": "The cats are chasing a mouse.", "label": 2 # contradiction } returned_instance = { "source": "nli premise: Two cats are sitting on a wall. hypothesis: The cats are chasing a mouse. label: " "target": "contradiction" } """ DETERMINISTIC = True CACHEABLE = True VERSION = "001" FORMAT = DatasetsFormat() def run( # type: ignore self, data: Union[ds.DatasetDict, ds.Dataset], source_prefix: str = "nli", premise_prefix: str = "premise", hypothesis_prefix: str = "hypothesis", label_prefix: str = "label", num_workers: int = 1, ) -> Union[ds.DatasetDict, ds.Dataset]: """ :param data: The snli `Dataset` or `DatasetDict` object. :param source_prefix: The str to add before the start of the source sequence. :param premise_prefix: The str to add before the start of the `premise` in the source sequence. :param hypothesis_prefix: The str to add before the start of the `hypothesis` in the source sequence. :param label_prefix: The str to add as the prompt for the label. :param num_workers: The number of workers to use for processing the data. """ def filter_no_gold(example, indices): if example["label"] == -1: return False return True data = data.filter(filter_no_gold, with_indices=True) label_map = {0: "entailment", 1: "neutral", 2: "contradiction"} def _mapper(example): return { "source": ( f'{source_prefix} {premise_prefix}: {example["premise"]} ' f'{hypothesis_prefix}: {example["hypothesis"]} {label_prefix}: ' ), "target": f'{label_map[example["label"]]}', } if isinstance(data, ds.Dataset): old_cols = data.column_names else: old_cols = list(data.column_names.values())[0] dataset = data.map( _mapper, batched=False, num_proc=num_workers, remove_columns=old_cols, # remove all old columns desc="Converting data to text-to-text format", ) return dataset ================================================ FILE: examples/finetune/test.py ================================================ import typing import datasets as ds import pytest from tango.common import Params from tango.common.testing import TangoTestCase, run_experiment class TestFinetuneSNLI(TangoTestCase): @pytest.mark.parametrize( "model, model_type", [("patrickvonplaten/t5-tiny-random", "t5"), ("sshleifer/tiny-gpt2", "gpt2")], ) @typing.no_type_check # mypy has become incompatible with the datasets library def test_config(self, model: str, model_type: str): overrides = { "steps.trained_model.model.model.pretrained_model_name_or_path": model, "steps.trained_model.tokenizer.pretrained_model_name_or_path": model, "steps.subset_data": { "type": "subset-data", "data": {"type": "ref", "ref": "raw_data"}, "max_samples": 10, }, "steps.processed_data.data.ref": "subset_data", } config = Params.from_file("config.jsonnet", params_overrides=overrides) # Make sure we've overrode the model entirely. flattened = config.as_flat_dict() for key, value in flattened.items(): if "model_name" in key or (isinstance(value, str) and model_type in value): assert value == model with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: assert (run_dir / "processed_data").is_dir() processed = ds.load_from_disk(run_dir / "processed_data" / "data") assert len(processed["train"][0].keys()) == 2 assert "source" in processed["train"][0].keys() assert "target" in processed["train"][0].keys() assert processed["train"][0]["source"].startswith("nli premise:") assert (run_dir / "trained_model").is_dir() ================================================ FILE: examples/finetune_resnet/.gitignore ================================================ data/ results/ extra_testing.py ================================================ FILE: examples/finetune_resnet/config.jsonnet ================================================ local input_size = 224; local batch_size = 32; local num_classes = 2; local val_size = 0.05; local model = "resnet"; local feature_extract = true; local distributed = false; local devices = if distributed then 2 else 1; local pretrained_model = "resnet_ft"; local training_steps = 500; local validate_every = 50; local image_url = "https://tinyurl.com/2p9xjvn9"; local distributed_dataloader = { batch_size: batch_size, sampler: { type: "torch::DistributedSampler", shuffle: true, drop_last: true, }, collate_fn: {"type": "image_collator"}, }; local single_device_dataloader = { shuffle: true, batch_size: batch_size, collate_fn: {"type": "image_collator"}, }; { steps: { raw_data: { type: "datasets::load", path: "nateraw/auto-cats-and-dogs", name: "cats_and_dogs", }, transform_data: { type: "transform_data", dataset: { type: 'ref', ref: 'raw_data' }, input_size: input_size, val_size: val_size, }, trained_model: { type: "torch::train", model: { type: pretrained_model, num_classes: num_classes, feature_extract: true, use_pretrained: true, }, training_engine: { optimizer: { type: "torch_adam", lr: 0.001, }, }, dataset_dict: {"type": "ref", "ref": "transform_data"}, train_dataloader: single_device_dataloader, validation_split: "val", val_metric_name: "accuracy", train_steps: training_steps, validate_every: validate_every, checkpoint_every: validate_every, log_every: 1, device_count: devices, minimize_val_metric: false, }, prediction: { type: "prediction", image_url: image_url, input_size: input_size, model: {"type": "ref", "ref": "trained_model"}, }, }, } ================================================ FILE: examples/finetune_resnet/resnet_steps.py ================================================ from typing import Any, Dict, List, Optional import datasets import torch from cached_path import cached_path from PIL import Image from torch import nn from torch.optim import Adam from torchvision import models, transforms from tango import Format, JsonFormat, Step from tango.integrations.torch import DataCollator, Model, Optimizer # Register the Adam optimizer as an `Optimizer` so we can use it in the train step. Optimizer.register("torch_adam")(Adam) # Wrapper class around the pre-trained ResNet-18 model that modifies the final layer. @Model.register("resnet_ft") class ResNetWrapper(Model): def __init__(self, num_classes: int, feature_extract: bool, use_pretrained: bool): super().__init__() self.model_ft = models.resnet18(pretrained=use_pretrained) self.set_parameter_requires_grad(self.model_ft, feature_extract) num_features = self.model_ft.fc.in_features self.model_ft.fc = nn.Linear(num_features, num_classes) self.loss_fn = nn.CrossEntropyLoss() def set_parameter_requires_grad(self, model: models, feature_extracting: bool): if feature_extracting: for param in model.parameters(): param.requires_grad = False def forward( # type: ignore self, image: torch.Tensor, label: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: output = self.model_ft(image) preds = torch.argmax(output, dim=1) if label is None: return {"preds": preds} loss = self.loss_fn(output, label) accuracy = (preds == label).float().mean() return {"loss": loss, "accuracy": accuracy} # Custom data collator for images, that takes in a batch of images and labels and # reformats the data so that it is suitable for the model. @DataCollator.register("image_collator") class ImageCollator(DataCollator[Dict[str, Any]]): def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: return { "image": torch.cat([item["image"].unsqueeze(0) for item in batch], dim=0), "label": torch.tensor([item["labels"] for item in batch]), } # Function that returns an image transformations dict with the appropriate image size. def get_data_transforms(input_size: int): data_transforms = { "train": transforms.Compose( [ transforms.RandomResizedCrop(input_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ), "val": transforms.Compose( [ transforms.Resize(input_size), transforms.CenterCrop(input_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ), } return data_transforms # loads and image and applies the appropriate transformation def pil_loader(path: str, input_size: int, transform_type: str): with open(path, "rb") as f: image = Image.open(f) image = image.convert("RGB") transform = get_data_transforms(input_size=input_size)[transform_type] transformed_image = transform(image) return transformed_image # calls the image loader on every image in a given batch def image_loader(example_batch, input_size: int, transform_type: str): example_batch["image"] = [ pil_loader(f, input_size, transform_type) for f in example_batch["file"] ] return example_batch # This step takes in raw image data and transforms and tokenizes it. @Step.register("transform_data") class TransformData(Step): DETERMINISTIC = True CACHEABLE = False def run( # type: ignore self, dataset: datasets.DatasetDict, val_size: float, input_size: int ) -> datasets.DatasetDict: def image_loader_wrapper(example_batch): return image_loader(example_batch, input_size=input_size, transform_type="train") dataset = dataset.with_transform(image_loader_wrapper) train_val = dataset["train"].train_test_split(test_size=val_size) train_val["val"] = train_val.pop("test") return train_val # function to map integer labels to string labels def convert_to_label(int_label: int) -> str: if int_label == 0: return "cat" else: return "dog" @Step.register("prediction") class Prediction(Step): FORMAT: Format = JsonFormat() def run( # type: ignore self, image_url: str, input_size: int, model: models, device: Optional[str] = "cpu" ) -> Dict[str, Any]: # download and store image image_path = cached_path(image_url) transformed_image = pil_loader(image_path, input_size, transform_type="val") # pass image through transform transformed_image = transformed_image.unsqueeze(0).to(device) # pass image through model and get the prediction prediction = model(image=transformed_image, label=None)["preds"][0].float() label = convert_to_label(prediction) return {"image_url": image_url, "local_path": image_path, "label": label} ================================================ FILE: examples/flax/config.jsonnet ================================================ { "steps": { "data": { "type": "datasets::load", "path": "xsum", }, "tokenize": { "type": "tokenize_data", "dataset": { "type": "ref", "ref": "data" } }, "train": { "type": "flax::train", "model": { "type": "transformers::FlaxAutoModelForSeq2SeqLM::from_pretrained", "pretrained_model_name_or_path": "facebook/bart-base" }, "dataset": { "type": "ref", "ref": "tokenize" }, "optimizer": { "type" : "optax::adamw", "learning_rate" : 2e-5 }, "train_dataloader": { "batch_size": 16, "drop_last": true }, "wrapper": { "type": "xsum_wrapper" }, "train_split": "train", "validation_split" : "validation", "validate_every" : 1000, "validation_dataloader": { "batch_size": 16, "drop_last": true }, "train_epoch": 5, "checkpoint_every": 1000, "log_every": 1000, "callbacks" : [ //{"type" : "wandb::log_flax"}, {"type": "flax::generate_step"} ] }, "eval": { "type": "flax::eval", "state": { "type": "ref", "ref": "train" }, "dataset": { "type": "ref", "ref": "tokenize" }, "dataloader": { "batch_size": 16, "drop_last": true }, "wrapper": { "type" : "xsum_wrapper" } } } } ================================================ FILE: examples/flax/run.sh ================================================ #!/bin/bash tango run config.jsonnet -d workspace --include-package xsum ================================================ FILE: examples/flax/xsum.py ================================================ import logging from typing import List, Optional import jax import jax.numpy as jnp import nltk import numpy as np import optax from datasets import load_metric from flax.training.common_utils import onehot from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSeq2SeqLM from tango.integrations.flax import FlaxWrapper from tango.integrations.flax.train_callback import TrainCallback from tango.step import Step """ XSum Summarization with facebook/bart-base """ @Step.register("tokenize_data") class PreProcessing(Step): DETERMINISTIC = False def run(self, dataset): tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") model = FlaxAutoModelForSeq2SeqLM.from_pretrained("facebook/bart-base") model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") config = AutoConfig.from_pretrained("facebook/bart-base") MAX_SOURCE_LENGTH = 512 MAX_TGT_LENGTH = 64 def preprocess_function(examples): inputs = examples["document"] targets = examples["summary"] inputs = [inp for inp in inputs] model_inputs = tokenizer( inputs, max_length=MAX_SOURCE_LENGTH, padding="max_length", truncation=True, return_tensors="np", ) # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer( targets, max_length=MAX_TGT_LENGTH, padding="max_length", truncation=True, return_tensors="np", ) model_inputs["labels"] = labels["input_ids"] decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id ) model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs column_names = dataset["train"].column_names dataset = dataset.map( preprocess_function, batched=True, remove_columns=column_names, desc="Running tokenizer on dataset", ) return dataset @FlaxWrapper.register("xsum_wrapper") # type: ignore class TransformerWrapper(FlaxWrapper): def loss_helper(self, logits, labels, batch): label_smoothing_factor = 0 padding_mask = batch["decoder_attention_mask"] vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.sum() / padding_mask.sum() return loss def train_loss(self, params, state, batch, dropout_rng, labels): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = self.loss_helper(logits, labels, batch) return loss def val_metrics(self, batch, logits, labels): loss = self.loss_helper(logits, labels, batch) metrics = {"loss": loss} return metrics def eval_metrics(self, batch, logits, labels): loss = self.loss_helper(logits, labels, batch) metrics = {"loss": loss} return metrics @TrainCallback.register("flax::generate_step") class GenerateCallback(TrainCallback): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.logger = logging.getLogger(GenerateCallback.__name__) def generate_step(self, params, batch): self.model.params = params gen_kwargs = {"max_length": 64, "num_beams": self.model.config.num_beams} output_ids = self.model.generate( batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs ) return output_ids.sequences def pre_train_loop(self) -> None: if len(jax.devices()) > 1: self.p_generate_step = jax.pmap(self.generate_step, axis_name="batch") def pre_val_loop(self, step: int, val_step: int, state) -> None: self.state = state self.eval_preds: List = [] self.eval_labels: List = [] def pre_val_batch(self, step: int, val_step: int, epoch: int, val_batch) -> None: labels = val_batch["labels"] if len(jax.devices()) > 1: generated_ids = self.p_generate_step(self.state.params, val_batch) else: generated_ids = self.generate_step(self.state.params, val_batch) self.eval_preds.extend(jax.device_get(generated_ids.reshape(-1, 64))) self.eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) def postprocess_text(self, preds, labels): preds = [pred.strip() for pred in preds] labels = [label.strip() for label in labels] # rougeLSum expects newline after each sentence preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] return preds, labels def compute_metrics(self, preds, labels): tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds, decoded_labels = self.postprocess_text(decoded_preds, decoded_labels) metric = load_metric("rouge") result = metric.compute( predictions=decoded_preds, references=decoded_labels, use_stemmer=True ) # Extract a few results from ROUGE result = {key: value.mid.fmeasure * 100 for key, value in result.items()} prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] result["gen_len"] = np.mean(prediction_lens) result = {k: round(v, 4) for k, v in result.items()} return result def post_val_loop( self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float] ) -> None: rouge_metrics = self.compute_metrics(self.eval_preds, self.eval_labels) rouge_desc = " ".join([f"Eval {key}: {value} |" for key, value in rouge_metrics.items()]) self.logger.info(rouge_desc) ================================================ FILE: examples/train_lm/.gitignore ================================================ runs run ================================================ FILE: examples/train_lm/README.md ================================================ # Fine-tuning a language model This Tango example showcases how you could train or fine-tune a causal language model like [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) or [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) from [transformers](https://github.com/huggingface/transformers) on WikiText2 or a similar dataset. It's best that you run this experiment on a machine with a GPU and PyTorch [properly installed](https://pytorch.org/get-started/locally/#start-locally), otherwise Tango will fall back to CPU-only and it will be extremely slow. This example also depends on [FairScale](https://fairscale.readthedocs.io/en/latest/), which allows you to leverage [`FullyShardedDataParallel`](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html) (FSDP) and [activation checkpointing](https://fairscale.readthedocs.io/en/latest/api/nn/checkpoint/checkpoint_activations.html) to fine-tune [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B) or a similar-sized model. Just set the constants `fsdp` and `activation_checkpointing` in the config to `true`. Without using CPU offloading you'll need at least 4 x 40GiB A100 GPUs, or a different configuration with a comparable amount of total GPU memory. To getting started, just run ``` tango run config.jsonnet -i tokenize_step.py ``` ================================================ FILE: examples/train_lm/config.jsonnet ================================================ ################## # Model settings # ################## local pretrained_model = "gpt2"; # With 'fsdp' and 'activation_checkpointing' (see constants below), you should be able to train # a 6B model on 4x ~40GB GPUs: # local pretrained_model = "EleutherAI/gpt-j-6B"; # This doesn't seem to work with gpt2, but works fine with gpt-j. local load_with_low_cpu_mem_usage = std.startsWith(pretrained_model, "EleutherAI/gpt-j"); #################### # Trainer settings # #################### # Trainer settings, adjust to your use-case. local training_steps = 200; # total number of optimization steps to train for local validate_every = 20; # how often to validate and save checkpoints local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU) local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) # This is the batch size per GPU, ignoring gradient accumulation: local batch_size = 8; # So the effective batch size is `batch_size * grad_accum * devices` local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) local amp = false; # use PyTorch's native automatic mixed precision local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. ###################### # Optimizer settings # ###################### local warmup_steps = 20; local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2" # <----- you probably don't need to edit below this line ----> # assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp"; # FullyShardedDataParallel config: local fsdp_config = if fsdp then { reshard_after_forward: true, move_params_to_cpu: cpu_offloading, move_grads_to_cpu: cpu_offloading, mixed_precision: amp, } else null; local training_engine = { type: if fsdp then "fairscale" else "torch", optimizer: { type: "torch::AdamW", lr: learning_rate, betas: [0.9, 0.95], eps: 1e-6, }, lr_scheduler: { type: "transformers::linear", num_warmup_steps: warmup_steps, num_training_steps: training_steps, }, amp: amp, [if fsdp then "fsdp_config" else null]: fsdp_config, }; local distributed_dataloader = { batch_size: batch_size, collate_fn: { type: "transformers::DefaultDataCollator" }, sampler: { type: "torch::DistributedSampler", shuffle: true, drop_last: true, }, }; local single_device_dataloader = { shuffle: true, batch_size: batch_size, collate_fn: { type: "transformers::DefaultDataCollator" }, }; local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; { steps: { raw_data: { type: "datasets::load", path: "wikitext", name: "wikitext-2-raw-v1", }, tokenized_data: { type: "tokenize_data", dataset: { type: "ref", ref: "raw_data" }, tokenizer: { pretrained_model_name_or_path: pretrained_model } }, trained_model: { type: "torch::train", model: { type: "fairscale::with_wrapped_modules", model: { type: "transformers::AutoModelForCausalLM::from_pretrained", pretrained_model_name_or_path: pretrained_model, low_cpu_mem_usage: load_with_low_cpu_mem_usage, }, modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually fsdp_config: fsdp_config, activation_checkpointing: activation_checkpointing, }, dataset_dict: { type: "ref", ref: "tokenized_data" }, train_dataloader: dataloader, validation_split: "validation", grad_accum: grad_accum, train_steps: training_steps, validate_every: validate_every, checkpoint_every: validate_every, log_every: 1, device_count: devices, training_engine: training_engine, }, final_metrics: { type: "torch::eval", model: { type: "ref", ref: "trained_model" }, dataset_dict: { type: "ref", ref: "tokenized_data" }, dataloader: single_device_dataloader, test_split: "test", }, } } ================================================ FILE: examples/train_lm/test.py ================================================ from tango.common import Params from tango.common.testing import run_experiment def test_small_experiment(): model = "sshleifer/tiny-gpt2" dataloader = { "batch_size": 2, "collate_fn": {"type": "transformers::DefaultDataCollator"}, } steps = 4 overrides = { "steps.tokenized_data.block_size": 64, # Override the model in the config with the tiny alternative so training is fast. "steps.tokenized_data.tokenizer.pretrained_model_name_or_path": model, "steps.trained_model.model.model.pretrained_model_name_or_path": model, # Use a small number of training/validation/eval steps. "steps.trained_model.training_engine.lr_scheduler.num_warmup_steps": 1, "steps.trained_model.training_engine.lr_scheduler.num_training_steps": steps, "steps.trained_model.train_steps": steps, "steps.trained_model.validation_steps": 2, "steps.trained_model.validate_every": steps, "steps.final_metrics.eval_steps": 2, "steps.trained_model.checkpoint_every": steps, "steps.trained_model.device_count": 1, # Override data loaders. "steps.trained_model.train_dataloader": dataloader, "steps.trained_model.validation_dataloader": dataloader, "steps.final_metrics.dataloader": dataloader, } # Load the config. config = Params.from_file("config.jsonnet", params_overrides=overrides) # Make sure we've overrode the model entirely. flattened = config.as_flat_dict() for key, value in flattened.items(): if "model_name" in key or (isinstance(value, str) and "gpt" in value): assert value == model with run_experiment(config, include_package=["tokenize_step.py"]) as run_dir: assert (run_dir / "trained_model").is_dir() ================================================ FILE: examples/train_lm/tokenize_step.py ================================================ import datasets from tango import Step from tango.integrations.datasets import DatasetsFormat from tango.integrations.transformers import Tokenizer # We need a step to tokenize the raw data. The result of this step will be passed # directly into the "torch::train" step. @Step.register("tokenize_data") class TokenizeData(Step): DETERMINISTIC = True CACHEABLE = True FORMAT = DatasetsFormat() def run( # type: ignore[override] self, dataset: datasets.DatasetDict, tokenizer: Tokenizer, block_size: int = 1024, num_workers: int = 1, field_to_tokenize: str = "text", ) -> datasets.DatasetDict: def tokenize_function(example): return tokenizer(example[field_to_tokenize]) dataset = dataset.map( tokenize_function, batched=True, num_proc=num_workers, remove_columns=list(dataset.column_names.values())[0], # remove all old columns desc="Tokenizing dataset", ) def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} # type: ignore total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported # it instead of this drop, you can customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result dataset = dataset.map( group_texts, batched=True, num_proc=num_workers, desc=f"Grouping texts into chunks of {block_size}", ) return dataset ================================================ FILE: integration_tests/README.md ================================================ # Integration tests These are a collection of longer running end-to-end tests of various parts of the Tango library. The easiest way to run any of these integration tests is by triggering the [**Integration tests**](https://github.com/allenai/tango/actions/workflows/integration_tests.yml) workflow on GitHub Actions. Just select the "Run workflow" dropdown, then pick the test to run and the Beaker cluster to run it on, and finally hit the "Run workflow" button. Each test should have a `run.sh` file in its folder that will run the relevant tango command. This is what the **Integration tests** workflow will call, and you can also use it to run the test manually. ================================================ FILE: integration_tests/fairscale_benchmarks/README.md ================================================ # FairScale Benchmarks This integration test is for checking the performance of the `FairScaleTrainingEngine` with various configurations. **When to run it:** It should be ran every time there is a major PyTorch or FairScale upgrade. **Where to run it:** A server with 4 A100 GPUs. Make sure you set your `WANDB_API_KEY` environment variable. **How to run it:** From the root directory of this repository, run: ``` integration_tests/fairscale_benchmarks/run.sh ``` By default, not all configurations are run. If you want to run change which configurations are run, open `config.jsonnet` are search for "enabled". Then toggle this `enabled` field to `true` or `false` for each configuration. **What to look for:** The training jobs shouldn't fail, for one. After `tango run` completes, check the corresponding Weights & Biases dashboard and inspect the results. Compare the various "fsdp" training runs with the baseline to ensure you see memory savings. ================================================ FILE: integration_tests/fairscale_benchmarks/config.jsonnet ================================================ ################## # Model settings # ################## local pretrained_model = "gpt2"; # local pretrained_model = "EleutherAI/gpt-j-6B"; # This doesn't seem to work with gpt2, but works fine with gpt-j-6B. local load_with_low_cpu_mem_usage = pretrained_model == "EleutherAI/gpt-j-6B"; #################### # Trainer settings # #################### # Trainer settings, adjust to your use-case. local training_steps = 100; # total number of optimization steps to train for local validate_every = 20; # how often to validate and save checkpoints local devices = 4; local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) # This is the batch size per GPU, ignoring gradient accumulation: local batch_size = 8; # So the effective batch size is `batch_size * grad_accum * devices` ###################### # Optimizer settings # ###################### local warmup_steps = 20; local learning_rate = if pretrained_model == "EleutherAI/gpt-j-6B" then 0.00001 else 0.0001; # <----- you probably don't need to edit below this line ----> # local distributed_dataloader = { batch_size: batch_size, collate_fn: { type: "transformers::DefaultDataCollator" }, sampler: { type: "torch::DistributedSampler", shuffle: true, drop_last: true, }, }; local single_device_dataloader = { shuffle: true, batch_size: batch_size, collate_fn: { type: "transformers::DefaultDataCollator" }, }; local TrainStep(options) = local training_engine = { type: if options.fsdp_config != null then "fairscale" else "torch", optimizer: { type: "torch::AdamW", lr: learning_rate, betas: [0.9, 0.95], eps: 1e-6, }, lr_scheduler: { type: "transformers::linear", num_warmup_steps: warmup_steps, num_training_steps: training_steps, }, amp: options.amp, [if options.fsdp_config != null then "fsdp_config" else null]: options.fsdp_config, }; { type: "torch::train", model: { type: "fairscale::with_wrapped_modules", model: { type: "transformers::AutoModelForCausalLM::from_pretrained", pretrained_model_name_or_path: pretrained_model, low_cpu_mem_usage: load_with_low_cpu_mem_usage, }, modules_to_wrap: ["transformer\\.h\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually fsdp_config: options.fsdp_config, activation_checkpointing: options.activation_checkpointing, }, dataset_dict: { type: "ref", ref: "tokenized_data" }, train_dataloader: distributed_dataloader, validation_split: "validation", grad_accum: grad_accum, train_steps: training_steps, validate_every: validate_every, checkpoint_every: validate_every, log_every: 1, device_count: devices, training_engine: training_engine, callbacks: [ { type: "wandb::log", entity: "allennlp", project: "tango-fairscale-benchmarks", wandb_config: options + { effective_batch_size: batch_size * devices * grad_accum, model: pretrained_model, }, }, ], }; { steps: { raw_data: { type: "datasets::load", path: "wikitext", name: "wikitext-2-raw-v1", }, tokenized_data: { type: "tokenize_data", dataset: { type: "ref", ref: "raw_data" }, tokenizer: { pretrained_model_name_or_path: pretrained_model } }, } + { ["trained_model_" + options.name]: TrainStep(options) for options in [ # NOTE: With 6B model, baseline and many others will fail with CUDA OOM. # FSDP and activation checkpointing will be required for a 6B model. { name: "baseline", enabled: false, amp: false, fsdp_config: null, activation_checkpointing: false, }, { name: "amp", enabled: false, amp: true, fsdp_config: null, activation_checkpointing: false, }, { name: "checkpointing", enabled: false, amp: false, fsdp_config: null, activation_checkpointing: true, }, { name: "amp_and_checkpointing", enabled: false, amp: true, fsdp_config: null, activation_checkpointing: true, }, { name: "fsdp", enabled: false, amp: false, activation_checkpointing: false, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: false, }, }, { name: "fsdp_no_reshard", enabled: false, amp: false, activation_checkpointing: false, fsdp_config: { reshard_after_forward: false, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: false, }, }, { name: "amp_and_fsdp", enabled: false, amp: true, activation_checkpointing: false, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: false, }, }, { name: "amp_and_fsdp_no_reshard", enabled: false, amp: true, activation_checkpointing: false, fsdp_config: { reshard_after_forward: false, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: false, }, }, { name: "amp_and_fsdp_mp", enabled: false, amp: true, activation_checkpointing: false, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: true, }, }, { name: "amp_and_fsdp_mp_no_reshard", enabled: false, amp: true, activation_checkpointing: false, fsdp_config: { reshard_after_forward: false, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: true, }, }, { name: "checkpointing_and_fsdp", enabled: false, amp: false, activation_checkpointing: true, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: false, }, }, { name: "amp_and_checkpointing_and_fsdp", enabled: false, amp: true, activation_checkpointing: true, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: false, }, }, { name: "amp_and_checkpointing_and_fsdp_mp", enabled: true, amp: true, activation_checkpointing: true, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: true, }, }, { name: "checkpointing_and_fsdp_mp", enabled: false, amp: false, activation_checkpointing: true, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: false, move_grads_to_cpu: false, mixed_precision: true, }, }, { # This configuration currently does not work. Tracking https://github.com/facebookresearch/fairscale/issues/918 name: "amp_and_checkpointing_and_fsdp_mp_with_partial_offloading", enabled: false, amp: true, activation_checkpointing: true, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: true, move_grads_to_cpu: false, mixed_precision: true, }, }, { name: "amp_and_checkpointing_and_fsdp_mp_with_full_offloading", enabled: false, amp: true, activation_checkpointing: true, fsdp_config: { reshard_after_forward: true, move_params_to_cpu: true, move_grads_to_cpu: true, mixed_precision: true, }, }, ] if options.enabled } } ================================================ FILE: integration_tests/fairscale_benchmarks/run.sh ================================================ #!/bin/sh tango run integration_tests/fairscale_benchmarks/config.jsonnet -i examples/train_lm/tokenize_step.py ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [project] name = "ai2-tango" dynamic = ["version"] readme = "README.md" description = "A library for choreographing your machine learning research." classifiers=[ "Intended Audience :: Science/Research", "Development Status :: 3 - Alpha", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", "Topic :: Scientific/Engineering :: Artificial Intelligence", ] authors = [ {name = "Allen Institute for Artificial Intelligence", email = "contact@allenai.org"} ] license = {file = "LICENSE"} requires-python = ">=3.8.1" dependencies = [ "cached-path>=1.0,<2.0", "rjsonnet>=0.5.0", "GitPython>=3.0,<4.0", "PyYAML>=5.4.1,<7.0", "dill", "base58", "xxhash", "filelock>=3.4,<4.0", "click>=8.0,<8.1.4", "click-help-colors>=0.9.1,<0.10", "rich>=12.3,<14.0", "tqdm>=4.62,<5.0", "more-itertools>=8.0,<11.0", "sqlitedict", "glob2>=0.7", "petname>=2.6,<3.0", "pytz" ] [project.optional-dependencies] dev = [ "ruff", "mypy==1.2.0", "types-PyYAML", "types-setuptools", "types-pytz", "types-retry", "black==23.3.0", "isort==5.12.0", "pytest", "pytest-sphinx", "flaky", "twine>=1.11.0", "setuptools", "wheel", "build", "Sphinx==5.3.0", "furo==2023.3.27", "myst-parser==1.0.0", "sphinx-copybutton==0.5.2", "sphinx-autobuild==2021.3.14", "sphinx-autodoc-typehints<=1.23.0", "packaging" ] examples = [ "torchmetrics>=0.7.0" ] torch = [ "torch>=1.9,<2.1", "numpy", ] transformers = [ "torch>=1.9,<2.1", "numpy", "datasets>=1.12,<3.0", "transformers>=4.12.3", "sentencepiece==0.1.98", "sacremoses" ] datasets = [ "datasets>=1.12,<3.0" ] fairscale = [ "torch>=1.9,<2.1", "numpy", "fairscale>=0.4.6,<0.5" ] flax = [ "datasets>=1.12,<3.0", "jax", "jaxlib", "flax", "optax", "tensorflow-cpu>=2.9.1" ] wandb = [ "wandb>=0.16", "retry" ] beaker = [ "beaker-py>=1.14.0,<2.0" ] gs = [ "google-cloud-storage>=2.6.0", "google-cloud-datastore>=2.12.0" ] all = [ "ai2-tango[examples,torch,transformers,datasets,fairscale,flax,wandb,beaker,gs]" ] [project.scripts] tango = "tango.__main__:main" [project.urls] homepage = "https://github.com/allenai/tango" repository = "https://github.com/allenai/tango" [tool.setuptools.packages.find] exclude = [ "*.tests", "*.tests.*", "tests.*", "tests", "test_fixtures", "test_fixtures.*", "docs*", "scripts*", "examples*" ] [tool.setuptools.package-data] tango = ["py.typed"] "tango.integrations.beaker" = ["*.sh"] [tool.setuptools.dynamic] version = {attr = "tango.version.VERSION"} [tool.black] line-length = 100 include = '\.pyi?$' exclude = ''' ( __pycache__ | \.git | \.mypy_cache | \.pytest_cache | \.vscode | \.venv | \bdist\b | \bdoc\b ) ''' [tool.isort] profile = "black" multi_line_output = 3 [tool.ruff] line-length = 115 select = ["E"] exclude = [ ".venv", ".git", "__pycache__", ".mypy_cache", "docs/build", "dist" ] [tool.ruff.per-file-ignores] "__init__.py" = ["F401"] "*/**/**/__init__.py" = ["F401","E501"] [tool.mypy] ignore_missing_imports = true no_site_packages = false allow_redefinition = true check_untyped_defs = true [[tool.mypy.overrides]] module = "tests.*" strict_optional = false disable_error_code = [ "var-annotated", "no-redef", "dict-item" ] allow_redefinition = true [tool.pytest.ini_options] testpaths = "tests/" python_classes = [ "Test*", "*Test" ] log_format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" log_level = "DEBUG" markers = [ "gpu: marks tests that need GPUs" ] filterwarnings = [ 'ignore:.*Consider increasing the value of the `num_workers` argument.*:UserWarning:pytorch_lightning\.trainer\.data_loading', 'ignore:.*you defined a validation_step but have no val_dataloader.*:UserWarning:pytorch_lightning\.trainer\.configuration_validator', 'ignore::UserWarning:tango\.*', 'ignore::DeprecationWarning:pkg_resources', 'ignore::DeprecationWarning:google\.rpc' ] doctest_optionflags = "NORMALIZE_WHITESPACE" ================================================ FILE: scripts/entrypoint.sh ================================================ #!/bin/bash # Exit script if any commands fail. set -e set -o pipefail # Check that the environment variable has been set correctly if [ -z "$COMMIT_SHA" ]; then echo >&2 'error: missing COMMIT_SHA environment variable' exit 1 fi # Upgrade pip /opt/conda/bin/pip install --upgrade pip # Clone and install tango. git clone https://github.com/allenai/tango.git cd tango git checkout --quiet "$COMMIT_SHA" /opt/conda/bin/pip install --no-cache-dir '.[dev,all]' # Create directory for results. mkdir -p /results # Execute the arguments to this script as commands themselves, piping output into a log file. exec "$@" 2>&1 | tee /results/out.log ================================================ FILE: scripts/hash_extras.py ================================================ """ Used in CI to create a unique ID for any set of install extras. """ import sys def main(): extras = sys.argv[1] print("-".join(sorted(extras.split(",")))) if __name__ == "__main__": main() ================================================ FILE: scripts/prepare_changelog.py ================================================ from datetime import datetime from pathlib import Path from tango.version import VERSION def main(): changelog = Path("CHANGELOG.md") with changelog.open() as f: lines = f.readlines() insert_index: int for i in range(len(lines)): line = lines[i] if line.startswith("## Unreleased"): insert_index = i + 1 elif line.startswith(f"## [v{VERSION}]"): print("CHANGELOG already up-to-date") return elif line.startswith("## [v"): break else: raise RuntimeError("Couldn't find 'Unreleased' section") lines.insert(insert_index, "\n") lines.insert( insert_index + 1, f"## [v{VERSION}](https://github.com/allenai/tango/releases/tag/v{VERSION}) - " f"{datetime.now().strftime('%Y-%m-%d')}\n", ) with changelog.open("w") as f: f.writelines(lines) if __name__ == "__main__": main() ================================================ FILE: scripts/prepare_citation_cff.py ================================================ from datetime import datetime from pathlib import Path from tango.version import VERSION def main(): citation = Path("CITATION.cff") with citation.open() as f: lines = f.readlines() for i in range(len(lines)): line = lines[i] if line.startswith("version:"): lines[i] = f'version: "{VERSION}"\n' elif line.startswith("date-released:"): lines[i] = f'date-released: "{datetime.now().strftime("%Y-%m-%d")}"\n' with citation.open("w") as f: f.writelines(lines) if __name__ == "__main__": main() ================================================ FILE: scripts/release.sh ================================================ #!/bin/bash set -e TAG=$(python -c 'from tango.version import VERSION; print("v" + VERSION)') read -p "Creating new release for $TAG. Do you want to continue? [Y/n] " prompt if [[ $prompt == "y" || $prompt == "Y" || $prompt == "yes" || $prompt == "Yes" ]]; then python scripts/prepare_changelog.py python scripts/prepare_citation_cff.py git add -A git commit -m "Prepare for release $TAG" || true && git push echo "Creating new git tag $TAG" git tag "$TAG" -m "$TAG" git push --tags else echo "Cancelled" exit 1 fi ================================================ FILE: scripts/release_notes.py ================================================ # encoding: utf-8 """ Prepares markdown release notes for GitHub releases. """ import os from typing import List, Optional import packaging.version TAG = os.environ["TAG"] ADDED_HEADER = "### Added 🎉" CHANGED_HEADER = "### Changed ⚠️" FIXED_HEADER = "### Fixed ✅" REMOVED_HEADER = "### Removed 👋" def get_change_log_notes() -> str: in_current_section = False current_section_notes: List[str] = [] with open("CHANGELOG.md") as changelog: for line in changelog: if line.startswith("## "): if line.startswith("## Unreleased"): continue if line.startswith(f"## [{TAG}]"): in_current_section = True continue break if in_current_section: if line.startswith("### Added"): line = ADDED_HEADER + "\n" elif line.startswith("### Changed"): line = CHANGED_HEADER + "\n" elif line.startswith("### Fixed"): line = FIXED_HEADER + "\n" elif line.startswith("### Removed"): line = REMOVED_HEADER + "\n" current_section_notes.append(line) assert current_section_notes return "## What's new\n\n" + "".join(current_section_notes).strip() + "\n" def get_commit_history() -> str: new_version = packaging.version.parse(TAG) os.popen("git fetch --tags") # Get all tags sorted by version, latest first. all_tags = os.popen("git tag -l --sort=-version:refname 'v*'").read().split("\n") # Out of `all_tags`, find the latest previous version so that we can collect all # commits between that version and the new version we're about to publish. # Note that we ignore pre-releases unless the new version is also a pre-release. last_tag: Optional[str] = None for tag in all_tags: if not tag.strip(): # could be blank line continue version = packaging.version.parse(tag) if new_version.pre is None and version.pre is not None: continue if version < new_version: last_tag = tag break if last_tag is not None: commits = os.popen(f"git log {last_tag}..{TAG}^ --oneline --first-parent").read() else: commits = os.popen("git log --oneline --first-parent").read() return "## Commits\n\n" + commits def main(): print(get_change_log_notes()) print(get_commit_history()) if __name__ == "__main__": main() ================================================ FILE: tango/__init__.py ================================================ """ A Python library for choreographing your machine learning research. """ __all__ = [ "cleanup_cli", "DillFormat", "DillFormatIterator", "execute_step_graph", "Executor", "Format", "initialize_cli", "JsonFormat", "JsonFormatIterator", "load_settings", "prepare_executor", "prepare_workspace", "Run", "RunInfo", "RunSort", "SqliteDictFormat", "Step", "step", "StepCache", "StepGraph", "StepInfo", "StepInfoSort", "StepResources", "StepState", "tango_cli", "Workspace", ] from .cli import ( cleanup_cli, execute_step_graph, initialize_cli, load_settings, prepare_executor, prepare_workspace, tango_cli, ) from .executor import Executor from .format import ( DillFormat, DillFormatIterator, Format, JsonFormat, JsonFormatIterator, SqliteDictFormat, ) from .step import Step, StepResources, step from .step_cache import StepCache from .step_graph import StepGraph from .step_info import StepInfo, StepState from .workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace ================================================ FILE: tango/__main__.py ================================================ """ The Tango CLI is the recommended tool to run experiments with. It also comes with several other useful commands. You can see the the list of all available commands by running: .. code-block:: $ tango --help .. testcode:: :hide: import subprocess output = subprocess.run("tango --help".split(" "), capture_output=True) output.check_returncode() print(output.stdout.decode().replace("\\n\\n", "\\n").strip()) .. testoutput:: Usage: tango [OPTIONS] COMMAND [ARGS]... Options: --version Show the version and exit. --settings FILE Path to a global tango.yml settings file. --log-level [debug|info|warning|error] Set the global log level. --file-friendly-logging Outputs progress bar status on separate lines and slows refresh rate. --start-method [fork|spawn|forkserver] Set the multiprocessing start method. --help Show this message and exit. Commands: info Get info about the current tango installation. run Run a tango experiment. settings Commands for initializing and updating global settings. To see all of the available arguments and options for a particular command, run .. code-block:: $ tango [COMMAND] --help For example, .. code-block:: $ tango run --help ``tango run`` ------------- The ``run`` command is used to execute a tango experiment from an experiment configuration file. See the `Configuration files `_ section in the overview for a quick introduction to the format. ``tango info`` -------------- The ``info`` command just prints out some useful information about the current tango installation, such as which integrations are available. ``tango settings`` ------------------ The ``settings`` group of commands can be used to initialize a :class:`~tango.settings.TangoGlobalSettings` file or update fields in it. """ import os from pathlib import Path from typing import Dict, List, NamedTuple, Optional, Sequence, Union import click from click_help_colors import HelpColorsCommand, HelpColorsGroup from tango.cli import ( cleanup_cli, execute_step_graph, initialize_cli, load_settings, prepare_executor, prepare_workspace, ) from tango.common.exceptions import CliRunError, IntegrationMissingError from tango.common.logging import cli_logger, initialize_logging from tango.common.params import Params from tango.common.util import ( find_integrations, import_extra_module, import_module_and_submodules, ) from tango.settings import TangoGlobalSettings from tango.step_graph import StepGraph from tango.version import VERSION from tango.workspace import Workspace _CLICK_GROUP_DEFAULTS = { "cls": HelpColorsGroup, "help_options_color": "green", "help_headers_color": "yellow", "context_settings": {"max_content_width": 115}, } _CLICK_COMMAND_DEFAULTS = { "cls": HelpColorsCommand, "help_options_color": "green", "help_headers_color": "yellow", "context_settings": {"max_content_width": 115}, } class SettingsObject(NamedTuple): settings: TangoGlobalSettings called_by_executor: bool @click.group(name=None, **_CLICK_GROUP_DEFAULTS) @click.version_option(version=VERSION) @click.option( "--settings", type=click.Path(exists=True, dir_okay=False, resolve_path=True), help="Path to a global tango.yml settings file.", ) @click.option( "--log-level", help="Set the global log level.", type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False), show_choices=True, ) @click.option( "--file-friendly-logging", is_flag=True, help="Outputs progress bar status on separate lines and slows refresh rate.", ) @click.option( "--start-method", help="Set the multiprocessing start method.", type=click.Choice(["fork", "spawn", "forkserver"], case_sensitive=True), show_choices=True, ) @click.option( "--called-by-executor", is_flag=True, hidden=True, ) @click.pass_context def main( ctx, settings: Optional[str] = None, log_level: Optional[str] = None, file_friendly_logging: bool = False, start_method: Optional[str] = None, called_by_executor: bool = False, ): settings: TangoGlobalSettings = load_settings(settings) if start_method is not None: settings.multiprocessing_start_method = start_method if log_level is not None: settings.log_level = log_level if file_friendly_logging: settings.file_friendly_logging = file_friendly_logging ctx.obj = SettingsObject(settings, called_by_executor) initialize_cli(settings=settings, called_by_executor=called_by_executor) @main.result_callback() def cleanup(*args, **kwargs): cleanup_cli() @main.command(**_CLICK_COMMAND_DEFAULTS) @click.argument( "experiment", type=click.Path(exists=True, dir_okay=False, resolve_path=True), ) @click.option( "-w", "--workspace", type=click.Path(file_okay=False), help="""A workspace path or URL. If not specified, the workspace from any global tango settings file will be used, if found, otherwise an ephemeral MemoryWorkspace.""", default=None, ) @click.option( "-d", "--workspace-dir", type=click.Path(file_okay=False), default=None, hidden=True, ) @click.option( "-o", "--overrides", type=str, help="""A JSON(NET) string used to override fields in the experiment config. Use dot syntax to specify nested fields.""", ) @click.option( "-i", "--include-package", type=str, help="Python packages or modules to import for tango components.", multiple=True, ) @click.option( "-j", "--parallelism", type=int, help="""The maximum number of steps to run in parallel (for executors that support this). The exact behavior depends on the executor. If you're using the default executors, a value of 0 (or left unspecified) means each step is run in the main process using the default executor, otherwise the multicore executor is used.""", ) @click.option( "-s", "--step-name", help="Execute a particular step (and its dependencies) in the experiment.", multiple=True, ) @click.option( "-n", "--name", type=str, help="""Specify the name for this run.""", ) @click.option( "-D", "--ext-var", type=str, help="""JSONNET external variables to use when loading the experiment config. For example, --ext-var 'pretrained_model=gpt2'.""", multiple=True, ) @click.pass_obj def run( obj: SettingsObject, experiment: str, workspace: Optional[str] = None, workspace_dir: Optional[Union[str, os.PathLike]] = None, overrides: Optional[str] = None, include_package: Optional[Sequence[str]] = None, parallelism: Optional[int] = None, step_name: Optional[Sequence[str]] = None, name: Optional[str] = None, ext_var: Optional[Sequence[str]] = None, ): """ Run a tango experiment. EXPERIMENT is the path to experiment's JSON/Jsonnet/YAML configuration file. """ if workspace_dir is not None: import warnings warnings.warn( "-d/--workspace-dir option is deprecated. Please use -w/--workspace instead.", DeprecationWarning, ) if workspace is not None: raise click.ClickException( "-w/--workspace is mutually exclusive with -d/--workspace-dir" ) workspace = "local://" + str(workspace_dir) _run( obj.settings, experiment, workspace_url=workspace, overrides=overrides, include_package=include_package, parallelism=parallelism, step_names=step_name, name=name, called_by_executor=obj.called_by_executor, ext_var=ext_var, ) @main.command(hidden=True) @click.argument( "experiment", type=click.Path(exists=True, dir_okay=False, resolve_path=True), ) @click.argument( "step_name", type=str, ) @click.argument( "workspace_url", type=str, ) @click.option( "-i", "--include-package", type=str, help="Python packages or modules to import for tango components.", multiple=True, ) @click.option( "--log-level", help="Set the global log level.", type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False), show_choices=True, ) def beaker_executor_run( experiment: str, step_name: str, workspace_url: str, include_package: Optional[Sequence[str]] = None, log_level: str = "debug", ): """ This command is only used internally by the BeakerExecutor. """ from tango.executor import Executor if include_package: for package_name in include_package: import_extra_module(package_name) # Load step graph and step. step_graph = StepGraph.from_file(experiment) step = step_graph[step_name] # Initialize workspace and executor. # NOTE: We use the default executor here because we're just running the step # locally in the main process. workspace = Workspace.from_url(workspace_url) executor = Executor(workspace=workspace, include_package=include_package) # Initialize logging. initialize_logging(log_level=log_level, enable_cli_logs=True, file_friendly_logging=True) # Run step. executor.execute_step(step) @main.command(**_CLICK_COMMAND_DEFAULTS) @click.pass_obj def info(obj: SettingsObject): """ Get info about the current tango installation. """ import platform cli_logger.info("Tango version %s (python %s)", VERSION, platform.python_version()) cli_logger.info("") # Show info about settings. if obj.settings.path is not None: cli_logger.info("[underline]Settings:[/]") cli_logger.info("[green] \N{check mark} Loaded from %s[/]", obj.settings.path) if obj.settings.include_package: cli_logger.info(" Included packages:") for package in obj.settings.include_package: is_found = True try: import_module_and_submodules(package) except (ModuleNotFoundError, ImportError): is_found = False if is_found: cli_logger.info(" [green]\N{check mark} %s[/]", package) else: cli_logger.info(" [red]\N{ballot x} %s (not found)[/]", package) cli_logger.info("") # Show info about integrations. cli_logger.info("[underline]Integrations:[/]") for integration in find_integrations(): name = integration.split(".")[-1] is_installed = True try: import_module_and_submodules(integration, recursive=False) except (IntegrationMissingError, ModuleNotFoundError, ImportError): is_installed = False if is_installed: cli_logger.info(" [green]\N{check mark} %s[/]", name) else: cli_logger.info(" [yellow]\N{ballot x} %s (not installed)[/]", name) @main.group(**_CLICK_GROUP_DEFAULTS) @click.pass_obj def settings(ctx): """ Commands for initializing and updating global settings. """ @settings.command(**_CLICK_COMMAND_DEFAULTS) @click.option( "-p", "--path", type=click.Path(exists=False, dir_okay=False, resolve_path=True), default=None, help="""The path to write the settings to.""", ) @click.option( "-f", "--force", is_flag=True, help="""Force overwrite the file if it exists.""", ) @click.pass_obj def init(obj: SettingsObject, path: Optional[str] = None, force: bool = False): """ Initialize the settings file. """ path_to_write = Path(path or TangoGlobalSettings._DEFAULT_LOCATION) if path_to_write.is_file() and not force: raise click.ClickException("Settings file already exists! Use -f/--force to overwrite it.") obj.settings.to_file(path_to_write) cli_logger.info( "[green]\N{check mark} Settings file written to [bold]%s[/bold][/green]", path_to_write ) @settings.group(name="set", **_CLICK_GROUP_DEFAULTS) @click.pass_obj def set_setting(obj: SettingsObject): """ Set a value in the settings file. """ if obj.settings.path is None: raise click.ClickException( "Settings file not found! Did you forget to call 'tango settings init'?" ) @set_setting.result_callback() def save_settings(settings: TangoGlobalSettings): settings.save() @set_setting.command(**_CLICK_COMMAND_DEFAULTS) @click.argument( "workspace", type=str, ) @click.option( "--validate/--no-validate", type=bool, help="Validate that the workspace can be initialized.", default=True, ) @click.pass_obj def workspace(obj: SettingsObject, workspace: str, validate: bool = True) -> TangoGlobalSettings: """ Set the default workspace path or URL. """ from urllib.parse import urlparse if not urlparse(workspace).scheme: obj.settings.workspace = {"type": "local", "dir": str(Path(workspace).resolve())} else: obj.settings.workspace = {"type": "from_url", "url": workspace} if validate: for package_name in obj.settings.include_package or []: import_extra_module(package_name) Workspace.from_params(obj.settings.workspace.copy()) return obj.settings @set_setting.command(**_CLICK_COMMAND_DEFAULTS) @click.argument( "packages", type=str, nargs=-1, ) @click.option( "-a", "--append", is_flag=True, help="Appends packages instead of overwriting.", ) @click.option( "--validate/--no-validate", type=bool, help="Validate that the workspace can be initialized.", default=True, ) @click.pass_obj def include_package( obj: SettingsObject, packages: List[str], append: bool = False, validate: bool = True, ) -> TangoGlobalSettings: """ Set or add modules to automatically import on 'tango run'. """ new_include: List[str] if append: new_include = obj.settings.include_package or [] else: new_include = [] for package in packages: if package not in new_include: new_include.append(package) obj.settings.include_package = new_include if validate: for package in obj.settings.include_package: try: import_module_and_submodules(package) except (ModuleNotFoundError, ImportError): raise click.ClickException(f"Failed to import '{package}'") return obj.settings @set_setting.command(**_CLICK_COMMAND_DEFAULTS) @click.argument( "level", type=click.Choice(["debug", "info", "warning", "error"], case_sensitive=False), ) @click.pass_obj def log_level(obj: SettingsObject, level: str) -> TangoGlobalSettings: """ Set the log level. """ obj.settings.log_level = level.lower() return obj.settings @set_setting.command(**_CLICK_COMMAND_DEFAULTS) @click.argument( "value", type=bool, ) @click.pass_obj def file_friendly_logging(obj: SettingsObject, value: bool) -> TangoGlobalSettings: """ Toggle file friendly logging mode. """ obj.settings.file_friendly_logging = value return obj.settings @set_setting.command(**_CLICK_COMMAND_DEFAULTS) @click.argument( "start_method", type=click.Choice(["fork", "spawn", "forkserver"], case_sensitive=True), ) @click.pass_obj def multiprocessing_start_method(obj: SettingsObject, start_method: str) -> TangoGlobalSettings: """ Set the Python multiprocessing start method. """ obj.settings.multiprocessing_start_method = start_method return obj.settings @set_setting.command(**_CLICK_COMMAND_DEFAULTS) @click.argument( "key", type=str, ) @click.argument( "value", type=str, ) @click.pass_obj def env(obj: SettingsObject, key: str, value: str) -> TangoGlobalSettings: """ Add or update an environment variable. """ from tango.common.aliases import EnvVarNames # These environment variables should not be set this way since they'll be ignored. blocked_env_variable_names = EnvVarNames.values() if key in blocked_env_variable_names: raise click.ClickException( f"Cannot add environment variable '{key}' to settings. " f"Please set the corresponding settings field instead." ) if obj.settings.environment is None: obj.settings.environment = {} obj.settings.environment[key] = value return obj.settings def _run( settings: TangoGlobalSettings, experiment: str, workspace_url: Optional[str] = None, overrides: Optional[str] = None, include_package: Optional[Sequence[str]] = None, step_names: Optional[Sequence[str]] = None, parallelism: Optional[int] = None, multicore: Optional[bool] = None, name: Optional[str] = None, called_by_executor: bool = False, ext_var: Optional[Sequence[str]] = None, ) -> str: # Read params. ext_vars: Dict[str, str] = {} for var in ext_var or []: try: key, value = var.split("=") except ValueError: raise CliRunError(f"Invalid --ext-var '{var}'") ext_vars[key] = value params = Params.from_file(experiment, params_overrides=overrides or "", ext_vars=ext_vars) # Import included packages to find registered components. # NOTE: The Executor imports these as well because it's meant to be used # directly, but we also need to import here in case the user is using a # custom Executor, StepCache, or Workspace. include_package: List[str] = list(include_package or []) include_package += params.pop("include_package", []) include_package += settings.include_package or [] for package_name in include_package: import_extra_module(package_name) # Initialize step graph. step_graph: StepGraph = StepGraph.from_params(params.pop("steps")) params.assert_empty("'tango run'") if step_names: for step_name in step_names: assert step_name in step_graph, ( f"You want to run a step called '{step_name}', but it cannot be found in the experiment config. " f"The config contains: {list(step_graph.keys())}." ) step_graph = step_graph.sub_graph(*step_names) # Execute step graph in workspace workspace = prepare_workspace(settings=settings, workspace_url=workspace_url) executor = prepare_executor( settings=settings, workspace=workspace, include_package=include_package, parallelism=parallelism, multicore=multicore, called_by_executor=called_by_executor, ) run_name = execute_step_graph( step_graph=step_graph, workspace=workspace, executor=executor, name=name, called_by_executor=called_by_executor, step_names=step_names, ) return run_name if __name__ == "__main__": main() ================================================ FILE: tango/cli.py ================================================ import logging import multiprocessing as mp import os import sys import warnings from contextlib import contextmanager, nullcontext from typing import TYPE_CHECKING, Optional, Sequence, Union from tango.common.exceptions import CliRunError from tango.common.logging import ( cli_logger, initialize_logging, initialize_prefix_logging, teardown_logging, ) from tango.common.params import Params from tango.executor import Executor from tango.settings import TangoGlobalSettings from tango.step_graph import StepGraph from tango.workspace import Workspace if TYPE_CHECKING: from tango.executor import ExecutorOutput from tango.workspace import Run logger = logging.getLogger(__name__) def load_settings(settings: Union[str, Params, dict, None] = None) -> TangoGlobalSettings: return ( TangoGlobalSettings.from_file(settings) if isinstance(settings, str) else TangoGlobalSettings.from_params(settings) if isinstance(settings, (Params, dict)) else TangoGlobalSettings.default() ) @contextmanager def tango_cli(settings: Union[TangoGlobalSettings, str, Params, dict, None] = None): if not isinstance(settings, TangoGlobalSettings): settings = load_settings(settings) try: initialize_cli(settings=settings, called_by_executor=False) yield finally: cleanup_cli() def initialize_cli( settings: Optional[TangoGlobalSettings] = None, called_by_executor: bool = False, ): if settings is None: settings = TangoGlobalSettings.default() if not sys.warnoptions: warnings.simplefilter("default", category=DeprecationWarning) if settings.environment: from tango.common.aliases import EnvVarNames # These environment variables should not be set this way since they'll be ignored. blocked_env_variable_names = EnvVarNames.values() for key, value in settings.environment.items(): if key not in blocked_env_variable_names: os.environ[key] = value else: warnings.warn( f"Ignoring environment variable '{key}' from settings file. " f"Please use the corresponding settings field instead.", UserWarning, ) mp.set_start_method(settings.multiprocessing_start_method) if not called_by_executor: initialize_logging( log_level=settings.log_level, file_friendly_logging=settings.file_friendly_logging, enable_cli_logs=True, ) def cleanup_cli(): teardown_logging() def prepare_workspace( settings: Optional[TangoGlobalSettings] = None, workspace_url: Optional[str] = None, ) -> Workspace: from tango.workspaces import default_workspace if settings is None: settings = TangoGlobalSettings.default() workspace: Workspace if workspace_url is not None: workspace = Workspace.from_url(workspace_url) elif settings.workspace is not None: workspace = Workspace.from_params(settings.workspace) else: workspace = default_workspace return workspace def prepare_executor( workspace: Workspace, settings: Optional[TangoGlobalSettings] = None, include_package: Optional[Sequence[str]] = None, parallelism: Optional[int] = None, multicore: Optional[bool] = None, called_by_executor: bool = False, ) -> Executor: from tango.executors import MulticoreExecutor from tango.workspaces import MemoryWorkspace if settings is None: settings = TangoGlobalSettings.default() executor: Executor if not called_by_executor and settings.executor is not None: if multicore is not None: logger.warning( "Ignoring argument 'multicore' since executor is defined in %s", settings.path or "setting", ) executor = Executor.from_params( settings.executor, workspace=workspace, include_package=include_package, **(dict(parallelism=parallelism) if parallelism is not None else {}), # type: ignore ) else: # Determine if we can use the multicore executor. if multicore is None: if isinstance(workspace, MemoryWorkspace): # Memory workspace does not work with multiple cores. multicore = False elif "pydevd" in sys.modules: # Pydevd doesn't reliably follow child processes, so we disable multicore under the debugger. logger.warning("Debugger detected, disabling multicore.") multicore = False elif parallelism is None or parallelism == 0: multicore = False else: multicore = True if multicore: executor = MulticoreExecutor( workspace=workspace, include_package=include_package, parallelism=parallelism ) else: executor = Executor(workspace=workspace, include_package=include_package) return executor def execute_step_graph( step_graph: StepGraph, workspace: Optional[Workspace] = None, executor: Optional[Executor] = None, name: Optional[str] = None, called_by_executor: bool = False, step_names: Optional[Sequence[str]] = None, ) -> str: if workspace is None: workspace = prepare_workspace() executor = prepare_executor(workspace=workspace) elif executor is None: executor = prepare_executor(workspace=workspace) # Register run. run: "Run" if called_by_executor and name is not None: try: run = workspace.registered_run(name) except KeyError: raise RuntimeError( "The CLI was called by `MulticoreExecutor.execute_step_graph`, but " f"'{name}' is not already registered as a run. This should never happen!" ) else: run = workspace.register_run((step for step in step_graph.values()), name) if called_by_executor: assert step_names is not None and len(step_names) == 1 from tango.common.aliases import EnvVarNames # We set this environment variable so that any steps that contain multiprocessing # and call `initialize_worker_logging` also log the messages with the `step_name` prefix. os.environ[EnvVarNames.LOGGING_PREFIX.value] = f"step {step_names[0]}" initialize_prefix_logging(prefix=f"step {step_names[0]}", main_process=False) # Capture logs to file. with workspace.capture_logs_for_run(run.name) if not called_by_executor else nullcontext(): if not called_by_executor: cli_logger.info("[green]Starting new run [bold]%s[/][/]", run.name) executor_output: ExecutorOutput = executor.execute_step_graph(step_graph, run_name=run.name) if executor_output.failed: cli_logger.error("[red]\N{ballot x} Run [bold]%s[/] finished with errors[/]", run.name) elif not called_by_executor: cli_logger.info("[green]\N{check mark} Finished run [bold]%s[/][/]", run.name) if executor_output is not None: if not called_by_executor: executor_output.display() if executor_output.failed: raise CliRunError return run.name ================================================ FILE: tango/common/__init__.py ================================================ from .aliases import PathOrStr from .dataset_dict import DatasetDict, DatasetDictBase, IterableDatasetDict from .det_hash import det_hash from .from_params import FromParams from .lazy import Lazy from .params import Params from .registrable import Registrable, RegistrableFunction, make_registrable from .tqdm import Tqdm from .util import filename_is_safe, threaded_generator __all__ = [ "PathOrStr", "DatasetDictBase", "DatasetDict", "IterableDatasetDict", "det_hash", "Params", "FromParams", "Registrable", "RegistrableFunction", "make_registrable", "Lazy", "Tqdm", "filename_is_safe", "threaded_generator", ] ================================================ FILE: tango/common/aliases.py ================================================ from enum import Enum, unique from os import PathLike from typing import Set, Union PathOrStr = Union[str, PathLike] @unique class EnvVarNames(Enum): FILE_FRIENDLY_LOGGING = "FILE_FRIENDLY_LOGGING" LOG_LEVEL = "TANGO_LOG_LEVEL" CLI_LOGGER_ENABLED = "TANGO_CLI_LOGGER_ENABLED" LOGGING_HOST = "TANGO_LOGGING_HOST" LOGGING_PORT = "TANGO_LOGGING_PORT" LOGGING_PREFIX = "TANGO_LOGGING_PREFIX" CONSOLE_WIDTH = "TANGO_CONSOLE_WIDTH" @classmethod def values(cls) -> Set[str]: return set(e.value for e in cls) ================================================ FILE: tango/common/dataset_dict.py ================================================ from dataclasses import dataclass, field from typing import Any, Generic, Iterable, Iterator, Mapping, Sequence, TypeVar T = TypeVar("T") S = TypeVar("S") @dataclass class DatasetDictBase(Generic[S], Mapping[str, S]): """ The base class for :class:`DatasetDict` and :class:`IterableDatasetDict`. """ splits: Mapping[str, S] """ A mapping of dataset split names to splits. """ metadata: Mapping[str, Any] = field(default_factory=dict) """ Metadata can contain anything you need. """ def __getitem__(self, split: str) -> S: """ Get a split in :attr:`splits`. """ return self.splits[split] def __contains__(self, split: str) -> bool: # type: ignore[override] """ Checks if :attr:`splits` contains the given split. """ return split in self.splits def __iter__(self) -> Iterator[str]: """ Returns an iterator over the keys in :attr:`splits`. """ return iter(self.splits.keys()) def __len__(self) -> int: """ Returns the number of splits in :attr:`splits`. """ return len(self.splits) def keys(self): """ Returns the split names in :attr:`splits`. """ return self.splits.keys() @dataclass class DatasetDict(DatasetDictBase[Sequence[T]], Generic[T]): """ A generic :class:`~collections.abc.Mapping` class of split names (:class:`str`) to datasets (``Sequence[T]``). """ @dataclass class IterableDatasetDict(DatasetDictBase[Iterable[T]], Generic[T]): """ An "iterable" version of :class:`DatasetDict`, where the dataset splits have type ``Iterable[T]`` instead of ``Sequence[T]``. This is useful for streaming datasets. """ ================================================ FILE: tango/common/det_hash.py ================================================ import collections import hashlib import io from abc import abstractmethod from typing import Any, MutableMapping, Optional, Type import base58 import dill ndarray: Optional[Type] try: from numpy import ndarray except ModuleNotFoundError: ndarray = None TorchTensor: Optional[Type] try: from torch import Tensor as TorchTensor except ModuleNotFoundError: TorchTensor = None class CustomDetHash: """ By default, :func:`det_hash()` pickles an object, and returns the hash of the pickled representation. Sometimes you want to take control over what goes into that hash. In that case, derive from this class and implement :meth:`det_hash_object()`. :func:`det_hash()` will pickle the result of this method instead of the object itself. If you return ``None``, :func:`det_hash()` falls back to the original behavior and pickles the object. """ @abstractmethod def det_hash_object(self) -> Any: """ Return an object to use for deterministic hashing instead of ``self``. """ raise NotImplementedError() class DetHashFromInitParams(CustomDetHash): """ Add this class as a mixin base class to make sure your class's det_hash is derived exclusively from the parameters passed to ``__init__()``. """ _det_hash_object: Any def __new__(cls, *args, **kwargs): super_new = super(DetHashFromInitParams, cls).__new__ if super().__new__ is object.__new__ and cls.__init__ is not object.__init__: instance = super_new(cls) else: instance = super_new(cls, *args, **kwargs) instance._det_hash_object = (args, kwargs) return instance def det_hash_object(self) -> Any: """Returns a copy of the parameters that were passed to the class instance's ``__init__()`` method.""" return self._det_hash_object class DetHashWithVersion(CustomDetHash): """ Add this class as a mixin base class to make sure your class's det_hash can be modified by altering a static ``VERSION`` member of your class. Let's say you are working on training a model. Whenever you change code that's part of your experiment, you have to change the :attr:`~tango.step.Step.VERSION` of the step that's running that code to tell Tango that the step has changed and should be re-run. But if you are training your model using Tango's built-in :class:`~tango.integrations.torch.TorchTrainStep`, how do you change the version of the step? The answer is, leave the version of the step alone, and instead add a :attr:`VERSION` to your model by deriving from this class: .. code-block:: Python class MyModel(DetHashWithVersion): VERSION = "001" def __init__(self, ...): ... """ VERSION: Optional[str] = None def det_hash_object(self) -> Any: """ Returns a tuple of :attr:`~tango.common.det_hash.DetHashWithVersion.VERSION` and this instance itself. """ if self.VERSION is not None: return self.VERSION, self else: return None # When you return `None` from here, it falls back to just hashing the object itself. _PICKLE_PROTOCOL = 4 class _DetHashPickler(dill.Pickler): def __init__(self, buffer: io.BytesIO): super().__init__(buffer, protocol=_PICKLE_PROTOCOL) # We keep track of how deeply we are nesting the pickling of an object. # If a class returns `self` as part of `det_hash_object()`, it causes an # infinite recursion, because we try to pickle the `det_hash_object()`, which # contains `self`, which returns a `det_hash_object()`, etc. # So we keep track of how many times recursively we are trying to pickle the # same object. We only call `det_hash_object()` the first time. We assume that # if `det_hash_object()` returns `self` in any way, we want the second time # to just pickle the object as normal. `DetHashWithVersion` takes advantage # of this ability. self.recursively_pickled_ids: MutableMapping[int, int] = collections.Counter() def save(self, obj, save_persistent_id=True): self.recursively_pickled_ids[id(obj)] += 1 super().save(obj, save_persistent_id) self.recursively_pickled_ids[id(obj)] -= 1 def persistent_id(self, obj: Any) -> Any: if isinstance(obj, CustomDetHash) and self.recursively_pickled_ids[id(obj)] <= 1: det_hash_object = obj.det_hash_object() if det_hash_object is not None: return obj.__class__.__module__, obj.__class__.__qualname__, det_hash_object else: return None elif isinstance(obj, type): return obj.__module__, obj.__qualname__ elif callable(obj): if hasattr(obj, "__module__") and hasattr(obj, "__qualname__"): return obj.__module__, obj.__qualname__ else: return None elif ndarray is not None and isinstance(obj, ndarray): # It's unclear why numpy arrays don't pickle in a consistent way. return obj.dumps() elif TorchTensor is not None and isinstance(obj, TorchTensor): # It's unclear why torch tensors don't pickle in a consistent way. import torch with io.BytesIO() as buffer: torch.save(obj, buffer, pickle_protocol=_PICKLE_PROTOCOL) return buffer.getvalue() else: return None def det_hash(o: Any) -> str: """ Returns a deterministic hash code of arbitrary Python objects. If you want to override how we calculate the deterministic hash, derive from the :class:`CustomDetHash` class and implement :meth:`CustomDetHash.det_hash_object()`. """ m = hashlib.blake2b() with io.BytesIO() as buffer: pickler = _DetHashPickler(buffer) pickler.dump(o) m.update(buffer.getbuffer()) return base58.b58encode(m.digest()).decode() ================================================ FILE: tango/common/exceptions.py ================================================ from typing import TYPE_CHECKING, Any, Optional, Set, Tuple, Union if TYPE_CHECKING: from tango.step import Step from tango.step_info import StepInfo, StepState class TangoError(Exception): """ Base class for Tango exceptions. """ class ConfigurationError(TangoError): """ The exception raised when a Tango object fails to initialize from a config that's misconfigured (e.g. missing properties, invalid properties, unknown properties). """ def __reduce__(self) -> Union[str, Tuple[Any, ...]]: return type(self), (self.message,) def __init__(self, message: str): super().__init__() self.message = message def __str__(self): return self.message class RegistryKeyError(ConfigurationError): """ A configuration error that is raised when attempting to get a class by a registered name that doesn't exist in the registry. """ class CancellationError(TangoError): """ Base class for errors raised due to manual cancellation of a run or step. """ class SigTermReceived(CancellationError): """ Raised when a SIGTERM is caught. """ class StepCancelled(CancellationError): pass class RunCancelled(CancellationError): pass class CliRunError(TangoError): """ Raised when `tango run` command fails. """ class IntegrationMissingError(TangoError): """ Raised when an integration can't be used due to missing dependencies. """ def __init__(self, integration: str, dependencies: Optional[Set[str]] = None): self.integration = integration self.dependencies = dependencies or {integration} msg = ( f"'{self.integration}' integration can't be used due to " f"missing dependencies ({', '.join(self.dependencies)})" ) super().__init__(msg) class StepStateError(TangoError): """ Raised when a step is in an unexpected state. """ def __init__( self, step: Union["Step", "StepInfo"], step_state: "StepState", context: Optional[str] = None, ): self.step_state = step_state self.step_id = step.unique_id msg = f"Step '{self.step_id}' is in unexpected state '{self.step_state.value}'" if context is not None: msg = msg + " " + context super().__init__(msg) class DirtyRepoError(TangoError): """ Raised when a repository is in a dirty state. """ class ExecutorError(TangoError): """ A base class for executor-specific errors. """ ================================================ FILE: tango/common/file_lock.py ================================================ import os import warnings from typing import Optional from filelock import AcquireReturnProxy from filelock import FileLock as _FileLock from .aliases import PathOrStr class FileLock(_FileLock): # type: ignore[valid-type,misc] """ This is just a subclass of the `FileLock` class from the `filelock` library, except that it adds an additional argument to the `__init__` method: `read_only_ok`. By default this flag is `False`, which an exception will be thrown when a lock can't be acquired due to lack of write permissions. But if this flag is set to `True`, a warning will be emitted instead of an error when the lock already exists but the lock can't be acquired because write access is blocked. """ def __init__(self, lock_file: PathOrStr, timeout=-1, read_only_ok: bool = False) -> None: super().__init__(str(lock_file), timeout=timeout) self._read_only_ok = read_only_ok def acquire( # type: ignore[override] self, timeout: Optional[float] = None, poll_interval: float = 0.05, ) -> AcquireReturnProxy: try: return super().acquire(timeout=timeout, poll_interval=poll_interval) except OSError as err: # OSError could be a lot of different things, but what we're looking # for in particular are permission errors, such as: # - errno 1 - EPERM - "Operation not permitted" # - errno 13 - EACCES - "Permission denied" # - errno 30 - EROFS - "Read-only file system" if err.errno not in (1, 13, 30): raise if os.path.isfile(self._lock_file) and self._read_only_ok: # type: ignore warnings.warn( f"Lacking permissions required to obtain lock '{self._lock_file}'. " # type: ignore "Race conditions are possible if other processes are writing to the same resource.", UserWarning, ) return AcquireReturnProxy(self) else: raise def acquire_with_updates(self, desc: Optional[str] = None) -> AcquireReturnProxy: """ Same as :meth:`acquire()`, except that when the lock cannot be immediately acquired, it will keep trying and print status updates as it goes. """ try: return self.acquire(timeout=0.1) except TimeoutError: pass from .tqdm import Tqdm if desc is None: desc = f"acquiring lock at {self._lock_file}" # type: ignore progress = Tqdm.tqdm(desc=desc, bar_format="{desc} [{elapsed}]") while True: progress.update() try: return self.acquire(timeout=1) except TimeoutError: continue ================================================ FILE: tango/common/from_params.py ================================================ import collections.abc import inspect import logging from copy import deepcopy from pathlib import Path from typing import ( Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, TypeVar, Union, cast, get_type_hints, ) from tango.common.det_hash import DetHashWithVersion from tango.common.exceptions import ConfigurationError from tango.common.lazy import Lazy from tango.common.params import Params try: # For PEP 604 support (python >= 3.10) from types import UnionType # type: ignore[attr-defined] except ImportError: class UnionType: # type: ignore pass logger = logging.getLogger(__name__) T = TypeVar("T", bound="FromParams") # If a function parameter has no default value specified, # this is what the inspect module returns. _NO_DEFAULT = inspect.Parameter.empty def takes_arg(obj, arg: str) -> bool: """ Checks whether the provided obj takes a certain arg. If it's a class, we're really checking whether its constructor does. If it's a function or method, we're checking the object itself. Otherwise, we raise an error. """ if inspect.isclass(obj): signature = inspect.signature(obj.__init__) elif inspect.ismethod(obj) or inspect.isfunction(obj): signature = inspect.signature(obj) else: raise ConfigurationError(f"object {obj} is not callable") return arg in signature.parameters def takes_kwargs(obj) -> bool: """ Checks whether a provided object takes in any positional arguments. Similar to takes_arg, we do this for both the __init__ function of the class or a function / method Otherwise, we raise an error """ if inspect.isclass(obj): signature = inspect.signature(obj.__init__) elif inspect.ismethod(obj) or inspect.isfunction(obj): signature = inspect.signature(obj) else: raise ConfigurationError(f"object {obj} is not callable") return any( p.kind == inspect.Parameter.VAR_KEYWORD # type: ignore for p in signature.parameters.values() ) def is_base_registrable(cls) -> bool: """ Checks whether this is a class that directly inherits from Registrable, or is a subclass of such a class. """ from tango.common.registrable import ( Registrable, # import here to avoid circular imports ) if not issubclass(cls, Registrable): return False method_resolution_order = inspect.getmro(cls)[1:] for base_class in method_resolution_order: if issubclass(base_class, Registrable) and base_class is not Registrable: return False return True def remove_optional(annotation: type): """ Optional[X] annotations are actually represented as Union[X, NoneType]. For our purposes, the "Optional" part is not interesting, so here we throw it away. """ origin = getattr(annotation, "__origin__", None) args = getattr(annotation, "__args__", ()) if origin == Union: return Union[tuple([arg for arg in args if arg != type(None)])] # noqa: E721 else: return annotation def infer_constructor_params( cls: Type[T], constructor: Optional[Union[Callable[..., T], Callable[[T], None]]] = None ) -> Dict[str, inspect.Parameter]: if constructor is None: constructor = cls.__init__ return infer_method_params(cls, constructor) infer_params = infer_constructor_params # Legacy name def infer_method_params( cls: Type[T], method: Callable, infer_kwargs: bool = True ) -> Dict[str, inspect.Parameter]: signature = inspect.signature(method) parameters = dict(signature.parameters) has_kwargs = False var_positional_key = None for param_name in list(parameters.keys()): # Ignore special private parameters. # This is necessary to make `FromParams` work with Pydantic, for example. if param_name.startswith("__"): del parameters[param_name] continue param = parameters[param_name] if param.kind == param.VAR_KEYWORD: has_kwargs = True elif param.kind == param.VAR_POSITIONAL: var_positional_key = param.name if isinstance(param.annotation, str): # For Python < 3.10, if the module where this class was defined used # `from __future__ import annotation`, the annotation will be a str, # so we need to resolve it using `get_type_hints` from the typing module. # See https://www.python.org/dev/peps/pep-0563/ for more info. try: parameters[param_name] = param.replace( annotation=get_type_hints(method)[param_name] ) except TypeError as e: if "'type' object is not subscriptable" in str(e): # This can happen when someone uses a type hint like `dict[str, str]` # instead of `Dict[str, str]`. err_msg = ( f"Failed to parse the type annotation `{param.annotation}` " f"from `{cls.__qualname__}.{method.__name__}()`." ) if "[" in param.annotation: # Check if there is an equivalent generic in the `typing` module. import typing type_, *_ = param.annotation.split("[", 1) for possible_typing_equivalent in {type_, type_.title()}: if hasattr(typing, possible_typing_equivalent): err_msg += ( f" Try using `{possible_typing_equivalent}` " "from the `typing` module instead." ) break new_e = TypeError(err_msg) new_e.__cause__ = e new_e.__cause__ = e raise new_e else: raise if var_positional_key: del parameters[var_positional_key] if not has_kwargs or not infer_kwargs: return parameters # "mro" is "method resolution order". The first one is the current class, the next is the # first superclass, and so on. We take the first superclass we find that inherits from # FromParams. super_class = None # We have to be a little careful here because in some cases we might not have an # actual class. Instead we might just have a function that returns a class instance. if hasattr(cls, "mro"): for super_class_candidate in cls.mro()[1:]: if issubclass(super_class_candidate, FromParams): super_class = super_class_candidate break if super_class: super_parameters = infer_params(super_class) else: super_parameters = {} return {**super_parameters, **parameters} # Subclass parameters overwrite superclass ones def create_kwargs( constructor: Callable[..., T], cls: Type[T], params: Params, extras: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ Given some class, a ``Params`` object, and potentially other keyword arguments, create a dict of keyword args suitable for passing to the class's constructor. The function does this by finding the class's constructor, matching the constructor arguments to entries in the ``params`` object, and instantiating values for the parameters using the type annotation and possibly a from_params method. Any values that are provided in the ``extras`` will just be used as is. For instance, you might provide an existing ``Vocabulary`` this way. """ # Get the signature of the constructor. kwargs: Dict[str, Any] = {} parameters = infer_params(cls, constructor) accepts_kwargs = False # Iterate over all the constructor parameters and their annotations. for param_name, param in parameters.items(): # Skip "self". You're not *required* to call the first parameter "self", # so in theory this logic is fragile, but if you don't call the self parameter # "self" you kind of deserve what happens. if param_name == "self": continue if param.kind == param.VAR_KEYWORD: # When a class takes **kwargs, we do two things: first, we assume that the **kwargs are # getting passed to the super class, so we inspect super class constructors to get # allowed arguments (that happens in `infer_params` above). Second, we store the fact # that the method allows extra keys; if we get extra parameters, instead of crashing, # we'll just pass them as-is to the constructor, and hope that you know what you're # doing. accepts_kwargs = True continue # If the annotation is a compound type like typing.Dict[str, int], # it will have an __origin__ field indicating `typing.Dict` # and an __args__ field indicating `(str, int)`. We capture both. annotation = remove_optional(param.annotation) explicitly_set = param_name in params constructed_arg = pop_and_construct_arg( cls.__name__, param_name, annotation, param.default, params, extras or {} ) # If the param wasn't explicitly set in `params` and we just ended up constructing # the default value for the parameter, we can just omit it. # Leaving it in can cause issues with **kwargs in some corner cases, where you might end up # with multiple values for a single parameter (e.g., the default value gives you lazy=False # for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes # lazy=True - the superclass sees both lazy=True and lazy=False in its constructor). if explicitly_set or constructed_arg is not param.default: kwargs[param_name] = constructed_arg if accepts_kwargs: for key in list(params): kwargs[key] = params.pop(key, keep_as_dict=True) if extras: for key, value in extras.items(): kwargs[key] = value params.assert_empty(cls.__name__) return kwargs def create_extras(cls: Type[T], extras: Dict[str, Any]) -> Dict[str, Any]: """ Given a dictionary of extra arguments, returns a dictionary of kwargs that actually are a part of the signature of the cls.from_params (or cls) method. """ subextras: Dict[str, Any] = {} if hasattr(cls, "from_params"): from_params_method = cls.from_params # type: ignore else: # In some rare cases, we get a registered subclass that does _not_ have a # from_params method (this happens with Activations, for instance, where we # register pytorch modules directly). This is a bit of a hack to make those work, # instead of adding a `from_params` method for them somehow. Then the extras # in the class constructor are what we are looking for, to pass on. from_params_method = cls if takes_kwargs(from_params_method): # If annotation.params accepts **kwargs, we need to pass them all along. # For example, `BasicTextFieldEmbedder.from_params` requires a Vocabulary # object, but `TextFieldEmbedder.from_params` does not. subextras = extras else: # Otherwise, only supply the ones that are actual args; any additional ones # will cause a TypeError. subextras = {k: v for k, v in extras.items() if takes_arg(from_params_method, k)} return subextras def pop_and_construct_arg( class_name: str, argument_name: str, annotation: Type, default: Any, params: Params, extras: Dict[str, Any], ) -> Any: """ Does the work of actually constructing an individual argument for [``create_kwargs``](./#create_kwargs). Here we're in the inner loop of iterating over the parameters to a particular constructor, trying to construct just one of them. The information we get for that parameter is its name, its type annotation, and its default value; we also get the full set of ``Params`` for constructing the object (which we may mutate), and any ``extras`` that the constructor might need. We take the type annotation and default value here separately, instead of using an ``inspect.Parameter`` object directly, so that we can handle ``Union`` types using recursion on this method, trying the different annotation types in the union in turn. """ # We used `argument_name` as the method argument to avoid conflicts with 'name' being a key in # `extras`, which isn't _that_ unlikely. Now that we are inside the method, we can switch back # to using `name`. name = argument_name # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary. # We check the provided `extras` for these and just use them if they exist. if name in extras: if name not in params: return extras[name] else: logger.warning( f"Parameter {name} for class {class_name} was found in both " "**extras and in params. Using the specification found in params, " "but you probably put a key in a config file that you didn't need, " "and if it is different from what we get from **extras, you might " "get unexpected behavior." ) try: popped_params = params.pop(name, default) if default != _NO_DEFAULT else params.pop(name) except ConfigurationError: raise ConfigurationError(f'Missing key "{name}" for {class_name}') if popped_params is None: return None return construct_arg(class_name, name, popped_params, annotation, default) def _params_contain_step(o: Any) -> bool: from tango.step import Step if isinstance(o, Step): return True elif isinstance(o, str): return False # Confusingly, str is an Iterable of itself, resulting in infinite recursion. elif isinstance(o, Params): return _params_contain_step(o.as_dict(quiet=True)) elif isinstance(o, dict): if set(o.keys()) == {"type", "ref"} and o["type"] == "ref": return True else: return _params_contain_step(o.values()) elif isinstance(o, Iterable): return any(_params_contain_step(p) for p in o) else: return False def construct_arg( class_name: str, argument_name: str, popped_params: Params, annotation: Type, default: Any, try_from_step: bool = True, ) -> Any: """ The first two parameters here are only used for logging if we encounter an error. """ # If we have the default, we're already done :) if popped_params is default: return popped_params from tango.step import FunctionalStep, Step, StepIndexer, WithUnresolvedSteps origin = getattr(annotation, "__origin__", None) args = getattr(annotation, "__args__", []) # Try to guess if `popped_params` might be a step, come from a step, or contain a step. could_be_step = ( try_from_step and ( origin == Step or isinstance(popped_params, Step) or _params_contain_step(popped_params) or (isinstance(popped_params, (dict, Params)) and popped_params.get("type") == "ref") ) and not (class_name == "StepInfo" and argument_name == "config") ) if could_be_step: # If we think it might be a step, we try parsing as a step _first_. # Parsing as a non-step always succeeds, because it will fall back to returning a dict. # So we can't try parsing as a non-step first. backup_params = deepcopy(popped_params) try: return construct_arg( class_name, argument_name, popped_params, Step[annotation], # type: ignore default, try_from_step=False, ) except (ValueError, TypeError, ConfigurationError, AttributeError, IndexError): popped_params = backup_params # The parameter is optional if its default value is not the "no default" sentinel. optional = default != _NO_DEFAULT if (inspect.isclass(annotation) and issubclass(annotation, FromParams)) or ( inspect.isclass(origin) and issubclass(origin, FromParams) ): if origin is None and isinstance(popped_params, annotation): return popped_params elif popped_params is not None: # In some cases we allow a string instead of a param dict, so # we need to handle that case separately. if isinstance(popped_params, str): if origin != Step: # We don't allow single strings to be upgraded to steps. # Since we try everything as a step first, upgrading strings to # steps automatically would cause confusion every time a step # name conflicts with any string anywhere in a config. popped_params = Params({"type": popped_params}) elif isinstance(popped_params, dict): popped_params = Params(popped_params) elif not isinstance(popped_params, (Params, Step)): raise TypeError( f"Expected a `Params` object, found `{popped_params}` instead while constructing " f"parameter '{argument_name}' for `{class_name}`" ) result: Union[FromParams, WithUnresolvedSteps] if isinstance(popped_params, Step): result = popped_params else: if origin != Step and _params_contain_step(popped_params): result = WithUnresolvedSteps(annotation.from_params, popped_params) else: result = annotation.from_params(popped_params) if isinstance(result, Step): expected_return_type = args[0] if args else None if isinstance(result, FunctionalStep): return_type = inspect.signature(result.WRAPPED_FUNC).return_annotation else: return_type = inspect.signature(result.run).return_annotation if return_type == inspect.Signature.empty: logger.warning( "Step %s has no return type annotation. Those are really helpful when " "debugging, so we recommend them highly.", result.__class__.__name__, ) else: try: if expected_return_type is not None and not issubclass( return_type, expected_return_type ): raise ConfigurationError( f"Step {result.name} returns {return_type}, but " f"we expected {expected_return_type}." ) except TypeError: pass return result elif not optional: # Not optional and not supplied, that's an error! raise ConfigurationError(f"expected key {argument_name} for {class_name}") else: return default # For StepIndexer, we just return as-is and hope the for the best. # At worst, user will get an error at runtime if they are trying to index a step # result that can't be indexed. # TODO (epwalsh): we could check the return type of the wrapped step here # and make sure that: # 1. It's an index-able object, # 2. The item in the index-able object matches `annotation`. # # But that's complex and might have false negatives. elif type(popped_params) == StepIndexer: return popped_params # If the parameter type is a Python primitive, just pop it off # using the correct casting pop_xyz operation. elif annotation in {int, bool}: if type(popped_params) in {int, bool}: return annotation(popped_params) else: raise TypeError( f"Expected {argument_name} to be {annotation.__name__}, " f"found {popped_params} ({type(popped_params)})." ) elif annotation == str: # Strings are special because we allow casting from Path to str. if isinstance(popped_params, str) or isinstance(popped_params, Path): return str(popped_params) # type: ignore else: raise TypeError( f"Expected {argument_name} to be a string, found {popped_params} ({type(popped_params)})" ) elif annotation == float: # Floats are special because in Python, you can put an int wherever you can put a float. # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html if type(popped_params) in {int, float}: return popped_params else: raise TypeError(f"Expected {argument_name} to be numeric.") elif annotation == Path: if isinstance(popped_params, (str, Path)): return Path(popped_params) else: raise TypeError( f"Expected {argument_name} to be a str or Path, found {popped_params} ({type(popped_params)})" ) # This is special logic for handling types like Dict[str, TokenIndexer], # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer], # which it creates by instantiating each value from_params and returning the resulting structure. elif origin in {collections.abc.Mapping, Mapping, Dict, dict} and len(args) == 2: value_cls = annotation.__args__[-1] value_dict = {} if not isinstance(popped_params, Mapping): raise TypeError( f"Expected {argument_name} to be a Mapping (probably a dict or a Params object) " f"found {popped_params} ({type(popped_params)})." ) for key, value_params in popped_params.items(): value_dict[key] = construct_arg( str(value_cls), argument_name + "." + key, value_params, value_cls, _NO_DEFAULT, ) return value_dict elif origin in (Tuple, tuple): value_list = [] value_types = list(annotation.__args__) if value_types[-1] == Ellipsis: # Variable length tuples, e.g. 'Tuple[int, ...]', we set value_types to '[int] * len(popped_params)'. value_types = value_types[:-1] + [value_types[-2]] * ( len(popped_params) - len(annotation.__args__) + 1 ) for i, (value_cls, value_params) in enumerate(zip(value_types, popped_params)): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, ) value_list.append(value) return tuple(value_list) elif origin in (Set, set) and len(args) == 1: value_cls = annotation.__args__[0] value_set = set() for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, ) value_set.add(value) return value_set elif origin == Union or isinstance(annotation, UnionType): # Storing this so we can recover it later if we need to. backup_params = deepcopy(popped_params) # We'll try each of the given types in the union sequentially, returning the first one that # succeeds. error_chain: Optional[Exception] = None for arg_annotation in args: try: return construct_arg( str(arg_annotation), argument_name, popped_params, arg_annotation, default, ) except (ValueError, TypeError, ConfigurationError, AttributeError) as e: # Our attempt to construct the argument may have modified popped_params, so we # restore it here. popped_params = deepcopy(backup_params) e.args = (f"While constructing an argument of type {arg_annotation}",) + e.args e.__cause__ = error_chain error_chain = e # If none of them succeeded, we crash. config_error = ConfigurationError( f"Failed to construct argument {argument_name} with type {annotation}." ) config_error.__cause__ = error_chain raise config_error elif origin == Lazy: value_cls = args[0] return Lazy(value_cls, params=deepcopy(popped_params)) # type: ignore # For any other kind of iterable, we will just assume that a list is good enough, and treat # it the same as List. This condition needs to be at the end, so we don't catch other kinds # of Iterables with this branch. elif origin in {collections.abc.Iterable, Iterable, List, list} and len(args) == 1: value_cls = annotation.__args__[0] value_list = [] for i, value_params in enumerate(popped_params): value = construct_arg( str(value_cls), argument_name + f".{i}", value_params, value_cls, _NO_DEFAULT, ) value_list.append(value) return value_list elif (inspect.isclass(annotation) or inspect.isclass(origin)) and isinstance( popped_params, Params ): # Constructing arbitrary classes from params arbitrary_class = origin or annotation constructor_to_inspect = arbitrary_class.__init__ constructor_to_call = arbitrary_class params_contain_step = _params_contain_step(popped_params) kwargs = create_kwargs(constructor_to_inspect, arbitrary_class, popped_params) from tango.step import WithUnresolvedSteps if origin != Step and params_contain_step: return WithUnresolvedSteps(constructor_to_call, *[], **kwargs) else: return constructor_to_call(**kwargs) # type: ignore else: # Pass it on as is and hope for the best. ¯\_(ツ)_/¯ if isinstance(popped_params, Params): return popped_params.as_dict() return popped_params class FromParams(DetHashWithVersion): """ Mixin to give a :meth:`from_params` method to classes. We create a distinct base class for this because sometimes we want non :class:`~tango.common.Registrable` classes to be instantiatable ``from_params``. """ @classmethod def from_params( cls: Type[T], params_: Union[Params, dict, str], constructor_to_call: Optional[Callable[..., T]] = None, constructor_to_inspect: Optional[Union[Callable[..., T], Callable[[T], None]]] = None, **extras, ) -> T: """ This is the automatic implementation of ``from_params``. Any class that subclasses from ``FromParams`` (or :class:`~tango.common.Registrable`, which itself subclasses ``FromParams``) gets this implementation for free. If you want your class to be instantiated from params in the "obvious" way -- pop off parameters and hand them to your constructor with the same names -- this provides that functionality. If you need more complex logic in your from ``from_params`` method, you'll have to implement your own method that overrides this one. The ``constructor_to_call`` and ``constructor_to_inspect`` arguments deal with a bit of redirection that we do. We allow you to register particular ``@classmethods`` on a class as the constructor to use for a registered name. This lets you, e.g., have a single ``Vocabulary`` class that can be constructed in two different ways, with different names registered to each constructor. In order to handle this, we need to know not just the class we're trying to construct (``cls``), but also what method we should inspect to find its arguments (``constructor_to_inspect``), and what method to call when we're done constructing arguments (``constructor_to_call``). These two methods are the same when you've used a ``@classmethod`` as your constructor, but they are ``different`` when you use the default constructor (because you inspect ``__init__``, but call ``cls()``). """ from tango.common.registrable import ( Registrable, # import here to avoid circular imports ) params = params_ logger.debug( f"instantiating class {cls} from params {getattr(params, 'params', params)} " f"and extras {set(extras.keys())}" ) if params is None: return None if isinstance(params, str): params = Params({"type": params}) if not isinstance(params, Params): if isinstance(params, dict): params = Params(params) else: raise ConfigurationError( "from_params was passed a `params` object that was not a `Params`. This probably " "indicates malformed parameters in a configuration file, where something that " "should have been a dictionary was actually a list, or something else. " f"This happened when constructing an object of type {cls}." ) if issubclass(cls, Registrable) and not constructor_to_call: # We know `cls` inherits from Registrable, so we'll use a cast to make mypy happy. as_registrable = cast(Type[Registrable], cls) if "type" in params and params["type"] not in as_registrable.list_available(): as_registrable.search_modules(params["type"]) # Resolve the subclass and constructor. if is_base_registrable(cls) or "type" in params: default_to_first_choice = as_registrable.default_implementation is not None choice = params.pop_choice( "type", choices=as_registrable.list_available(), default_to_first_choice=default_to_first_choice, ) # We allow users to register methods and functions, not just classes. # So we have to handle both here. subclass_or_factory_func, constructor_name = as_registrable.resolve_class_name( choice ) if inspect.isclass(subclass_or_factory_func): # We have an actual class. subclass = subclass_or_factory_func if constructor_name is not None: constructor_to_inspect = cast( Callable[..., T], getattr(subclass, constructor_name) ) constructor_to_call = constructor_to_inspect else: constructor_to_inspect = subclass.__init__ constructor_to_call = subclass else: # We have a function that returns an instance of the class. factory_func = cast(Callable[..., T], subclass_or_factory_func) return_type = inspect.signature(factory_func).return_annotation if return_type == inspect.Signature.empty: subclass = cls else: subclass = return_type constructor_to_inspect = factory_func constructor_to_call = factory_func else: # Must be trying to instantiate the given class directly. subclass = cls constructor_to_inspect = cls.__init__ constructor_to_call = cast(Callable[..., T], cls) if hasattr(subclass, "from_params"): # We want to call subclass.from_params. extras = create_extras(subclass, extras) # mypy can't follow the typing redirection that we do, so we explicitly cast here. retyped_subclass = cast(Type[T], subclass) return retyped_subclass.from_params( params, constructor_to_call=constructor_to_call, constructor_to_inspect=constructor_to_inspect, **extras, ) else: # In some rare cases, we get a registered subclass that does _not_ have a # from_params method (this happens with Activations, for instance, where we # register pytorch modules directly). This is a bit of a hack to make those work, # instead of adding a `from_params` method for them somehow. We just trust that # you've done the right thing in passing your parameters, and nothing else needs to # be recursively constructed. kwargs = create_kwargs(constructor_to_inspect, subclass, params, extras) # type: ignore return constructor_to_call(**kwargs) # type: ignore else: # This is not a base class, so convert our params and extras into a dict of kwargs. # See the docstring for an explanation of what's going on here. if not constructor_to_inspect: constructor_to_inspect = cls.__init__ if not constructor_to_call: constructor_to_call = cls if constructor_to_inspect == object.__init__: # This class does not have an explicit constructor, so don't give it any kwargs. # Without this logic, create_kwargs will look at object.__init__ and see that # it takes *args and **kwargs and look for those. kwargs: Dict[str, Any] = {} # type: ignore[no-redef] params.assert_empty(cls.__name__) else: # This class has a constructor, so create kwargs for it. constructor_to_inspect = cast(Callable[..., T], constructor_to_inspect) kwargs = create_kwargs(constructor_to_inspect, cls, params, extras) return constructor_to_call(**kwargs) # type: ignore def to_params(self) -> Params: """ Returns a ``Params`` object that can be used with ``.from_params()`` to recreate an object just like it. This relies on ``_to_params()``. If you need this in your custom ``FromParams`` class, override ``_to_params()``, not this method. """ def replace_object_with_params(o: Any) -> Any: if isinstance(o, FromParams): return o.to_params().as_dict(quiet=True) elif isinstance(o, (list, tuple, set)): return [replace_object_with_params(i) for i in o] elif isinstance(o, dict): return {key: replace_object_with_params(value) for key, value in o.items()} elif isinstance(o, Path): return str(o) elif o is None or isinstance(o, (str, float, int, bool)): return o else: raise NotImplementedError( f"Unexpected type encountered in to_params(): {o} ({type(o)})\n" "You may need to implement a custom '_to_params()'." ) return Params(replace_object_with_params(self._to_params())) def _to_params(self) -> Dict[str, Any]: """ Returns a dictionary of parameters that, when turned into a ``Params`` object and then fed to ``.from_params()``, will recreate this object. You don't need to implement this all the time. Tango will let you know if you need it. """ try: return self.__dict__ except AttributeError: raise NotImplementedError( f"{self.__class__.__name__}._to_params() needs to be implemented" ) ================================================ FILE: tango/common/lazy.py ================================================ import copy import inspect from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, Union, cast from .det_hash import CustomDetHash, DetHashWithVersion from .params import Params T = TypeVar("T") class Lazy(Generic[T], CustomDetHash): """ This class is for use when constructing objects using :class:`~tango.common.FromParams`, when an argument to a constructor has a `sequential dependency` with another argument to the same constructor. For example, in a ``Trainer`` class you might want to take a ``Model`` and an ``Optimizer`` as arguments, but the ``Optimizer`` needs to be constructed using the parameters from the ``Model``. You can give the type annotation ``Lazy[Optimizer]`` to the optimizer argument, then inside the constructor call ``optimizer.construct(parameters=model.parameters)``. This is only recommended for use when you have registered a ``@classmethod`` as the constructor for your class, instead of using ``__init__``. Having a ``Lazy[]`` type annotation on an argument to an ``__init__`` method makes your class completely dependent on being constructed using the ``FromParams`` pipeline, which is not a good idea. The actual implementation here is incredibly simple; the logic that handles the lazy construction is actually found in ``FromParams``, where we have a special case for a ``Lazy`` type annotation. Examples -------- :: @classmethod def my_constructor( cls, some_object: Lazy[MyObject], optional_object: Lazy[MyObject] = None, # or: # optional_object: Optional[Lazy[MyObject]] = None, optional_object_with_default: Optional[Lazy[MyObject]] = Lazy(MyObjectDefault), required_object_with_default: Lazy[MyObject] = Lazy(MyObjectDefault), ) -> MyClass: obj1 = some_object.construct() obj2 = None if optional_object is None else optional_object.construct() obj3 = None optional_object_with_default is None else optional_object_with_default.construct() obj4 = required_object_with_default.construct() """ def __init__( self, constructor: Union[Type[T], Callable[..., T]], params: Optional[Params] = None, constructor_extras: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: self._constructor = constructor self._params = params or Params({}) self._constructor_extras = constructor_extras or {} self._constructor_extras.update(kwargs) @property def constructor(self) -> Callable[..., T]: from tango.common.from_params import FromParams if inspect.isclass(self._constructor) and issubclass(self._constructor, FromParams): def constructor_to_use(**kwargs): return self._constructor.from_params( # type: ignore[union-attr] copy.deepcopy(self._params), **kwargs, ) return constructor_to_use else: return self._constructor def construct(self, **kwargs) -> T: """ Call the constructor to create an instance of ``T``. """ # If there are duplicate keys between self._constructor_extras and kwargs, # this will overwrite the ones in self._constructor_extras with what's in kwargs. constructor_kwargs = {**self._constructor_extras, **kwargs} return self.constructor(**constructor_kwargs) def det_hash_object(self) -> Any: from tango.common.from_params import FromParams class_to_construct: Union[Type[T], Callable[..., T]] = self._constructor if isinstance(class_to_construct, type) and issubclass(class_to_construct, FromParams): params = copy.deepcopy(self._params) if params is None: params = Params({}) elif isinstance(params, str): params = Params({"type": params}) elif isinstance(params, dict): params = Params(params) elif not isinstance(params, Params): return None from tango.common import Registrable if issubclass(class_to_construct, Registrable): as_registrable = cast(Type[Registrable], class_to_construct) if "type" in params and params["type"] not in as_registrable.list_available(): as_registrable.search_modules(params["type"]) # Resolve the subclass and constructor. from .from_params import is_base_registrable if is_base_registrable(class_to_construct) or "type" in params: default_to_first_choice = as_registrable.default_implementation is not None choice = params.pop_choice( "type", choices=as_registrable.list_available(), default_to_first_choice=default_to_first_choice, ) subclass_or_factory_func, _ = as_registrable.resolve_class_name(choice) if inspect.isclass(subclass_or_factory_func): class_to_construct = subclass_or_factory_func else: # We have a function that returns an instance of the class. factory_func = cast(Callable[..., T], subclass_or_factory_func) return_type = inspect.signature(factory_func).return_annotation if return_type != inspect.Signature.empty: class_to_construct = return_type if isinstance(class_to_construct, type) and issubclass( class_to_construct, DetHashWithVersion ): return class_to_construct.VERSION, self else: return self ================================================ FILE: tango/common/logging.py ================================================ """ Tango makes heavy use of the :mod:`logging` module from the standard library to convey information to users. When you're writing your own :class:`~tango.step.Step` implementations we encourage you to also use standard Python logging as opposed to :func:`print` or other functions that write directly to ``stdout`` or ``stderr``. This is easy enough since each :class:`~tango.step.Step` class already comes with its own logger: :attr:`Step.logger `. When using the `Tango CLI <./commands.html>`_ you can set the log level in several different ways: 1. Through a Tango `global settings <./commands.html#global-settings>`_ file. 2. With the environment variable ``TANGO_LOG_LEVEL``. 3. Or with the ``--log-level`` command-line option. In some cases (like when running on `Beaker `_) you may also want to enable `"file friendly logging" <#tango.common.logging.FILE_FRIENDLY_LOGGING>`_. Configuring logging in your own CLI ----------------------------------- If you're writing your own CLI that uses tango, you can utilize the :func:`initialize_logging()` function to easily configure logging properly. For example, .. testcode:: from tango.common.logging import initialize_logging, teardown_logging initialize_logging(log_level="info") logger = logging.getLogger() logger.info("Running script!") teardown_logging() .. testoutput:: :options: +ELLIPSIS [...] INFO Running script! ... If you want to have logs written to a file, you can use the :func:`file_handler` context manager. Logging from worker processes or threads ---------------------------------------- If you have steps or other functions that spawn workers, and you want to enable logging within those workers, you can call the :func:`initialize_worker_logging()` function to configure logging within each worker. This assumes that you've called :func:`initialize_logging()` from the main process (the tango CLI does this for you). For example, .. testcode:: import logging import multiprocessing as mp from tango import Step from tango.common.logging import initialize_worker_logging @Step.register("multiprocessing_step_example") class MultiprocessingStep(Step): def run(self, num_proc: int = 2) -> bool: # type: ignore workers = [] for i in range(num_proc): worker = mp.Process(target=_worker_function, args=(i,)) workers.append(worker) worker.start() for worker in workers: worker.join() return True def _worker_function(worker_id: int): initialize_worker_logging(worker_rank=worker_id) logger = logging.getLogger(MultiprocessingStep.__name__) logger.info("Hello from worker %d!", worker_id) """ import logging import logging.handlers import os import pickle import socketserver import struct import sys import threading from contextlib import contextmanager from datetime import datetime from pathlib import Path from typing import Callable, ClassVar, ContextManager, Generator, List, Optional, Union import rich from rich.console import Console, ConsoleRenderable, Group from rich.highlighter import NullHighlighter from rich.padding import Padding from rich.syntax import Syntax from rich.table import Table from rich.text import Text from .aliases import EnvVarNames, PathOrStr from .exceptions import CancellationError, CliRunError, SigTermReceived from .util import _parse_bool, _parse_optional_int FILE_FRIENDLY_LOGGING: bool = _parse_bool( os.environ.get(EnvVarNames.FILE_FRIENDLY_LOGGING.value, False) ) """ If this flag is set to ``True``, we remove special styling characters from log messages, add newlines to :class:`~tango.common.tqdm.Tqdm` output even on an interactive terminal, and we slow down :class:`~tango.common.tqdm.Tqdm`'s output to only once every 10 seconds. .. attention:: Unfortunately this won't affect ``tqdm`` output from other libraries that don't use Tango's :class:`~tango.common.tqdm.Tqdm` wrapper. By default, it is set to ``False``. It can be changed by setting the corresponding environment variable (``FILE_FRIENDLY_LOGGING``) or field in a :class:`~tango.__main__.TangoGlobalSettings` file (``file_friendly_logging``) to "true" or "false", or from the command line with the ``--file-friendly-logging`` flag. For example, .. code-block:: $ tango --file-friendly-logging run ... """ TANGO_LOG_LEVEL: Optional[str] = os.environ.get(EnvVarNames.LOG_LEVEL.value, None) """ The log level to use globally. The value can be set from the corresponding environment variable (``TANGO_LOG_LEVEL``) or field in a :class:`~tango.__main__.TangoGlobalSettings` file (``log_level``), or from the command line with the ``--log-level`` option. Possible values are "debug", "info", "warning", or "error" (not case sensitive). For example, .. code-block:: $ tango --log-level info run ... .. note:: This does not affect the :data:`~tango.common.logging.cli_logger` or logs from :class:`~tango.common.Tqdm` progress bars. """ TANGO_CONSOLE_WIDTH: Optional[int] = _parse_optional_int( os.environ.get(EnvVarNames.CONSOLE_WIDTH.value, None) ) # Click logger disabled by default in case nobody calls initialize_logging(). TANGO_CLI_LOGGER_ENABLED: bool = _parse_bool( os.environ.get(EnvVarNames.CLI_LOGGER_ENABLED.value, False) ) # Keep track of exceptions logged so we don't log duplicates from our custom excepthook. _EXCEPTIONS_LOGGED: List[BaseException] = [] class LevelFilter(logging.Filter): """ Filters out everything that is above `max_level` or higher. This is meant to be used with a stdout handler when a stderr handler is also configured. That way WARNING or ERROR messages aren't duplicated. """ def __init__(self, max_level: int, min_level: Optional[int] = None, name=""): self.max_level = max_level self.min_level = min_level super().__init__(name) def filter(self, record): if self.min_level is not None: return self.min_level <= record.levelno <= self.max_level else: return record.levelno <= self.max_level class CliFilter(logging.Filter): def __init__(self, filter_out: bool): self.filter_out = filter_out def filter(self, record): if self.filter_out: return record.name != "tango.__main__" else: return record.name == "tango.__main__" class WorkerLogFilter(logging.Filter): def __init__(self, rank=-1): super().__init__() self._rank = rank def filter(self, record): if self._rank != -1: record.msg = f"[rank {self._rank}] {record.msg}" return True class PrefixLogFilter(logging.Filter): def __init__(self, prefix): super().__init__() self._prefix = prefix def filter(self, record): if not isinstance(record.msg, str): return True if record.name == "tango.__main__": from rich.markup import escape record.msg = escape(f"[{self._prefix}] ") + record.msg else: record.msg = f"[{self._prefix}] {record.msg}" return True class LogRecordStreamHandler(socketserver.StreamRequestHandler): """Handler for a streaming logging request. This basically logs the record using whatever logging policy is configured locally. Taken from `the logging cookbook `_. """ def handle(self): """ Handle multiple requests - each expected to be a 4-byte length, followed by the LogRecord in pickle format. Logs the record according to whatever policy is configured locally. """ while True: chunk = self.connection.recv(4) if len(chunk) < 4: break slen = struct.unpack(">L", chunk)[0] chunk = self.connection.recv(slen) while len(chunk) < slen: chunk = chunk + self.connection.recv(slen - len(chunk)) obj = self.unPickle(chunk) record = logging.makeLogRecord(obj) self.handleLogRecord(record) def unPickle(self, data): return pickle.loads(data) def handleLogRecord(self, record): name = record.name logger = logging.getLogger(name) # N.B. EVERY record gets logged. This is because Logger.handle # is normally called AFTER logger-level filtering. If you want # to do filtering, do it at the client end to save wasting # cycles and network bandwidth! logger.handle(record) class LogRecordSocketReceiver(socketserver.ThreadingTCPServer): """ Simple TCP socket-based logging receiver. Taken from `the logging cookbook `_. """ allow_reuse_address = True def __init__(self, host: str, port: int = 0): super().__init__((host, port), LogRecordStreamHandler) self.abort = False self.timeout = 0.2 def serve_until_stopped(self): import select while not self.abort: rd, _, _ = select.select([self.socket.fileno()], [], [], self.timeout) if rd: self.handle_request() _LOGGING_PREFIX: str = os.environ.get(EnvVarNames.LOGGING_PREFIX.value, "") _LOGGING_HOST: str = os.environ.get(EnvVarNames.LOGGING_HOST.value, "localhost") _LOGGING_PORT: Optional[int] = _parse_optional_int( os.environ.get(EnvVarNames.LOGGING_PORT.value, None) ) _LOGGING_SERVER: Optional[LogRecordSocketReceiver] = None _LOGGING_SERVER_THREAD: Optional[threading.Thread] = None class RichHandler(logging.Handler): """ Adapted from https://github.com/Textualize/rich/blob/master/rich/logging.py """ KEYWORDS: ClassVar[Optional[List[str]]] = [ "GET", "POST", "HEAD", "PUT", "DELETE", "OPTIONS", "TRACE", "PATCH", ] def __init__( self, level: Union[int, str] = logging.NOTSET, console: Optional[Console] = None, *, markup: bool = False, log_time_format: Union[str, Callable[[datetime], str]] = "[%x %X]", keywords: Optional[List[str]] = None, show_time: bool = True, show_level: bool = True, show_path: bool = True, ) -> None: super().__init__(level=level) self.console = console or rich.get_console() self.highlighter = NullHighlighter() self.time_format = log_time_format self.markup = markup self.keywords = keywords or self.KEYWORDS self.show_time = show_time self.show_level = show_level self.show_path = show_path def emit(self, record: logging.LogRecord) -> None: if isinstance(record.msg, (Syntax, Table)): self.console.print(Padding(record.msg, (1, 0, 1, 1))) elif hasattr(record.msg, "__rich__") or hasattr(record.msg, "__rich_console__"): self.console.print(record.msg) else: message = self.format(record) message_renderable = self.render_message(record, message) log_renderable = self.render(record=record, message_renderable=message_renderable) try: self.console.print(log_renderable) except Exception: self.handleError(record) def render_message(self, record: logging.LogRecord, message: str) -> ConsoleRenderable: use_markup = getattr(record, "markup", self.markup) message_text = Text.from_markup(message) if use_markup else Text(message) if self.show_path and record.exc_info is None: message_text.end = " " highlighter = getattr(record, "highlighter", self.highlighter) if highlighter: message_text = highlighter(message_text) if self.keywords is None: self.keywords = self.KEYWORDS if self.keywords: message_text.highlight_words(self.keywords, "logging.keyword") return message_text def get_time_text(self, record: logging.LogRecord) -> Text: log_time = datetime.fromtimestamp(record.created) time_str: str if callable(self.time_format): time_str = self.time_format(log_time) else: time_str = log_time.strftime(self.time_format) return Text(time_str, style="log.time", end=" ") def get_level_text(self, record: logging.LogRecord) -> Text: level_name = record.levelname level_text = Text.styled(level_name.ljust(8), f"logging.level.{level_name.lower()}") level_text.style = "log.level" level_text.end = " " return level_text def get_path_text(self, record: logging.LogRecord, length_so_far: int) -> Text: path = Path(record.pathname) for package_root in sys.path: try: path = path.relative_to(Path(package_root)) break except ValueError: continue text = f"{path}:{record.lineno}" length_after_wrap = length_so_far % self.console.width return Text( text.rjust(self.console.width - length_after_wrap - 3), style="log.path", ) def render( self, *, record: logging.LogRecord, message_renderable: ConsoleRenderable, ) -> ConsoleRenderable: components: List[ConsoleRenderable] = [] if self.show_time: components.append(self.get_time_text(record)) if self.show_level: components.append(self.get_level_text(record)) components.append(message_renderable) if self.show_path and record.exc_info is None: try: length_so_far = sum(len(x) for x in components) # type: ignore except TypeError: pass else: components.append(self.get_path_text(record, length_so_far)) return Group(*components) def get_handler( level: int, stderr: bool = False, enable_markup: bool = False, show_time: bool = True, show_level: bool = True, show_path: bool = True, ) -> logging.Handler: console = Console( color_system="auto" if not FILE_FRIENDLY_LOGGING else None, stderr=stderr, width=TANGO_CONSOLE_WIDTH, soft_wrap=True, ) if TANGO_CONSOLE_WIDTH is None and not console.is_terminal: console.width = 160 handler = RichHandler( level=level, console=console, markup=enable_markup, show_time=show_time, show_level=show_level, show_path=show_path, ) return handler cli_logger = logging.getLogger("tango.__main__") """ A logger that emits messages directly to stdout/stderr using `rich `_'s :class:`~rich.console.Console` class. This provides a convenient way for command-line apps to log pretty, styled messages uses the `markup style `_ provided by `rich`. """ cli_logger.propagate = False cli_logger.disabled = TANGO_CLI_LOGGER_ENABLED def excepthook(exctype, value, traceback): """ Used to patch `sys.excepthook` in order to log exceptions. """ log_exc_info(exctype, value, traceback) def log_exception(exc: Optional[BaseException] = None, logger: Optional[logging.Logger] = None): if exc is None: et, ev, tb = sys.exc_info() log_exc_info(et, ev, tb, logger=logger) else: log_exc_info(exc.__class__, exc, exc.__traceback__, logger=logger) def log_exc_info(exctype, value, traceback, logger: Optional[logging.Logger] = None): global _EXCEPTIONS_LOGGED if value not in _EXCEPTIONS_LOGGED: _EXCEPTIONS_LOGGED.append(value) logger = logger or logging.getLogger() if isinstance(value, CliRunError): msg = str(value) if msg: cli_logger.error(msg) elif isinstance(value, (KeyboardInterrupt, CancellationError)): logger.error("%s: %s", exctype.__name__, value) else: logger.error( "Uncaught exception", exc_info=(exctype, value, traceback), extra={"highlighter": rich.highlighter.ReprHighlighter()}, ) def initialize_logging( *, log_level: Optional[str] = None, enable_cli_logs: Optional[bool] = None, file_friendly_logging: Optional[bool] = None, ): """ Initialize logging, which includes setting the global log level, format, and configuring handlers. .. tip:: This should be called as early on in your script as possible. .. tip:: You should also call :func:`teardown_logging()` as the end of your script. .. tip:: For worker threads/processes, use :func:`initialize_worker_logging()` instead. :param log_level: Can be one of "debug", "info", "warning", "error". Defaults to the value of :data:`TANGO_LOG_LEVEL`, if set, or "error". :param enable_cli_logs: Set to ``True`` to enable messages from the :data:`cli_logger`. :param file_friendly_logging: Enable or disable file friendly logging. Defaults to the value of :data:`FILE_FRIENDLY_LOGGING`. """ import multiprocessing as mp is_main_process: bool if hasattr(mp, "parent_process"): # python 3.8 or greater is_main_process = mp.parent_process() is None # type: ignore else: is_main_process = mp.current_process().name == "MainProcess" _initialize_logging( log_level=log_level, enable_cli_logs=enable_cli_logs, file_friendly_logging=file_friendly_logging, main_process=is_main_process, ) def initialize_worker_logging(worker_rank: Optional[int] = None): """ Initialize logging in a worker thread/process. :param worker_rank: The rank/ID of the worker. """ if worker_rank is not None: if worker_rank != -1: prefix = f"rank {worker_rank}" else: prefix = None else: prefix = None return initialize_prefix_logging(prefix=prefix, main_process=False) def initialize_prefix_logging( *, log_level: Optional[str] = None, prefix: Optional[str] = None, main_process: bool = False ): """ Initialize logging with a prefix. :param log_level: Can be one of "debug", "info", "warning", "error". Defaults to the value of :data:`TANGO_LOG_LEVEL`, if set, or "error". :param prefix: The string prefix to add to the log message. :param main_process: Whether it is for the main/worker process. """ return _initialize_logging(log_level=log_level, prefix=prefix, main_process=main_process) def _initialize_logging( *, log_level: Optional[str] = None, enable_cli_logs: Optional[bool] = None, file_friendly_logging: Optional[bool] = None, prefix: Optional[str] = None, main_process: bool = True, ): global FILE_FRIENDLY_LOGGING, TANGO_LOG_LEVEL, TANGO_CLI_LOGGER_ENABLED global _LOGGING_HOST, _LOGGING_PORT, _LOGGING_SERVER, _LOGGING_SERVER_THREAD, _LOGGING_PREFIX if log_level is None: log_level = TANGO_LOG_LEVEL if log_level is None: log_level = "warning" if file_friendly_logging is None: file_friendly_logging = FILE_FRIENDLY_LOGGING if enable_cli_logs is None: enable_cli_logs = TANGO_CLI_LOGGER_ENABLED if prefix: prefix = _LOGGING_PREFIX + " " + prefix if _LOGGING_PREFIX else prefix else: prefix = _LOGGING_PREFIX level = logging._nameToLevel[log_level.upper()] # Update global flags and corresponding environment variables, if necessary, # so that child processes can read the environment variables to determine the right # settings. TANGO_LOG_LEVEL = log_level os.environ[EnvVarNames.LOG_LEVEL.value] = log_level if file_friendly_logging is not None: FILE_FRIENDLY_LOGGING = file_friendly_logging os.environ[EnvVarNames.FILE_FRIENDLY_LOGGING.value] = str(file_friendly_logging).lower() if enable_cli_logs is not None: TANGO_CLI_LOGGER_ENABLED = enable_cli_logs os.environ[EnvVarNames.CLI_LOGGER_ENABLED.value] = str(enable_cli_logs).lower() from .tqdm import logger as tqdm_logger # Handle special cases for specific loggers: # These loggers emit too many messages, so we tell them to be quiet unless they have something # important to say. for loud_logger in {"filelock", "sqlitedict"}: logging.getLogger(loud_logger).setLevel(max(level, logging.WARNING)) # We always want to see all CLI messages if we're running from the command line, and none otherwise. cli_logger.setLevel(logging.DEBUG) cli_logger.disabled = not enable_cli_logs # We also want to enable the tqdm logger so that the progress bar lines end up in the log file. tqdm_logger.setLevel(logging.DEBUG) root_logger = logging.getLogger() root_logger.setLevel(level) root_logger.handlers.clear() if main_process: # Create stdout and stderr handlers so that we can route DEBUG and INFO # messages to stdout, and WARNING and ERROR messages to stderr. stdout_handler = get_handler(level) stdout_handler.addFilter(LevelFilter(logging.INFO)) stderr_handler = get_handler(max(level, logging.WARNING), stderr=True) stderr_handler.addFilter(LevelFilter(logging.CRITICAL, min_level=logging.WARNING)) root_logger.addHandler(stdout_handler) root_logger.addHandler(stderr_handler) # Configure cli_logger so that if log level <= INFO, it will behave # like a regular logger, otherwise it prints directly to stdout. cli_logger.handlers.clear() if enable_cli_logs: for handler_level in (logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR): cli_handler = get_handler( handler_level, stderr=handler_level >= logging.WARNING, enable_markup=True, show_time=level <= handler_level, show_level=(level <= handler_level) or handler_level >= logging.WARNING, show_path=level <= handler_level, ) cli_handler.addFilter(LevelFilter(handler_level)) cli_logger.addHandler(cli_handler) # Add prefix. if prefix: for logger in (root_logger, cli_logger, tqdm_logger): for handler in logger.handlers: handler.addFilter(PrefixLogFilter(prefix)) # Main process: set formatter and handlers, initialize logging socket and server. # Set up logging socket to emit log records from worker processes/threads. # Inspired by: # https://docs.python.org/3.8/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network _LOGGING_SERVER = LogRecordSocketReceiver(_LOGGING_HOST, 0) _LOGGING_PORT = _LOGGING_SERVER.server_address[1] os.environ[EnvVarNames.LOGGING_PORT.value] = str(_LOGGING_PORT) _LOGGING_SERVER_THREAD = threading.Thread( target=_LOGGING_SERVER.serve_until_stopped, daemon=True ) _LOGGING_SERVER_THREAD.start() else: # Child process: set handler and level, no need to set formatting since only raw log records # will be sent to the logging socket. if _LOGGING_PORT is None: raise ValueError( "missing logging socket configuration, " "did you forget to call 'initialize_logging()' from the main process?" ) socket_handler = logging.handlers.SocketHandler(_LOGGING_HOST, _LOGGING_PORT) if prefix: socket_handler.addFilter(PrefixLogFilter(prefix)) for logger in (root_logger, cli_logger, tqdm_logger): logger.handlers.clear() logger.addHandler(socket_handler) # Write uncaught exceptions to the logs. sys.excepthook = excepthook # Ensure warnings issued by the 'warnings' module will be redirected to the logging system. logging.captureWarnings(True) def teardown_logging(): """ Cleanup any logging fixtures created from :func:`initialize_logging()`. Should be called at the end of your script. """ global _LOGGING_HOST, _LOGGING_PORT, _LOGGING_SERVER, _LOGGING_SERVER_THREAD if _LOGGING_SERVER is not None: _LOGGING_SERVER.abort = True if _LOGGING_SERVER_THREAD is not None: _LOGGING_SERVER_THREAD.join() _LOGGING_SERVER_THREAD = None if _LOGGING_SERVER is not None: _LOGGING_SERVER = None sys.excepthook = sys.__excepthook__ # type: ignore[assignment] @contextmanager def insert_handlers(*handlers: logging.Handler) -> Generator[None, None, None]: """ A context manager that can be used to route logs to a specific handler temporarily. """ global _EXCEPTIONS_LOGGED root_logger = logging.getLogger() from .tqdm import logger as tqdm_logger for logger in (root_logger, cli_logger, tqdm_logger): for handler in handlers: logger.addHandler(handler) try: yield None except BaseException as e: if not isinstance( e, (CliRunError, KeyboardInterrupt, SigTermReceived) ): # don't need tracebacks for these log_exception(e) _EXCEPTIONS_LOGGED.append(e) raise finally: for logger in (root_logger, cli_logger, tqdm_logger): for handler in handlers: logger.removeHandler(handler) def file_handler(filepath: PathOrStr) -> ContextManager[None]: """ A context manager that can be used to route logs to a file by adding a :class:`logging.FileHandler` to the root logger's handlers. For example, .. code-block:: from tango.common.logging import initialize_logging, file_handler, teardown_logging initialize_logging(log_level="info") logger = logging.getLogger() logger.info("Hi!") with file_handler("log.out"): logger.info("This message should also go into 'log.out'") teardown_logging() """ log_file = open(filepath, "w") handlers: List[logging.Handler] = [] console = Console( color_system=None, file=log_file, force_terminal=False, width=TANGO_CONSOLE_WIDTH or 160, ) for is_cli_handler in (True, False): handler = RichHandler( console=console, markup=is_cli_handler, ) handler.addFilter(CliFilter(filter_out=not is_cli_handler)) handlers.append(handler) return insert_handlers(*handlers) ================================================ FILE: tango/common/params.py ================================================ import copy import json import logging import os import zlib from collections import OrderedDict from collections.abc import MutableMapping from itertools import chain from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Set, TypeVar, Union import yaml from rjsonnet import evaluate_file, evaluate_snippet from .aliases import PathOrStr from .exceptions import ConfigurationError from .util import could_be_class_name logger = logging.getLogger(__name__) def infer_and_cast(value: Any): """ In some cases we'll be feeding params dicts to functions we don't own; for example, PyTorch optimizers. In that case we can't use ``pop_int`` or similar to force casts (which means you can't specify ``int`` parameters using environment variables). This function takes something that looks JSON-like and recursively casts things that look like (bool, int, float) to (bool, int, float). """ if isinstance(value, (int, float, bool)): # Already one of our desired types, so leave as is. return value elif isinstance(value, list): # Recursively call on each list element. return [infer_and_cast(item) for item in value] elif isinstance(value, dict): # Recursively call on each dict value. return {key: infer_and_cast(item) for key, item in value.items()} elif isinstance(value, str): # If it looks like a bool, make it a bool. if value.lower() == "true": return True elif value.lower() == "false": return False else: # See if it could be an int. try: return int(value) except ValueError: pass # See if it could be a float. try: return float(value) except ValueError: # Just return it as a string. return value else: raise ValueError(f"cannot infer type of {value}") def _is_encodable(value: str) -> bool: """ We need to filter out environment variables that can't be unicode-encoded to avoid a "surrogates not allowed" error in jsonnet. """ # Idiomatically you'd like to not check the != b"" # but mypy doesn't like that. return (value == "") or (value.encode("utf-8", "ignore") != b"") def _environment_variables() -> Dict[str, str]: """ Wraps ``os.environ`` to filter out non-encodable values. """ return {key: value for key, value in os.environ.items() if _is_encodable(value)} T = TypeVar("T", dict, list) def with_overrides(original: T, overrides_dict: Dict[str, Any], prefix: str = "") -> T: merged: T keys: Union[Iterable[str], Iterable[int]] if isinstance(original, list): merged = [None] * len(original) keys = range(len(original)) elif isinstance(original, dict): merged = {} keys = chain( original.keys(), (k for k in overrides_dict if "." not in k and k not in original) ) else: if prefix: raise ValueError( f"overrides for '{prefix[:-1]}.*' expected list or dict in original, " f"found {type(original)} instead" ) else: raise ValueError(f"expected list or dict, found {type(original)} instead") used_override_keys: Set[str] = set() for key in keys: if str(key) in overrides_dict: merged[key] = copy.deepcopy(overrides_dict[str(key)]) used_override_keys.add(str(key)) else: overrides_subdict = {} for o_key in overrides_dict: if o_key.startswith(f"{key}."): overrides_subdict[o_key[len(f"{key}.") :]] = overrides_dict[o_key] used_override_keys.add(o_key) if overrides_subdict: merged[key] = with_overrides( original[key], overrides_subdict, prefix=prefix + f"{key}." ) else: merged[key] = copy.deepcopy(original[key]) unused_override_keys = [prefix + key for key in set(overrides_dict.keys()) - used_override_keys] if unused_override_keys: raise ValueError(f"overrides dict contains unused keys: {unused_override_keys}") return merged def parse_overrides( serialized_overrides: str, ext_vars: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: if serialized_overrides: ext_vars = {**_environment_variables(), **(ext_vars or {})} return json.loads(evaluate_snippet("", serialized_overrides, ext_vars=ext_vars)) else: return {} def _is_dict_free(obj: Any) -> bool: """ Returns False if obj is a dict, or if it's a list with an element that _has_dict. """ if isinstance(obj, dict): return False elif isinstance(obj, list): return all(_is_dict_free(item) for item in obj) else: return True def pop_choice( params: Dict[str, Any], key: str, choices: List[Any], default_to_first_choice: bool = False, history: str = "?.", allow_class_names: bool = True, ) -> Any: """ Performs the same function as ``Params.pop_choice``, but is required in order to deal with places that the Params object is not welcome, such as inside Keras layers. See the docstring of that method for more detail on how this function works. This method adds a ``history`` parameter, in the off-chance that you know it, so that we can reproduce ``Params.pop_choice`` exactly. We default to using "?." if you don't know the history, so you'll have to fix that in the log if you want to actually recover the logged parameters. """ value = Params(params, history).pop_choice( key, choices, default_to_first_choice, allow_class_names=allow_class_names ) return value def _replace_none(params: Any) -> Any: if isinstance(params, str) and params == "None": return None elif isinstance(params, (dict, Params)): if isinstance(params, Params): params = params.as_dict(quiet=True) for key, value in params.items(): params[key] = _replace_none(value) return params elif isinstance(params, list): return [_replace_none(value) for value in params] return params def remove_keys_from_params(params: "Params", keys: List[str] = ["pretrained_file", "initializer"]): if isinstance(params, Params): # The model could possibly be a string, for example. param_keys = params.keys() for key in keys: if key in param_keys: del params[key] for value in params.values(): if isinstance(value, Params): remove_keys_from_params(value, keys) elif isinstance(value, list): for item in value: if isinstance(item, Params): remove_keys_from_params(item, keys) class Params(MutableMapping): """ A :class:`~collections.abc.MutableMapping` that represents a parameter dictionary with a history, and contains other functionality around parameter passing and validation for AI2 Tango. There are currently two benefits of a ``Params`` object over a plain dictionary for parameter passing: 1. We handle a few kinds of parameter validation, including making sure that parameters representing discrete choices actually have acceptable values, and making sure no extra parameters are passed. 2. We log all parameter reads, including default values. This gives a more complete specification of the actual parameters used than is given in a JSON file, because those may not specify what default values were used, whereas this will log them. .. important:: The convention for using a ``Params`` object in Tango is that you will consume the parameters as you read them, so that there are none left when you've read everything you expect. This lets us easily validate that you didn't pass in any ``extra`` parameters, just by making sure that the parameter dictionary is empty. You should do this when you're done handling parameters, by calling :meth:`Params.assert_empty()`. """ # This allows us to check for the presence of "None" as a default argument, # which we require because we make a distinction between passing a value of "None" # and passing no value to the default parameter of "pop". DEFAULT = object() def __init__(self, params: "MutableMapping[str, Any]", history: str = "") -> None: if isinstance(params, Params): self.params: MutableMapping = params.params else: self.params = _replace_none(params) self.history = history def pop(self, key: str, default: Any = DEFAULT, keep_as_dict: bool = False) -> Any: """ Performs the functionality associated with ``dict.pop(key)``, along with checking for returned dictionaries, replacing them with Param objects with an updated history (unless keep_as_dict is True, in which case we leave them as dictionaries). If ``key`` is not present in the dictionary, and no default was specified, we raise a :class:`~tango.common.exceptions.ConfigurationError`, instead of the typical ``KeyError``. """ if default is self.DEFAULT: try: value = self.params.pop(key) except KeyError: msg = f'key "{key}" is required' if self.history: msg += f' at location "{self.history}"' raise ConfigurationError(msg) else: value = self.params.pop(key, default) logger.debug(f"{self.history}{key} = {value}") if keep_as_dict or _is_dict_free(value): return value else: return self._check_is_dict(key, value) def pop_int(self, key: str, default: Any = DEFAULT) -> Optional[int]: """ Performs a pop and coerces to an int. """ value = self.pop(key, default) if value is None: return None else: return int(value) def pop_float(self, key: str, default: Any = DEFAULT) -> Optional[float]: """ Performs a pop and coerces to a float. """ value = self.pop(key, default) if value is None: return None else: return float(value) def pop_bool(self, key: str, default: Any = DEFAULT) -> Optional[bool]: """ Performs a pop and coerces to a bool. """ value = self.pop(key, default) if value is None: return None elif isinstance(value, bool): return value elif value == "true": return True elif value == "false": return False else: raise ValueError("Cannot convert variable to bool: " + value) def get(self, key: str, default: Any = DEFAULT): """ Performs the functionality associated with ``dict.get(key)`` but also checks for returned dicts and returns a ``Params`` object in their place with an updated history. """ default = None if default is self.DEFAULT else default value = self.params.get(key, default) return self._check_is_dict(key, value) def pop_choice( self, key: str, choices: List[Any], default_to_first_choice: bool = False, allow_class_names: bool = True, ) -> Any: """ Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of the given choices. Note that this ``pops`` the key from params, modifying the dictionary, consistent with how parameters are processed in this codebase. :param key: Key to get the value from in the param dictionary :param choices: A list of valid options for values corresponding to ``key``. For example, if you're specifying the type of encoder to use for some part of your model, the choices might be the list of encoder classes we know about and can instantiate. If the value we find in the param dictionary is not in ``choices``, we raise a :class:`~tango.common.exceptions.ConfigurationError`, because the user specified an invalid value in their parameter file. :param default_to_first_choice: If this is ``True``, we allow the ``key`` to not be present in the parameter dictionary. If the key is not present, we will use the return as the value the first choice in the ``choices`` list. If this is ``False``, we raise a :class:`~tango.common.exceptions.ConfigurationError`, because specifying the ``key`` is required (e.g., you ``have`` to specify your model class when running an experiment, but you can feel free to use default settings for encoders if you want). :param allow_class_names: If this is ``True``, then we allow unknown choices that look like fully-qualified class names. This is to allow e.g. specifying a model type as ``my_library.my_model.MyModel`` and importing it on the fly. Our check for "looks like" is extremely lenient and consists of checking that the value contains a '.'. """ default = choices[0] if default_to_first_choice else self.DEFAULT value = self.pop(key, default) ok_because_class_name = allow_class_names and could_be_class_name(value) if value not in choices and not ok_because_class_name: key_str = self.history + key message = ( f"'{value}' not in acceptable choices for {key_str}: {choices}. " "You should either use the --include-package flag to make sure the correct module " "is loaded, or use a fully qualified class name in your config file like " """{"model": "my_module.models.MyModel"} to have it imported automatically.""" ) raise ConfigurationError(message) return value def as_dict(self, quiet: bool = False, infer_type_and_cast: bool = False): """ Sometimes we need to just represent the parameters as a dict, for instance when we pass them to PyTorch code. :param quiet: Whether to log the parameters before returning them as a dict. :param infer_type_and_cast: If ``True``, we infer types and cast (e.g. things that look like floats to floats). """ if infer_type_and_cast: params_as_dict = infer_and_cast(self.params) else: params_as_dict = self.params if quiet: return params_as_dict def log_recursively(parameters, history): for key, value in parameters.items(): if isinstance(value, dict): new_local_history = history + key + "." log_recursively(value, new_local_history) else: logger.debug(f"{history}{key} = {value}") log_recursively(self.params, self.history) return params_as_dict def as_flat_dict(self) -> Dict[str, Any]: """ Returns the parameters of a flat dictionary from keys to values. Nested structure is collapsed with periods. """ flat_params = {} def recurse(parameters, path): for key, value in parameters.items(): newpath = path + [key] if isinstance(value, dict): recurse(value, newpath) else: flat_params[".".join(newpath)] = value recurse(self.params, []) return flat_params def duplicate(self) -> "Params": """ Uses ``copy.deepcopy()`` to create a duplicate (but fully distinct) copy of these Params. """ return copy.deepcopy(self) def assert_empty(self, name: str): """ Raises a :class:`~tango.common.exceptions.ConfigurationError` if ``self.params`` is not empty. We take ``name`` as an argument so that the error message gives some idea of where an error happened, if there was one. For example, ``name`` could be the name of the ``calling`` class that got extra parameters (if there are any). """ if self.params: raise ConfigurationError("Extra parameters passed to {}: {}".format(name, self.params)) def __getitem__(self, key): if key in self.params: return self._check_is_dict(key, self.params[key]) else: raise KeyError(str(key)) def __setitem__(self, key, value): self.params[key] = value def __delitem__(self, key): del self.params[key] def __iter__(self): return iter(self.params) def __len__(self): return len(self.params) def _check_is_dict(self, new_history, value): if isinstance(value, dict): new_history = self.history + new_history + "." return Params(value, history=new_history) if isinstance(value, list): value = [self._check_is_dict(f"{new_history}.{i}", v) for i, v in enumerate(value)] return value @classmethod def from_file( cls, params_file: PathOrStr, params_overrides: Union[str, Dict[str, Any]] = "", ext_vars: Optional[dict] = None, ) -> "Params": """ Load a ``Params`` object from a configuration file. :param params_file: The path to the configuration file to load. Can be JSON, Jsonnet, or YAML. :param params_overrides: A dict of overrides that can be applied to final object. e.g. ``{"model.embedding_dim": 10}`` will change the value of "embedding_dim" within the "model" object of the config to 10. If you wanted to override the entire "model" object of the config, you could do ``{"model": {"type": "other_type", ...}}``. :param ext_vars: Our config files are Jsonnet, which allows specifying external variables for later substitution. Typically we substitute these using environment variables; however, you can also specify them here, in which case they take priority over environment variables. e.g. ``{"HOME_DIR": "/Users/allennlp/home"}`` """ if ext_vars is None: ext_vars = {} # redirect to cache, if necessary from cached_path import cached_path params_file: Path = Path(cached_path(params_file)) if not params_file.is_file(): raise FileNotFoundError(params_file) file_dict: Dict[str, Any] if params_file.suffix in {".yml", ".yaml"}: with open(params_file) as f: file_dict = yaml.safe_load(f) else: # Fall back to JSON/Jsonnet. ext_vars = {**_environment_variables(), **ext_vars} json_str = evaluate_file(params_file.name, str(params_file.parent), ext_vars=ext_vars) file_dict = json.loads(json_str) if isinstance(params_overrides, dict): params_overrides = json.dumps(params_overrides) overrides_dict = parse_overrides(params_overrides, ext_vars=ext_vars) if overrides_dict: param_dict = with_overrides(file_dict, overrides_dict) else: param_dict = file_dict return cls(param_dict) def to_file( self, params_file: PathOrStr, preference_orders: Optional[List[List[str]]] = None ) -> None: """ Write the params to file. """ with open(params_file, "w") as handle: json.dump(self.as_ordered_dict(preference_orders), handle, indent=4) def as_ordered_dict(self, preference_orders: Optional[List[List[str]]] = None) -> OrderedDict: """ Returns an ``OrderedDict`` of ``Params`` from list of partial order preferences. :param preference_orders: ``preference_orders`` is list of partial preference orders. ["A", "B", "C"] means "A" > "B" > "C". For multiple preference_orders first will be considered first. Keys not found, will have last but alphabetical preference. Default Preferences: ``[["dataset_reader", "iterator", "model", "train_data_path", "validation_data_path", "test_data_path", "trainer", "vocabulary"], ["type"]]`` """ params_dict = self.as_dict(quiet=True) if not preference_orders: preference_orders = [] preference_orders.append(["type"]) def order_func(key): # Makes a tuple to use for ordering. The tuple is an index into each of the `preference_orders`, # followed by the key itself. This gives us integer sorting if you have a key in one of the # `preference_orders`, followed by alphabetical ordering if not. order_tuple = [ order.index(key) if key in order else len(order) for order in preference_orders # type: ignore ] return order_tuple + [key] def order_dict(dictionary, order_func): # Recursively orders dictionary according to scoring order_func result = OrderedDict() for key, val in sorted(dictionary.items(), key=lambda item: order_func(item[0])): result[key] = order_dict(val, order_func) if isinstance(val, dict) else val return result return order_dict(params_dict, order_func) def get_hash(self) -> str: """ Returns a hash code representing the current state of this ``Params`` object. We don't want to implement ``__hash__`` because that has deeper python implications (and this is a mutable object), but this will give you a representation of the current state. We use ``zlib.adler32`` instead of Python's builtin ``hash`` because the random seed for the latter is reset on each new program invocation, as discussed here: https://stackoverflow.com/questions/27954892/deterministic-hashing-in-python-3. """ dumped = json.dumps(self.params, sort_keys=True) hashed = zlib.adler32(dumped.encode()) return str(hashed) def __str__(self) -> str: return f"{self.history}Params({self.params})" ================================================ FILE: tango/common/registrable.py ================================================ """ :class:`Registrable` is a "mixin" for endowing any base class with a named registry for its subclasses and a decorator for registering them. """ import importlib import logging from collections import defaultdict from typing import ( Callable, ClassVar, DefaultDict, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast, ) from .exceptions import ConfigurationError, IntegrationMissingError, RegistryKeyError from .from_params import FromParams from .util import ( could_be_class_name, find_integrations, find_submodules, import_module_and_submodules, ) logger = logging.getLogger(__name__) _T = TypeVar("_T") _RegistrableT = TypeVar("_RegistrableT", bound="Registrable") _SubclassRegistry = Dict[str, Tuple[type, Optional[str]]] class Registrable(FromParams): """ Any class that inherits from ``Registrable`` gains access to a named registry for its subclasses. To register them, just decorate them with the classmethod ``@BaseClass.register(name)``. After which you can call ``BaseClass.list_available()`` to get the keys for the registered subclasses, and ``BaseClass.by_name(name)`` to get the corresponding subclass. Note that the registry stores the subclasses themselves; not class instances. In most cases you would then call :meth:`~tango.common.from_params.FromParams.from_params()` on the returned subclass. You can specify a default by setting ``BaseClass.default_implementation``. If it is set, it will be the first element of :meth:`list_available()`. Note that if you use this class to implement a new ``Registrable`` abstract class, you must ensure that all subclasses of the abstract class are loaded when the module is loaded, because the subclasses register themselves in their respective files. You can achieve this by having the abstract class and all subclasses in the ``__init__.py`` of the module in which they reside (as this causes any import of either the abstract class or a subclass to load all other subclasses and the abstract class). """ _registry: ClassVar[DefaultDict[type, _SubclassRegistry]] = defaultdict(dict) default_implementation: Optional[str] = None @classmethod def register( cls, name: str, constructor: Optional[str] = None, exist_ok: bool = False ) -> Callable[[Type[_T]], Type[_T]]: """ Register a class under a particular name. :param name: The name to register the class under. :param constructor: The name of the method to use on the class to construct the object. If this is given, we will use this method (which must be a ``@classmethod``) instead of the default constructor. :param exist_ok: If True, overwrites any existing models registered under ``name``. Else, throws an error if a model is already registered under ``name``. Examples -------- To use this class, you would typically have a base class that inherits from ``Registrable``:: class Vocabulary(Registrable): ... Then, if you want to register a subclass, you decorate it like this:: @Vocabulary.register("my-vocabulary") class MyVocabulary(Vocabulary): def __init__(self, param1: int, param2: str): ... Registering a class like this will let you instantiate a class from a config file, where you give ``"type": "my-vocabulary"``, and keys corresponding to the parameters of the ``__init__`` method (note that for this to work, those parameters must have type annotations). If you want to have the instantiation from a config file call a method other than the constructor, either because you have several different construction paths that could be taken for the same object (as we do in ``Vocabulary``) or because you have logic you want to happen before you get to the constructor (as we do in ``Embedding``), you can register a specific ``@classmethod`` as the constructor to use, like this:: @Vocabulary.register("my-vocabulary-from-instances", constructor="from_instances") @Vocabulary.register("my-vocabulary-from-files", constructor="from_files") class MyVocabulary(Vocabulary): def __init__(self, some_params): ... @classmethod def from_instances(cls, some_other_params) -> MyVocabulary: ... # construct some_params from instances return cls(some_params) @classmethod def from_files(cls, still_other_params) -> MyVocabulary: ... # construct some_params from files return cls(some_params) """ if _cls_is_step(cls) and name == "ref": raise ConfigurationError( "You cannot use the name 'ref' to register a step. This name is reserved." ) registry = Registrable._registry[cls] def add_subclass_to_registry(subclass: Type[_T]) -> Type[_T]: # Add to registry, raise an error if key has already been used. if name in registry: already_in_use_for = registry[name][0] if already_in_use_for.__module__ == "__main__": # Sometimes the same class shows up under module.submodule.Class and __main__.Class, and we # don't want to make a fuss in that case. We prefer the class without __main__, so we go # ahead and overwrite the entry. pass elif subclass.__module__ == "__main__": # We don't want to overwrite the entry because the new one comes from the __main__ module. return already_in_use_for elif exist_ok: message = ( f"Registering {_fullname(subclass)} as a {_fullname(cls)} under the name {name} " f"overwrites existing entry {_fullname(already_in_use_for)}, which is fine because " "you said exist_ok=True." ) logger.info(message) else: message = ( f"Attempting to register {_fullname(subclass)} as a {_fullname(cls)} under the name " f"'{name}' failed. {_fullname(already_in_use_for)} is already registered under that name." ) raise ConfigurationError(message) registry[name] = (subclass, constructor) return subclass return add_subclass_to_registry @classmethod def by_name(cls: Type[_RegistrableT], name: str) -> Callable[..., _RegistrableT]: """ Returns a callable function that constructs an argument of the registered class. Because you can register particular functions as constructors for specific names, this isn't necessarily the ``__init__`` method of some class. """ logger.debug(f"instantiating registered subclass {name} of {cls}") subclass, constructor = cls.resolve_class_name(name) if not constructor: return cast(Type[_RegistrableT], subclass) else: return cast(Callable[..., _RegistrableT], getattr(subclass, constructor)) @classmethod def search_modules(cls: Type[_RegistrableT], name: str): """ Search for and import modules where ``name`` might be registered. """ if ( could_be_class_name(name) or name in Registrable._registry[cls] or (_cls_is_step(cls) and name == "ref") ): return None def try_import(module, recursive: bool = True): try: import_module_and_submodules(module, recursive=recursive) except IntegrationMissingError: pass except ImportError as e: if e.name != module: raise integrations = {m.split(".")[-1]: m for m in find_integrations()} integrations_imported: Set[str] = set() if name in integrations: try_import(integrations[name], recursive=False) integrations_imported.add(name) if name in Registrable._registry[cls]: return None if "::" in name: # Try to guess the integration that it comes from. maybe_integration = name.split("::")[0] if maybe_integration in integrations: try_import(integrations[maybe_integration], recursive=False) integrations_imported.add(maybe_integration) if name in Registrable._registry[cls]: return None # Check Python files and modules in the current directory. from glob import glob from pathlib import Path for pyfile in glob("*.py"): module = str(Path(pyfile).with_suffix("")) if module == "setup": continue try: try_import(module) if name in Registrable._registry[cls]: return None except: # noqa: E722 continue for pyinit in glob("**/__init__.py"): module = str(Path(pyinit).parent) if module == "tango" or module.startswith("test"): continue try: try_import(module) if name in Registrable._registry[cls]: return None except: # noqa: E722 continue # Search all other modules in Tango. for module in find_submodules(exclude={"tango.integrations*"}, recursive=False): try_import(module) if name in Registrable._registry[cls]: return None # Try importing all other integrations. for integration_name, module in integrations.items(): if integration_name not in integrations_imported: try_import(module, recursive=False) integrations_imported.add(integration_name) if name in Registrable._registry[cls]: return None @classmethod def resolve_class_name( cls: Type[_RegistrableT], name: str, search_modules: bool = True, ) -> Tuple[Type[_RegistrableT], Optional[str]]: """ Returns the subclass that corresponds to the given ``name``, along with the name of the method that was registered as a constructor for that ``name``, if any. This method also allows ``name`` to be a fully-specified module name, instead of a name that was already added to the ``Registry``. In that case, you cannot use a separate function as a constructor (as you need to call ``cls.register()`` in order to tell us what separate function to use). If the ``name`` given is not in the registry and ``search_modules`` is ``True``, it will search for and import modules where the class might be defined according to :meth:`search_modules()`. """ if name in Registrable._registry[cls]: subclass, constructor = Registrable._registry[cls][name] return subclass, constructor elif could_be_class_name(name): # This might be a fully qualified class name, so we'll try importing its "module" # and finding it there. parts = name.split(".") submodule = ".".join(parts[:-1]) class_name = parts[-1] try: module = importlib.import_module(submodule) except ModuleNotFoundError: raise ConfigurationError( f"tried to interpret {name} as a path to a class " f"but unable to import module {submodule}" ) try: subclass = getattr(module, class_name) constructor = None return subclass, constructor except AttributeError: raise ConfigurationError( f"tried to interpret {name} as a path to a class " f"but unable to find class {class_name} in {submodule}" ) else: # is not a qualified class name if search_modules: cls.search_modules(name) return cls.resolve_class_name(name, search_modules=False) available = cls.list_available() suggestion = _get_suggestion(name, available) raise RegistryKeyError( ( f"'{name}' is not a registered name for '{cls.__name__}'" + (". " if not suggestion else f", did you mean '{suggestion}'? ") ) + "If your registered class comes from custom code, you'll need to import " "the corresponding modules. If you're using Tango or AllenNLP from the command-line, " "this is done by using the '--include-package' flag, or by specifying your imports " "in a 'tango.yml' settings file. " "Alternatively, you can specify your choices " """using fully-qualified paths, e.g. {"model": "my_module.models.MyModel"} """ "in which case they will be automatically imported correctly." ) @classmethod def list_available(cls) -> List[str]: """List default first if it exists""" keys = list(Registrable._registry[cls].keys()) default = cls.default_implementation if default is None: return keys if default not in keys: cls.search_modules(default) keys = list(Registrable._registry[cls].keys()) if default not in keys: raise ConfigurationError(f"Default implementation '{default}' is not registered") else: return [default] + [k for k in keys if k != default] class RegistrableFunction(Registrable): """ A registrable class mimicking a `Callable`. This is to allow referring to functions by their name in tango configurations. """ WRAPPED_FUNC: ClassVar[Callable] def __call__(self, *args, **kwargs): return self.__class__.WRAPPED_FUNC(*args, **kwargs) def make_registrable(name: Optional[str] = None, *, exist_ok: bool = False): """ A decorator to create a :class:`RegistrableFunction` from a function. :param name: A name to register the function under. By default the name of the function is used. :param exist_ok: If True, overwrites any existing function registered under the same ``name``. Else, throws an error if a function is already registered under ``name``. """ def function_wrapper(func): @RegistrableFunction.register(name or func.__name__, exist_ok=exist_ok) class WrapperFunc(RegistrableFunction): WRAPPED_FUNC = func return WrapperFunc() return function_wrapper def _get_suggestion(name: str, available: List[str]) -> Optional[str]: # Check for simple mistakes like using '-' instead of '_', or vice-versa. for ch, repl_ch in (("_", "-"), ("-", "_")): suggestion = name.replace(ch, repl_ch) if suggestion in available: return suggestion return None def _fullname(c: type) -> str: return f"{c.__module__}.{c.__qualname__}" def _cls_is_step(c: type) -> bool: # NOTE (epwalsh): importing the actual Step class here would result in a circular # import, even though the import wouldn't be at the top of the module (believe me, I've tried). # So instead we just check the fully qualified name of the class. return _fullname(c) == "tango.step.Step" ================================================ FILE: tango/common/remote_utils.py ================================================ import logging from typing import Union from tango.step import Step from tango.step_info import StepInfo logger = logging.getLogger(__name__) class RemoteConstants: """ Common constants to be used as prefixes and filenames in remote workspaces. """ RUN_ARTIFACT_PREFIX: str = "tango-run-" RUN_DATA_FNAME: str = "run.json" STEP_ARTIFACT_PREFIX: str = "tango-step-" STEP_INFO_FNAME: str = "step_info.json" STEP_RESULT_DIR: str = "result" STEP_GRAPH_ARTIFACT_PREFIX: str = "tango-step-graph-" STEP_EXPERIMENT_PREFIX: str = "tango-step-" STEP_GRAPH_FILENAME: str = "config.json" GITHUB_TOKEN_SECRET_NAME: str = "TANGO_GITHUB_TOKEN" RESULTS_DIR: str = "/tango/output" INPUT_DIR: str = "/tango/input" LOCK_ARTIFACT_SUFFIX: str = "-lock" @classmethod def step_artifact_name(cls, step: Union[str, StepInfo, Step]) -> str: return f"{cls.STEP_ARTIFACT_PREFIX}{step if isinstance(step, str) else step.unique_id}" @classmethod def step_lock_artifact_name(cls, step: Union[str, StepInfo, Step]) -> str: return f"{cls.step_artifact_name(step)}{cls.LOCK_ARTIFACT_SUFFIX}" @classmethod def run_artifact_name(cls, name: str) -> str: return f"{cls.RUN_ARTIFACT_PREFIX}{name}" ================================================ FILE: tango/common/sequences.py ================================================ import bisect import os import random import shutil from collections import abc from os import PathLike from typing import Any, Callable, Iterable, MutableSequence, Optional, Sequence, Union class ShuffledSequence(abc.Sequence): """ Produces a shuffled view of a sequence, such as a list. This assumes that the inner sequence never changes. If it does, the results are undefined. :param inner_sequence: the inner sequence that's being shuffled :param indices: Optionally, you can specify a list of indices here. If you don't, we'll just shuffle the inner sequence randomly. If you do specify indices, element ``n`` of the output sequence will be ``inner_sequence[indices[n]]``. This gives you great flexibility. You can repeat elements, leave them out completely, or slice the list. A Python :class:`slice` object is an acceptable input for this parameter, and so are other sequences from this module. Example: .. testcode:: :hide: import random random.seed(42) .. testcode:: from tango.common.sequences import ShuffledSequence l = [1, 2, 3, 4, 5, 6, 7, 8, 9] shuffled_l = ShuffledSequence(l) print(shuffled_l[0]) print(shuffled_l[1]) print(shuffled_l[2]) assert len(shuffled_l) == len(l) This will print something like the following: .. testoutput:: 4 7 8 """ def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None): self.inner = inner_sequence self.indices: Sequence[int] if indices is None: self.indices = list(range(len(inner_sequence))) random.shuffle(self.indices) else: self.indices = indices def __len__(self) -> int: return len(self.indices) def __getitem__(self, i: Union[int, slice]): if isinstance(i, int): return self.inner[self.indices[i]] else: return ShuffledSequence(self.inner, self.indices[i]) def __contains__(self, item) -> bool: for i in self.indices: if self.inner[i] == item: return True return False class SlicedSequence(ShuffledSequence): """ Produces a sequence that's a slice into another sequence, without copying the elements. This assumes that the inner sequence never changes. If it does, the results are undefined. :param inner_sequence: the inner sequence that's being shuffled :param s: the :class:`~slice` to slice the input with. Example: .. testcode:: from tango.common.sequences import SlicedSequence l = [1, 2, 3, 4, 5, 6, 7, 8, 9] sliced_l = SlicedSequence(l, slice(1, 4)) print(sliced_l[0]) print(sliced_l[1]) print(sliced_l[2]) assert len(sliced_l) == 3 This will print the following: .. testoutput:: 2 3 4 """ def __init__(self, inner_sequence: Sequence, s: slice): super().__init__(inner_sequence, range(*s.indices(len(inner_sequence)))) class ConcatenatedSequence(abc.Sequence): """ Produces a sequence that's the lazy concatenation of multiple other sequences. It does not copy any of the elements of the original sequences. This assumes that the inner sequences never change. If they do, the results are undefined. :param sequences: the inner sequences to concatenate Example: .. testcode:: from tango.common.sequences import ConcatenatedSequence l1 = [1, 2, 3] l2 = [4, 5] l3 = [6] cat_l = ConcatenatedSequence(l1, l2, l3) assert len(cat_l) == 6 for i in cat_l: print(i) This will print the following: .. testoutput:: 1 2 3 4 5 6 """ def __init__(self, *sequences: Sequence): self.sequences = sequences self.cumulative_sequence_lengths = [0] for sequence in sequences: self.cumulative_sequence_lengths.append( self.cumulative_sequence_lengths[-1] + len(sequence) ) def __len__(self): return self.cumulative_sequence_lengths[-1] def __getitem__(self, i: Union[int, slice]): if isinstance(i, int): if i < 0: i += len(self) if i < 0 or i >= len(self): raise IndexError("list index out of range") sequence_index = bisect.bisect_right(self.cumulative_sequence_lengths, i) - 1 i -= self.cumulative_sequence_lengths[sequence_index] return self.sequences[sequence_index][i] else: return SlicedSequence(self, i) def __contains__(self, item) -> bool: return any(s.__contains__(item) for s in self.sequences) class MappedSequence(abc.Sequence): """ Produces a sequence that applies a function to every element of another sequence. This is similar to Python's :func:`map`, but it returns a sequence instead of a :class:`map` object. :param fn: the function to apply to every element of the inner sequence. The function should take one argument. :param inner_sequence: the inner sequence to map over Example: .. testcode:: from tango.common.sequences import MappedSequence def square(x): return x * x l = [1, 2, 3, 4] map_l = MappedSequence(square, l) assert len(map_l) == len(l) for i in map_l: print(i) This will print the following: .. testoutput:: 1 4 9 16 """ def __init__(self, fn: Callable, inner_sequence: Sequence): self.inner = inner_sequence self.fn = fn def __getitem__(self, item): if isinstance(item, slice): new_inner = None try: # special case for a special library from datasets import Dataset if isinstance(self.inner, Dataset): new_inner = self.inner.select(range(*item.indices(len(self.inner)))) except ImportError: pass if new_inner is None: new_inner = self.inner[item] return MappedSequence(self.fn, new_inner) else: item = self.inner.__getitem__(item) return self.fn(item) def __len__(self): return self.inner.__len__() def __contains__(self, item): return any(e == item for e in self) class SqliteSparseSequence(MutableSequence[Any]): """ This is a sparse sequence that pickles elements to a Sqlite database. When you read from the sequence, elements are retrieved and unpickled lazily. That means creating/opening a sequence is very fast and does not depend on the length of the sequence. This is a "sparse sequence" because you can set element ``n`` before you set element ``n-1``: .. testcode:: :hide: from tango.common.sequences import SqliteSparseSequence import tempfile dir = tempfile.TemporaryDirectory() from pathlib import Path filename = Path(dir.name) / "test.sqlite" .. testcode:: s = SqliteSparseSequence(filename) element = "Big number, small database." s[2**32] = element assert len(s) == 2**32 + 1 assert s[2**32] == element assert s[1000] is None s.close() .. testcode:: :hide: dir.cleanup() You can use a ``SqliteSparseSequence`` from multiple processes at the same time. This is useful, for example, if you're filling out a sequence and you are partitioning ranges to processes. :param filename: the filename at which to store the data :param read_only: Set this to ``True`` if you only want to read. """ def __init__(self, filename: Union[str, PathLike], read_only: bool = False): from sqlitedict import SqliteDict self.table = SqliteDict(filename, "sparse_sequence", flag="r" if read_only else "c") def __del__(self): if self.table is not None: self.table.close(force=True) self.table = None def __getitem__(self, i: Union[int, slice]) -> Any: if isinstance(i, int): try: return self.table[str(i)] except KeyError: current_length = len(self) if i >= current_length or current_length <= 0: raise IndexError("list index out of range") elif i < 0 < current_length: return self.__getitem__(i % current_length) else: return None elif isinstance(i, slice): return SlicedSequence(self, i) else: raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}") def __setitem__(self, i: Union[int, slice], value: Any): if isinstance(i, int): current_length = len(self) if i < 0: i %= current_length self.table[str(i)] = value self.table["_len"] = max(i + 1, current_length) self.table.commit() else: raise TypeError(f"list indices must be integers, not {i.__class__.__name__}") def __delitem__(self, i: Union[int, slice]): current_length = len(self) if isinstance(i, int): if i < 0: i %= current_length if i >= current_length: raise IndexError("list assignment index out of range") for index in range(i + 1, current_length): self.table[str(index - 1)] = self.table.get(str(index)) del self.table[str(current_length - 1)] self.table["_len"] = current_length - 1 self.table.commit() elif isinstance(i, slice): # This isn't very efficient for continuous slices. for index in reversed(range(*i.indices(current_length))): del self[index] else: raise TypeError(f"list indices must be integers or slices, not {i.__class__.__name__}") def extend(self, values: Iterable[Any]) -> None: current_length = len(self) index = -1 for index, value in enumerate(values): self.table[str(index + current_length)] = value if index < 0: return self.table["_len"] = current_length + index + 1 self.table.commit() def insert(self, i: int, value: Any) -> None: current_length = len(self) for index in reversed(range(i, current_length)): self.table[str(index + 1)] = self.table.get(str(index)) self.table[str(i)] = value self.table["_len"] = max(i + 1, current_length + 1) self.table.commit() def __len__(self) -> int: try: return self.table["_len"] except KeyError: return 0 def clear(self) -> None: """ Clears the entire sequence """ self.table.clear() self.table.commit() def close(self) -> None: """ Closes the underlying Sqlite table. Do not use this sequence afterwards! """ if self.table is not None: self.table.close() self.table = None def copy_to(self, target: Union[str, PathLike]): """ Make a copy of this sequence at a new location. :param target: the location of the copy This will attempt to make a hardlink, which is very fast, but only works on Linux and if ``target`` is on the same drive. If making a hardlink fails, it falls back to making a regular copy. As a result, there is no guarantee whether you will get a hardlink or a copy. If you get a hardlink, future edits in the source sequence will also appear in the target sequence. This is why we recommend to not use :meth:`copy_to()` until you are done with the sequence. This is not ideal, but it is a compromise we make for performance. """ try: os.link(self.table.filename, target) except OSError as e: if e.errno == 18: # Cross-device link shutil.copy(self.table.filename, target) else: raise ================================================ FILE: tango/common/testing/__init__.py ================================================ import logging import os import shutil import tempfile from contextlib import contextmanager from pathlib import Path from typing import Any, Dict, List, Optional, Union, cast from tango.common.aliases import EnvVarNames, PathOrStr from tango.common.logging import initialize_logging, teardown_logging from tango.common.params import Params from tango.settings import TangoGlobalSettings class TangoTestCase: """ A custom testing class that * disables some of the more verbose logging, * creates and destroys a temp directory as a test fixture, and * restores the internal state of the `Registrable` class at the end of each test method. """ PROJECT_ROOT = (Path(__file__).parent / ".." / ".." / "..").resolve() """ Root of the git repository. """ # to run test suite with finished package, which does not contain # tests & fixtures, we must be able to look them up somewhere else PROJECT_ROOT_FALLBACK = ( # users wanting to run test suite for installed package Path(os.environ["TANGO_SRC_DIR"]) if "TANGO_SRC_DIR" in os.environ else ( # fallback for conda packaging Path(os.environ["SRC_DIR"]) if "CONDA_BUILD" in os.environ # stay in-tree else PROJECT_ROOT ) ) MODULE_ROOT = PROJECT_ROOT_FALLBACK / "tango" """ Root of the tango module. """ TESTS_ROOT = PROJECT_ROOT_FALLBACK / "tests" """ Root of the tests directory. """ FIXTURES_ROOT = PROJECT_ROOT_FALLBACK / "test_fixtures" """ Root of the test fixtures directory. """ def setup_method(self): logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.DEBUG ) # Disabling some of the more verbose logging statements that typically aren't very helpful # in tests. logging.getLogger("urllib3.connectionpool").disabled = True # Create a temporary scratch directory. self.TEST_DIR = Path(tempfile.mkdtemp(prefix="tango_tests")) os.makedirs(self.TEST_DIR, exist_ok=True) # Set an artificial console width so logs are not mangled. os.environ[EnvVarNames.CONSOLE_WIDTH.value] = str(300) def teardown_method(self): shutil.rmtree(self.TEST_DIR) if EnvVarNames.CONSOLE_WIDTH.value in os.environ: del os.environ[EnvVarNames.CONSOLE_WIDTH.value] def run( self, config: Union[PathOrStr, Dict[str, Any], Params], overrides: Optional[Union[Dict[str, Any], str]] = None, include_package: Optional[List[str]] = None, workspace_url: Optional[str] = None, step_name: Optional[str] = None, parallelism: Optional[int] = 1, multicore: Optional[bool] = False, name: Optional[str] = None, settings: Optional[TangoGlobalSettings] = None, ) -> Path: from tango.__main__ import _run if isinstance(config, (dict, Params)): params = config if isinstance(config, Params) else Params(config) config = self.TEST_DIR / "config.json" params.to_file(cast(Path, config)) if isinstance(overrides, dict): import json overrides = json.dumps(overrides) run_name = _run( settings or TangoGlobalSettings(), str(config), workspace_url=workspace_url or "local://" + str(self.TEST_DIR / "workspace"), overrides=overrides, include_package=include_package, step_names=None if not step_name else [step_name], parallelism=parallelism, multicore=multicore, name=name, ) return self.TEST_DIR / "workspace" / "runs" / run_name @contextmanager def run_experiment( config: Union[PathOrStr, Dict[str, Any], Params], overrides: Optional[Union[Dict[str, Any], str]] = None, file_friendly_logging: bool = True, include_package: Optional[List[str]] = None, workspace_url: Optional[str] = None, parallelism: Optional[int] = 1, multicore: Optional[bool] = False, name: Optional[str] = None, settings: Optional[TangoGlobalSettings] = None, ): """ A context manager to make testing experiments easier. On ``__enter__`` it runs the experiment and returns the path to the run directory, a temporary directory that will be cleaned up on ``__exit__``. """ initialize_logging(enable_cli_logs=True, file_friendly_logging=file_friendly_logging) test_case = TangoTestCase() try: test_case.setup_method() yield test_case.run( config, overrides=overrides, include_package=include_package, workspace_url=workspace_url, parallelism=parallelism, multicore=multicore, name=name, settings=settings, ) finally: test_case.teardown_method() teardown_logging() def requires_gpus(test_method): """ Decorator to indicate that a test requires multiple GPU devices. """ import pytest import torch return pytest.mark.gpu( pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 or more GPUs required.")( test_method ) ) ================================================ FILE: tango/common/testing/steps.py ================================================ import logging import multiprocessing as mp import random import time from string import ascii_letters from typing import List import tango.common.logging as common_logging from tango import Step from tango.common import Tqdm @Step.register("float") class FloatStep(Step): CACHEABLE = True DETERMINISTIC = True def run(self, result: float) -> float: # type: ignore return result @Step.register("string") class StringStep(Step): CACHEABLE = True DETERMINISTIC = True def run(self, result: str) -> str: # type: ignore return result @Step.register("concat_strings") class ConcatStringsStep(Step): CACHEABLE = True DETERMINISTIC = True def run(self, string1: str, string2: str, join_with: str = " ") -> str: # type: ignore return join_with.join([string1, string2]) @Step.register("noisy_step") class NoisyStep(Step): CACHEABLE = True DETERMINISTIC = True def run(self, raise_error: bool = False) -> None: # type: ignore self.logger.debug("debug message") common_logging.cli_logger.debug("debug message from cli_logger") self.logger.info("info message") common_logging.cli_logger.info("info message from cli_logger") self.logger.warning("warning message") common_logging.cli_logger.warning("warning message from cli_logger") self.logger.error("error message") common_logging.cli_logger.error("error message from cli_logger") if raise_error: raise ValueError("Oh no!") @Step.register("random_string") class RandomStringStep(Step): def run(self, length: int = 10) -> str: # type: ignore return "".join([random.choice(ascii_letters) for _ in range(length)]) @Step.register("add_numbers") class AddNumbersStep(Step): DETERMINISTIC = True CACHEABLE = True def run(self, a_number: int, b_number: int) -> int: # type: ignore return a_number + b_number @Step.register("sleep-print-maybe-fail") class SleepPrintMaybeFail(Step): DETERMINISTIC = True CACHEABLE = True def run(self, string: str, seconds: int = 5, fail: bool = False) -> str: # type: ignore time.sleep(seconds) self.logger.info(f"Step {self.name} is awake.") print(string) if fail: raise RuntimeError("Step had to fail!") return string @Step.register("logging-step") class LoggingStep(Step): DETERMINISTIC = True CACHEABLE = True def run(self, string: str, num_log_lines: int = 50) -> str: # type: ignore for i in Tqdm.tqdm(list(range(num_log_lines)), desc="log progress"): time.sleep(0.1) self.logger.info(f"{i} - {string}") return string @Step.register("make_number") class MakeNumber(Step): DETERMINISTIC = True CACHEABLE = True def run(self, what_number: int) -> int: # type: ignore return what_number @Step.register("store_number_in_file") class StoreNumberInFile(Step): DETERMINISTIC = True CACHEABLE = False def run(self, number: int, file_name: str) -> None: # type: ignore # Note: this is only for testing if the uncacheable step # ran in the multicore setting. Normally, a step like this # would be marked as CACHEABLE. with open(file_name, "w") as file_ref: file_ref.write(str(number)) @Step.register("multiprocessing_step") class MultiprocessingStep(Step): """ Mainly used to test that logging works properly in child processes. """ def run(self, num_proc: int = 2) -> bool: # type: ignore for _ in Tqdm.tqdm(list(range(10)), desc="progress from main process"): time.sleep(0.1) workers = [] for i in range(num_proc): worker = mp.Process(target=_worker_function, args=(i,)) workers.append(worker) worker.start() for worker in workers: worker.join() return True @Step.register("range_step") class RangeOutput(Step): def run(self, start: int, end: int) -> List[int]: # type: ignore return list(range(start, end)) def _worker_function(worker_id: int): common_logging.initialize_worker_logging(worker_id) logger = logging.getLogger(MultiprocessingStep.__name__) logger.info("Hello from worker %d!", worker_id) common_logging.cli_logger.info("Hello from the cli logger in worker %d!", worker_id) for _ in Tqdm.tqdm(list(range(10)), desc="progress from worker", disable=worker_id > 0): time.sleep(0.1) ================================================ FILE: tango/common/tqdm.py ================================================ """ Copied over from ``allennlp.common.tqdm.Tqdm``. Wraps tqdm so we can add configurable global defaults for certain tqdm parameters. """ import logging import sys from contextlib import contextmanager from time import time from typing import Optional try: SHELL = str(type(get_ipython())) # type:ignore # noqa: F821 except: # noqa: E722 SHELL = "" if "zmqshell.ZMQInteractiveShell" in SHELL: from tqdm import tqdm_notebook as _tqdm else: from tqdm import tqdm as _tqdm from tango.common import logging as common_logging # This is necessary to stop tqdm from hanging # when exceptions are raised inside iterators. # It should have been fixed in 4.2.1, but it still # occurs. # TODO(Mark): Remove this once tqdm cleans up after itself properly. # https://github.com/tqdm/tqdm/issues/469 _tqdm.monitor_interval = 0 logger = logging.getLogger("tqdm") logger.propagate = False def replace_cr_with_newline(message: str) -> str: """ TQDM and requests use carriage returns to get the training line to update for each batch without adding more lines to the terminal output. Displaying those in a file won't work correctly, so we'll just make sure that each batch shows up on its one line. """ # In addition to carriage returns, nested progress-bars will contain extra new-line # characters and this special control sequence which tells the terminal to move the # cursor one line up. message = message.replace("\r", "").replace("\n", "").replace("", "") if message and message[-1] != "\n": message += "\n" return message class TqdmToLogsWriter: def __init__(self): self.last_message_written_time = 0.0 def write(self, message): file_friendly_message: Optional[str] = None if common_logging.FILE_FRIENDLY_LOGGING: file_friendly_message = replace_cr_with_newline(message) if file_friendly_message.strip(): sys.stderr.write(file_friendly_message) else: sys.stderr.write(message) # Every 10 seconds we also log the message. now = time() if now - self.last_message_written_time >= 10 or "100%" in message: if file_friendly_message is None: file_friendly_message = replace_cr_with_newline(message) for message in file_friendly_message.split("\n"): message = message.strip() if len(message) > 0: logger.info(message) self.last_message_written_time = now def flush(self): sys.stderr.flush() class Tqdm: """ A `tqdm `_ wrapper that respects :data:`~tango.common.logging.FILE_FRIENDLY_LOGGING` and other Tango logging configurations. """ @staticmethod def tqdm(*args, **kwargs): new_kwargs = Tqdm.get_updated_kwargs(**kwargs) return _tqdm(*args, **new_kwargs) @staticmethod @contextmanager def wrapattr(*args, **kwargs): new_kwargs = Tqdm.get_updated_kwargs(**kwargs) with _tqdm.wrapattr(*args, **new_kwargs) as t: yield t @staticmethod def get_updated_kwargs(**kwargs): # Use a slower interval when FILE_FRIENDLY_LOGGING is set. default_mininterval = 2.0 if common_logging.FILE_FRIENDLY_LOGGING else 0.1 return { "file": TqdmToLogsWriter(), "mininterval": default_mininterval, **kwargs, } @staticmethod def set_lock(lock): _tqdm.set_lock(lock) @staticmethod def get_lock(): return _tqdm.get_lock() ================================================ FILE: tango/common/util.py ================================================ import importlib import pkgutil import signal import string import sys import traceback from collections import OrderedDict from dataclasses import asdict, is_dataclass from datetime import datetime, tzinfo from enum import Enum from pathlib import Path from typing import Any, Iterable, Optional, Set, Tuple, Union import pytz from .exceptions import SigTermReceived def tango_cache_dir() -> Path: """ Returns a directory suitable for caching things from Tango, defaulting to ``$HOME/.cache/tango``. """ cache_dir = Path.home() / ".cache" / "tango" cache_dir.mkdir(parents=True, exist_ok=True) return cache_dir def _handle_sigterm(sig, frame): raise SigTermReceived def install_sigterm_handler(): signal.signal(signal.SIGTERM, _handle_sigterm) _extra_imported_modules: Set[str] = set() def get_extra_imported_modules() -> Set[str]: return _extra_imported_modules def import_extra_module(package_name: str) -> None: global _extra_imported_modules import_module_and_submodules(package_name) _extra_imported_modules.add(package_name) def resolve_module_name(package_name: str) -> Tuple[str, Path]: base_path = Path(".") package_path = Path(package_name) if not package_path.exists(): raise ValueError(f"'{package_path}' looks like a path, but the path does not exist") parent = package_path.parent while parent != parent.parent: if (parent / "__init__.py").is_file(): parent = parent.parent else: base_path = parent break package_name = str(package_path.relative_to(base_path)).replace("/", ".") if package_path.is_file(): if package_path.name == "__init__.py": # If `__init__.py` file, resolve to the parent module. package_name = package_name[: -len(".__init__.py")] elif package_name.endswith(".py"): package_name = package_name[:-3] if not package_name: raise ValueError(f"invalid package path '{package_path}'") return package_name, base_path def import_module_and_submodules( package_name: str, exclude: Optional[Set[str]] = None, recursive: bool = True ) -> None: """ Import all submodules under the given package. Primarily useful so that people using tango can specify their own custom packages and have their custom classes get loaded and registered. """ # If `package_name` is in the form of a path, convert to the module format. if "/" in package_name or package_name.endswith(".py"): package_name, base_path = resolve_module_name(package_name) else: base_path = Path(".") base_path = base_path.resolve() if exclude and package_name in exclude: return importlib.invalidate_caches() # Ensure `base_path` is first in `sys.path`. if str(base_path) not in sys.path: sys.path.insert(0, str(base_path)) else: sys.path.insert(0, sys.path.pop(sys.path.index(str(base_path)))) # Certain packages might mess with sys.excepthook which we don't like since # we mess with sys.excepthook ourselves. If it looks like the package is overriding # the hook for a reason, we'll leave it be but also make sure our hook is still called. excepthook = sys.excepthook # Import at top level try: module = importlib.import_module(package_name) finally: if sys.excepthook != excepthook: if sys.excepthook.__module__.startswith("rich"): # We definitely don't want rich's traceback hook because that will mess # with our exception handling. sys.excepthook = excepthook else: new_hook = sys.excepthook def excepthook_wrapper(exctype, value, traceback): excepthook(exctype, value, traceback) new_hook(exctype, value, traceback) sys.excepthook = excepthook_wrapper path = getattr(module, "__path__", []) path_string = "" if not path else path[0] if recursive: # walk_packages only finds immediate children, so need to recurse. for module_finder, name, _ in pkgutil.walk_packages(path): # Sometimes when you import third-party libraries that are on your path, # `pkgutil.walk_packages` returns those too, so we need to skip them. if path_string and module_finder.path != path_string: # type: ignore[union-attr] continue subpackage = f"{package_name}.{name}" import_module_and_submodules(subpackage, exclude=exclude) def _parse_bool(value: Union[bool, str]) -> bool: if isinstance(value, bool): return value if value in {"1", "true", "True", "TRUE"}: return True return False def _parse_optional_int(value: Optional[str]) -> Optional[int]: if value is not None: return int(value) return None def find_submodules( module: Optional[str] = None, match: Optional[Set[str]] = None, exclude: Optional[Set[str]] = None, recursive: bool = True, ) -> Iterable[str]: """ Find tango submodules. """ from fnmatch import fnmatch root = Path(__file__).parent / ".." if module: if module.startswith("tango."): module = module.replace("tango.", "", 1) for m in module.split("."): root = root / m module = f"tango.{module}" else: module = "tango" for path in root.iterdir(): if path.name[0] in {"_", "."}: continue submodule: str if path.is_dir(): submodule = path.name elif path.suffix == ".py": submodule = path.name[:-3] else: continue submodule = f"{module}.{submodule}" if exclude and any((fnmatch(submodule, pat) for pat in exclude)): continue if match and not any((fnmatch(submodule, pat) for pat in match)): continue yield submodule if recursive and path.is_dir(): yield from find_submodules(submodule, match=match, exclude=exclude) def find_integrations() -> Iterable[str]: """ Find all tango integration modules. """ yield from find_submodules("tango.integrations", recursive=False) SAFE_FILENAME_CHARS = frozenset("-_.%s%s" % (string.ascii_letters, string.digits)) def filename_is_safe(filename: str) -> bool: return all(c in SAFE_FILENAME_CHARS for c in filename) def make_safe_filename(name: str) -> str: if filename_is_safe(name): return name else: from tango.common.det_hash import det_hash name_hash = det_hash(name) name = name.replace(" ", "-").replace("/", "--") return "".join(c for c in name if c in SAFE_FILENAME_CHARS) + f"-{name_hash[:7]}" def could_be_class_name(name: str) -> bool: if "." in name and not name.endswith("."): return all([_is_valid_python_name(part) for part in name.split(".")]) else: return False def _is_valid_python_name(name: str) -> bool: return bool(name and name[0].isalpha() and name.replace("_", "").isalnum()) def threaded_generator(g, queue_size: int = 16): """ Puts the generating side of a generator into its own thread Let's say you have a generator that reads records from disk, and something that consumes the generator that spends most of its time in PyTorch. Wouldn't it be great if you could read more records while the PyTorch code runs? If you wrap your record-reading generator with ``threaded_generator(inner)``, that's exactly what happens. The reading code will run in a new thread, while the consuming code runs in the main thread as normal. ``threaded_generator()`` uses a queue to hand off items. :param queue_size: the maximum queue size for hand-offs between the main thread and the generator thread """ from queue import Queue from threading import Thread q: Queue = Queue(maxsize=queue_size) sentinel = object() def fill_queue(): try: for value in g: q.put(value) finally: q.put(sentinel) thread = Thread(name=repr(g), target=fill_queue, daemon=True) thread.start() yield from iter(q.get, sentinel) thread.join() def exception_to_string(e: BaseException) -> str: """ Generates a string that contains an exception plus stack frames based on an exception. This became trivial in Python 3.10, but we need to run on Python 3.8 as well. """ if sys.version_info >= (3, 10): formatted = traceback.format_exception(e) else: formatted = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__) return "".join(formatted) def utc_now_datetime() -> datetime: return datetime.utcnow().replace(tzinfo=pytz.utc) def local_timezone() -> Optional[tzinfo]: return datetime.now().astimezone().tzinfo def replace_steps_with_unique_id(o: Any): from tango.step import Step, StepIndexer if isinstance(o, Step): return {"type": "ref", "ref": o.unique_id} elif isinstance(o, StepIndexer): return {"type": "ref", "ref": o.step.unique_id, "key": o.key} elif isinstance(o, (list, tuple, set)): return o.__class__(replace_steps_with_unique_id(i) for i in o) elif isinstance(o, dict): return {key: replace_steps_with_unique_id(value) for key, value in o.items()} else: return o def jsonify(o: Any) -> Any: """ Transform an object into a JSON-serializable equivalent (if there is one) in a deterministic way. For example, tuples and sets are turned into lists, dictionaries are turned into ordered dictionaries where the order depends on the sorting of the keys, and datetimes are turned into formatted strings. """ if isinstance(o, (tuple, set)): return [jsonify(x) for x in o] elif isinstance(o, dict): return OrderedDict((k, jsonify(v)) for k, v in sorted(o.items(), key=lambda x: x[0])) elif isinstance(o, datetime): return o.strftime("%Y-%m-%dT%H:%M:%S") elif is_dataclass(o): return jsonify(asdict(o)) elif isinstance(o, Path): return str(o) else: return o class StrEnum(str, Enum): def __str__(self) -> str: return self.value ================================================ FILE: tango/executor.py ================================================ import logging import warnings from dataclasses import dataclass, field from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar from rich import get_console from rich.table import Table from .common.logging import cli_logger, log_exception from .common.registrable import Registrable from .common.util import import_extra_module from .step_graph import StepGraph from .workspace import Workspace if TYPE_CHECKING: from .step import Step logger = logging.getLogger(__name__) T = TypeVar("T") @dataclass class ExecutionMetadata: logs_location: Optional[str] = None """ Path or URL to the logs for the step's execution. """ result_location: Optional[str] = None """ Path or URL to the result of the step's execution. """ @dataclass class ExecutorOutput: """ Describes the outcome of the execution. """ successful: Dict[str, ExecutionMetadata] = field(default_factory=dict) """Steps which ran successfully or were found in the cache.""" failed: Dict[str, ExecutionMetadata] = field(default_factory=dict) """Steps that failed.""" not_run: Dict[str, ExecutionMetadata] = field(default_factory=dict) """Steps that were ignored (usually because of failed dependencies).""" def display(self) -> None: table = Table(caption_style="") table.add_column("Step Name", justify="left", style="cyan") table.add_column("Status", justify="left") table.add_column("Results", justify="left") all_steps = dict(self.successful) all_steps.update(self.failed) all_steps.update(self.not_run) for step_name in sorted(all_steps): status_str: str result_str: str = "[grey62]N/A[/]" if step_name in self.failed: status_str = "[red]\N{ballot x} failed[/]" execution_metadata = self.failed[step_name] if execution_metadata.logs_location is not None: result_str = f"[cyan]{execution_metadata.logs_location}[/]" elif step_name in self.not_run: status_str = "[yellow]- not run[/]" elif step_name in self.successful: status_str = "[green]\N{check mark} succeeded[/]" execution_metadata = self.successful[step_name] if execution_metadata.result_location is not None: result_str = f"[cyan]{execution_metadata.result_location}[/]" elif execution_metadata.logs_location is not None: result_str = f"[cyan]{execution_metadata.logs_location}[/]" else: continue table.add_row(step_name, status_str, result_str) caption_parts: List[str] = [] if self.failed: caption_parts.append(f"[red]\N{ballot x}[/] [italic]{len(self.failed)} failed[/]") if self.successful: caption_parts.append( f"[green]\N{check mark}[/] [italic]{len(self.successful)} succeeded[/]" ) if self.not_run: caption_parts.append(f"[italic]{len(self.not_run)} not run[/]") table.caption = ", ".join(caption_parts) if logger.isEnabledFor(logging.INFO): logger.info(table) elif cli_logger.isEnabledFor(logging.INFO): cli_logger.info(table) else: get_console().print(table) class Executor(Registrable): """ An ``Executor`` is a class that is responsible for running steps and caching their results. This is the base class and default implementation, registered as "default". .. note:: The ``parallelism`` parameter has no effect with this default :class:`Executor`, but is part of the API because most subclass implementations allow configuring parallelism. """ default_implementation = "default" def __init__( self, workspace: Workspace, include_package: Optional[Sequence[str]] = None, parallelism: Optional[int] = None, ) -> None: self.workspace = workspace self.include_package = include_package self.parallelism = parallelism def execute_step(self, step: "Step") -> None: # Import included packages to find registered components. if self.include_package is not None: for package_name in self.include_package: import_extra_module(package_name) if step.cache_results: step.ensure_result(self.workspace) else: step.result(self.workspace) def execute_step_graph( self, step_graph: StepGraph, run_name: Optional[str] = None ) -> ExecutorOutput: """ Execute a :class:`~tango.step_graph.StepGraph`. This attempts to execute every step in order. If a step fails, its dependent steps are not run, but unrelated steps are still executed. Step failures will be logged, but no exceptions will be raised. """ if self.parallelism is not None: warnings.warn( "The 'parallelism' parameter has no effect with the default Executor. " "If you want to run steps in parallel, consider using the MulticoreExecutor.", UserWarning, ) successful: Dict[str, ExecutionMetadata] = {} failed: Dict[str, ExecutionMetadata] = {} not_run: Dict[str, ExecutionMetadata] = {} uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps() for step in step_graph.values(): if not step.cache_results and step not in uncacheable_leaf_steps: # If a step is uncacheable and required for another step, it will be # executed as part of the downstream step's execution. continue if any(dep.name in failed for dep in step.recursive_dependencies): not_run[step.name] = ExecutionMetadata() else: try: self.execute_step(step) successful[step.name] = ExecutionMetadata( result_location=self.workspace.step_info(step).result_location ) except Exception as exc: failed[step.name] = ExecutionMetadata() log_exception(exc, logger) return ExecutorOutput(successful=successful, failed=failed, not_run=not_run) # NOTE: The reason for having this method instead of just using `execute_step()` to run # a single step is that the certain executors, such as the BeakerExecutor, need to # serialize steps somehow, and the easiest way to serialize a step is by serializing the # whole step config (which can be accessed via the step graph). def execute_sub_graph_for_steps( self, step_graph: StepGraph, *step_names: str, run_name: Optional[str] = None ) -> ExecutorOutput: """ Execute the sub-graph associated with a particular step in a :class:`~tango.step_graph.StepGraph`. """ sub_graph = step_graph.sub_graph(*step_names) return self.execute_step_graph(sub_graph, run_name=run_name) Executor.register("default")(Executor) ================================================ FILE: tango/executors/__init__.py ================================================ """ Built-in :class:`~tango.executor.Executor` implementations. """ from .multicore_executor import MulticoreExecutor ================================================ FILE: tango/executors/multicore_executor.py ================================================ import logging import os import subprocess import time from tempfile import NamedTemporaryFile from typing import Dict, List, Optional, OrderedDict, Sequence, Set, TypeVar from tango.executor import ExecutionMetadata, Executor, ExecutorOutput from tango.step import Step from tango.step_graph import StepGraph from tango.step_info import StepState from tango.workspace import Workspace logger = logging.getLogger(__name__) T = TypeVar("T") @Executor.register("multicore") class MulticoreExecutor(Executor): """ A ``MulticoreExecutor`` runs the steps in parallel and caches their results. """ def __init__( self, workspace: Workspace, include_package: Optional[Sequence[str]] = None, parallelism: Optional[int] = 1, num_tries_to_sync_states: int = 3, wait_seconds_to_sync_states: int = 3, ) -> None: super().__init__(workspace, include_package=include_package, parallelism=parallelism or 1) assert self.parallelism is not None if self.parallelism < 0: self.parallelism = min(32, os.cpu_count() or 1) # Perhaps there's a better way to do this without these being passed as args. self._num_tries_to_sync_states = num_tries_to_sync_states self._wait_seconds_to_sync_states = wait_seconds_to_sync_states def execute_step_graph( self, step_graph: StepGraph, run_name: Optional[str] = None ) -> ExecutorOutput: """ Execute a :class:`tango.step_graph.StepGraph`. This attempts to execute steps in parallel. If a step fails, its dependent steps are not run, but unrelated steps are still executed. Step failures will be logged, but no exceptions will be raised. """ _running: OrderedDict[str, subprocess.Popen] = OrderedDict({}) _successful: Dict[str, ExecutionMetadata] = {} _failed: Dict[str, ExecutionMetadata] = {} _queued_steps: List[str] = [] uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps() def _sync_step_states() -> Dict[str, StepState]: """ Update the StepState info. Although, this is not really elegant. The issue is as follows: The main multicore executor process queues a step, and starts step execution in a different process. If we try to read the StepState before that process has had time to update the StepState, the Workspace will throw the out of sync error (IOError: process should be running but it's considered incomplete...). Hence, we try to read a few times, so that the child process has time to update the step's state. """ attempts = 0 while attempts < self._num_tries_to_sync_states: attempts += 1 try: step_states = {step.name: self._get_state(step) for step in step_graph.values()} break except IOError: if attempts == self._num_tries_to_sync_states: raise step_states = {} time.sleep(self._wait_seconds_to_sync_states) return step_states def _has_incomplete_steps(step_states: Dict[str, StepState]) -> bool: """ Are there any steps in the graph that are currently: 1) running, or 2) queued, or 3) incomplete (with no failed dependencies). If there are any failed dependencies for a step, it will never manage to run. """ def _failed_dependencies(step: Step) -> bool: for dependency in step.recursive_dependencies: if ( step_states[dependency.name] == StepState.FAILED or dependency.name in _failed ): return True return False uncacheable_leaf_step_names = {step.name for step in uncacheable_leaf_steps} for step_name, step_state in step_states.items(): if ( step_name in _running or step_name in _queued_steps or ( # If the workspace already has a previous run, we disregard the failure. step_state in [StepState.INCOMPLETE, StepState.FAILED] and not _failed_dependencies(step_graph[step_name]) # We check for failures in this run. and step_name not in _failed ) or ( # Uncacheable leaf steps need to run, but their StepState will always be UNCACHEABLE. step_name in uncacheable_leaf_step_names and step_name not in _successful and step_name not in _failed and not _failed_dependencies(step_graph[step_name]) ) ): return True return False def _update_running_steps(step_states: Dict[str, StepState]) -> List[str]: """ Check the running processes for status. We use poll_status to check if the process ended, but the StepState for checking completion/failure status, because after the process ends, the lock release etc. sometimes takes a beat longer. """ done = [] errors = [] for step_name, process in _running.items(): poll_status = process.poll() if poll_status is not None: # The step may have finished since we synced step states. if step_states[step_name] == StepState.RUNNING: step_states[step_name] = self._get_state(step_graph[step_name]) if step_states[step_name] == StepState.UNCACHEABLE: if poll_status == 0: done.append(step_name) else: errors.append(step_name) elif step_states[step_name] == StepState.COMPLETED: done.append(step_name) elif ( step_states[step_name] == StepState.FAILED or step_states[step_name] == StepState.INCOMPLETE ): # TODO: look into why the step status changes from running back to incomplete sometimes. # Possibly it's due to the workspace being aggressive in marking it as incomplete when # it thinks that the process is not running. errors.append(step_name) else: raise RuntimeError( f"Step '{step_name}' has the state {step_states[step_name]}, " "but the corresponding process has ended!" ) for step_name in done + errors: _running.pop(step_name) for step_name in done: step = step_graph[step_name] _successful[step_name] = ExecutionMetadata( result_location=None if not step.cache_results else self.workspace.step_info(step).result_location ) for step_name in errors: _failed[step_name] = ExecutionMetadata() return errors def _get_steps_to_run(step_states: Dict[str, StepState]) -> Set[str]: """ Returns the steps that can be queued to run immediately. Criteria: 1) All dependencies are available. 2) Step is not already running or queued. 3) Step has not run in the past and failed. 4) Step's state is INCOMPLETE (or FAILED from a previous run), or step's state is UNCACHEABLE and it is a leaf step. (We only run uncacheable steps if they are needed for another step downstream, as part of the downstream step). """ def _are_dependencies_available(step: Step) -> bool: for dependency in step.dependencies: if step_states[dependency.name] not in [ StepState.COMPLETED, StepState.UNCACHEABLE, ]: return False return True to_run: Set[str] = set() for step in step_graph.values(): if ( _are_dependencies_available(step) and step.name not in _running # Not already running. and step.name not in _queued_steps # Not queued to run. and step.name not in _failed # Not already failed. # See comment in _has_incomplete_steps and ( step_states[step.name] in [StepState.INCOMPLETE, StepState.FAILED] or ( step_states[step.name] == StepState.UNCACHEABLE and step in uncacheable_leaf_steps and step.name not in _successful ) ) ): to_run.add(step.name) return to_run def _queue_step(step_name: str) -> None: _queued_steps.append(step_name) logger.debug(f"Step {step_name} added to the queue for execution.") def _try_to_execute_next_step(config_path: str, run_name: Optional[str] = None) -> None: """ If there are queued steps, try to start processes for them (limited by `parallelism`). """ if len(_queued_steps) == 0: logger.debug("No steps in queue!") return if len(_running) < (self.parallelism or 1): step_name = _queued_steps.pop(0) command: List[str] = [ "tango", "--called-by-executor", "run", config_path, "-s", step_name, "-w", self.workspace.url, ] if self.include_package is not None: for package in self.include_package: command += ["-i", package] if run_name is not None: command += ["-n", run_name] process = subprocess.Popen(command, shell=False) _running[step_name] = process else: logger.debug( f"{self.parallelism or 1} steps are already running. Will attempt to execute later." ) # Creates a temporary file in which to store the config. This is passed as a command line # argument to child step processes. with NamedTemporaryFile(prefix="step-graph-to-file-run", suffix=".jsonnet") as file_ref: step_graph.to_file(file_ref.name, include_unique_id=True) assert os.path.exists(file_ref.name) step_states = _sync_step_states() while _has_incomplete_steps(step_states): # Cleanup previously running steps. _update_running_steps(step_states) # Get steps that are ready to run. to_run = _get_steps_to_run(step_states) if to_run: logger.debug(f"Steps ready to run: {to_run}") for step_name in to_run: _queue_step(step_name) # Begin processes for any queued steps (if not enough processes are already running). while len(_queued_steps) > 0 and len(_running) < (self.parallelism or 1): _try_to_execute_next_step(config_path=file_ref.name, run_name=run_name) # Re-sync the StepState info. step_states = _sync_step_states() assert not _running and not _queued_steps _not_run: Dict[str, ExecutionMetadata] = {} for step_name, step in step_graph.items(): if step_name in _successful or step_name in _failed: # tried to execute directly continue elif not step.cache_results and step not in uncacheable_leaf_steps: # uncacheable interior step; didn't execute directly. continue elif ( step.cache_results and step_name in step_states and step_states[step_name] == StepState.COMPLETED ): # step result was found in cache. # NOTE: since neither `Step.result()` nor `Step.ensure_result()` will have been # called, we invoke the CLI logger here to let users know that we didn't run this # step because we found it in the cache. step.log_cache_hit() _successful[step_name] = ExecutionMetadata( result_location=self.workspace.step_info(step_graph[step_name]).result_location ) else: # step wasn't executed because parents failed, or # step is uncacheable leaf step, so we do care about what happened to it. _not_run[step_name] = ExecutionMetadata() return ExecutorOutput(successful=_successful, failed=_failed, not_run=_not_run) def _get_state(self, step: Step) -> StepState: """ Returns the StepState as determined by the workspace. """ return self.workspace.step_info(step).state ================================================ FILE: tango/format.py ================================================ import bz2 import dataclasses import gzip import importlib import json import logging import lzma from abc import abstractmethod from os import PathLike from pathlib import Path from typing import ( IO, Any, Callable, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, TypeVar, Union, cast, ) import dill from tango.common import DatasetDict, filename_is_safe from tango.common.aliases import PathOrStr from tango.common.exceptions import ConfigurationError from tango.common.registrable import Registrable from tango.common.sequences import SqliteSparseSequence T = TypeVar("T") class Format(Registrable, Generic[T]): """ Formats write objects to directories and read them back out. In the context of Tango, the objects that are written by formats are usually the result of a :class:`~tango.step.Step`. """ VERSION: str = NotImplemented """ Formats can have versions. Versions are part of a step's unique signature, part of :attr:`~tango.step.Step.unique_id`, so when a step's format changes, that will cause the step to be recomputed. """ default_implementation = "dill" @abstractmethod def write(self, artifact: T, dir: PathOrStr): """Writes the ``artifact`` to the directory at ``dir``.""" raise NotImplementedError() @abstractmethod def read(self, dir: PathOrStr) -> T: """Reads an artifact from the directory at ``dir`` and returns it.""" raise NotImplementedError() def _to_params(self) -> Dict[str, Any]: params_dict = super()._to_params() for key in ["logger", "__orig_class__"]: params_dict.pop(key, None) # Removing unnecessary keys. params_dict["type"] = self.__module__ + "." + self.__class__.__qualname__ return params_dict _OPEN_FUNCTIONS: Dict[Optional[str], Callable[[PathLike, str], IO]] = { None: open, "None": open, "none": open, "null": open, "gz": gzip.open, # type: ignore "gzip": gzip.open, # type: ignore "bz": bz2.open, # type: ignore "bz2": bz2.open, # type: ignore "bzip": bz2.open, # type: ignore "bzip2": bz2.open, # type: ignore "lzma": lzma.open, } _SUFFIXES: Dict[Callable, str] = { open: "", gzip.open: ".gz", bz2.open: ".bz2", lzma.open: ".xz", } def _open_compressed(filename: PathOrStr, mode: str) -> IO: open_fn: Callable filename = str(filename) for open_fn, suffix in _SUFFIXES.items(): if len(suffix) > 0 and filename.endswith(suffix): break else: open_fn = open return open_fn(filename, mode) @Format.register("dill") class DillFormat(Format[T], Generic[T]): """ This format writes the artifact as a single file called "data.dill" using dill (a drop-in replacement for pickle). Optionally, it can compress the data. This is very flexible, but not always the fastest. .. tip:: This format has special support for iterables. If you write an iterator, it will consume the iterator. If you read an iterator, it will read the iterator lazily. """ VERSION = "001" def __init__(self, compress: Optional[str] = None): if compress not in _OPEN_FUNCTIONS: raise ConfigurationError(f"The {compress} compression format does not exist.") self.compress = compress def write(self, artifact: T, dir: PathOrStr): filename = self._get_artifact_path(dir) open_method = _OPEN_FUNCTIONS[self.compress] with open_method(filename, "wb") as f: pickler = dill.Pickler(file=f) pickler.dump(self.VERSION) if hasattr(artifact, "__next__"): pickler.dump(True) for item in cast(Iterable, artifact): pickler.dump(item) else: pickler.dump(False) pickler.dump(artifact) def read(self, dir: PathOrStr) -> T: filename = self._get_artifact_path(dir) open_method = _OPEN_FUNCTIONS[self.compress] with open_method(filename, "rb") as f: unpickler = dill.Unpickler(file=f) version = unpickler.load() if version > self.VERSION: raise ValueError( f"File {filename} is too recent for this version of {self.__class__}." ) iterator = unpickler.load() if iterator: return DillFormatIterator(filename) # type: ignore else: return unpickler.load() def _get_artifact_path(self, dir: PathOrStr) -> Path: return Path(dir) / ("data.dill" + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]]) class DillFormatIterator(Iterator[T], Generic[T]): """ An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.DillFormat.read`. """ def __init__(self, filename: PathOrStr): self.f: Optional[IO[Any]] = _open_compressed(filename, "rb") self.unpickler = dill.Unpickler(self.f) version = self.unpickler.load() if version > DillFormat.VERSION: raise ValueError(f"File {filename} is too recent for this version of {self.__class__}.") iterator = self.unpickler.load() if not iterator: raise ValueError( f"Tried to open {filename} as an iterator, but it does not store an iterator." ) def __iter__(self) -> Iterator[T]: return self def __next__(self) -> T: if self.f is None: raise StopIteration() try: return self.unpickler.load() except EOFError: self.f.close() self.f = None raise StopIteration() @Format.register("json") class JsonFormat(Format[T], Generic[T]): """This format writes the artifact as a single file in json format. Optionally, it can compress the data. This is very flexible, but not always the fastest. .. tip:: This format has special support for iterables. If you write an iterator, it will consume the iterator. If you read an iterator, it will read the iterator lazily. """ VERSION = "002" def __init__(self, compress: Optional[str] = None): self.logger = logging.getLogger(self.__class__.__name__) if compress not in _OPEN_FUNCTIONS: raise ConfigurationError(f"The {compress} compression format does not exist.") self.compress = compress @staticmethod def _encoding_fallback(unencodable: Any): try: import torch if isinstance(unencodable, torch.Tensor): if len(unencodable.shape) == 0: return unencodable.item() else: raise TypeError( "Tensors must have 1 element and no dimensions to be JSON serializable." ) except ImportError: pass if dataclasses.is_dataclass(unencodable): result = dataclasses.asdict(unencodable) module = type(unencodable).__module__ qualname = type(unencodable).__qualname__ if module == "builtins": result["_dataclass"] = qualname else: result["_dataclass"] = [module, qualname] return result raise TypeError(f"Object of type {type(unencodable)} is not JSON serializable") @staticmethod def _decoding_fallback(o: Dict) -> Any: if "_dataclass" in o: classname: Union[str, List[str]] = o.pop("_dataclass") if isinstance(classname, list) and len(classname) == 2: module, classname = classname constructor: Callable = importlib.import_module(module) # type: ignore for item in classname.split("."): constructor = getattr(constructor, item) elif isinstance(classname, str): constructor = globals()[classname] else: raise RuntimeError(f"Could not parse {classname} as the name of a dataclass.") return constructor(**o) return o def write(self, artifact: T, dir: PathOrStr): open_method = _OPEN_FUNCTIONS[self.compress] if hasattr(artifact, "__next__"): filename = self._get_artifact_path(dir, iterator=True) with open_method(filename, "wt") as f: for item in cast(Iterable, artifact): json.dump(item, f, default=self._encoding_fallback) f.write("\n") else: filename = self._get_artifact_path(dir, iterator=False) with open_method(filename, "wt") as f: json.dump(artifact, f, default=self._encoding_fallback) def read(self, dir: PathOrStr) -> T: iterator_filename = self._get_artifact_path(dir, iterator=True) iterator_exists = iterator_filename.exists() non_iterator_filename = self._get_artifact_path(dir, iterator=False) non_iterator_exists = non_iterator_filename.exists() if iterator_exists and non_iterator_exists: self.logger.warning( "Both %s and %s exist. Ignoring %s.", iterator_filename, non_iterator_filename, iterator_filename, ) iterator_exists = False if not iterator_exists and not non_iterator_exists: raise IOError("Attempting to read non-existing data from %s", dir) if iterator_exists and not non_iterator_exists: return JsonFormatIterator(iterator_filename) # type: ignore elif not iterator_exists and non_iterator_exists: open_method = _OPEN_FUNCTIONS[self.compress] with open_method(non_iterator_filename, "rt") as f: return json.load(f, object_hook=self._decoding_fallback) else: raise RuntimeError("This should be impossible.") def _get_artifact_path(self, dir: PathOrStr, iterator: bool = False) -> Path: return Path(dir) / ( ("data.jsonl" if iterator else "data.json") + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]] ) class JsonFormatIterator(Iterator[T], Generic[T]): """ An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.JsonFormat.read`. """ def __init__(self, filename: PathOrStr): self.f: Optional[IO[Any]] = _open_compressed(filename, "rt") def __iter__(self) -> Iterator[T]: return self def __next__(self) -> T: if self.f is None: raise StopIteration() try: line = self.f.readline() if len(line) <= 0: raise EOFError() return json.loads(line, object_hook=JsonFormat._decoding_fallback) except EOFError: self.f.close() self.f = None raise StopIteration() @Format.register("text") class TextFormat(Format[Union[str, Iterable[str]]]): """This format writes the artifact as a single file in text format. Optionally, it can compress the data. This is very flexible, but not always the fastest. This format can only write strings, or iterable of strings. .. tip:: This format has special support for iterables. If you write an iterator, it will consume the iterator. If you read an iterator, it will read the iterator lazily. Be aware that if your strings contain newlines, you will read out more strings than you wrote. For this reason, it's often advisable to use :class:`JsonFormat` instead. With :class:`JsonFormat`, all special characters are escaped, strings are quoted, but it's all still human-readable. """ VERSION = "001" def __init__(self, compress: Optional[str] = None): self.logger = logging.getLogger(self.__class__.__name__) if compress not in _OPEN_FUNCTIONS: raise ConfigurationError(f"The {compress} compression format does not exist.") self.compress = compress def write(self, artifact: Union[str, Iterable[str]], dir: PathOrStr): open_method = _OPEN_FUNCTIONS[self.compress] if hasattr(artifact, "__next__"): filename = self._get_artifact_path(dir, iterator=True) with open_method(filename, "wt") as f: for item in cast(Iterable, artifact): f.write(str(item)) f.write("\n") else: filename = self._get_artifact_path(dir, iterator=False) with open_method(filename, "wt") as f: f.write(str(artifact)) def read(self, dir: PathOrStr) -> Union[str, Iterable[str]]: iterator_filename = self._get_artifact_path(dir, iterator=True) iterator_exists = iterator_filename.exists() non_iterator_filename = self._get_artifact_path(dir, iterator=False) non_iterator_exists = non_iterator_filename.exists() if iterator_exists and non_iterator_exists: self.logger.warning( "Both %s and %s exist. Ignoring %s.", iterator_filename, non_iterator_filename, iterator_filename, ) iterator_exists = False if not iterator_exists and not non_iterator_exists: raise IOError("Attempting to read non-existing data from %s", dir) if iterator_exists and not non_iterator_exists: return TextFormatIterator(iterator_filename) # type: ignore elif not iterator_exists and non_iterator_exists: open_method = _OPEN_FUNCTIONS[self.compress] with open_method(non_iterator_filename, "rt") as f: return f.read() else: raise RuntimeError("This should be impossible.") def _get_artifact_path(self, dir: PathOrStr, iterator: bool = False) -> Path: return Path(dir) / ( ("texts.txt" if iterator else "text.txt") + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]] ) class TextFormatIterator(Iterator[str]): """ An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.TextFormat.read`. """ def __init__(self, filename: PathOrStr): self.f: Optional[IO[Any]] = _open_compressed(filename, "rt") def __iter__(self) -> Iterator[str]: return self def __next__(self) -> str: if self.f is None: raise StopIteration() try: line = self.f.readline() if len(line) <= 0: raise EOFError() if line.endswith("\n"): line = line[:-1] return line except EOFError: self.f.close() self.f = None raise StopIteration() @Format.register("sqlite_sequence") class SqliteSequenceFormat(Format[Sequence[T]]): VERSION = "003" FILENAME = "data.sqlite" def write(self, artifact: Sequence[T], dir: Union[str, PathLike]): dir = Path(dir) try: (dir / self.FILENAME).unlink() except FileNotFoundError: pass if isinstance(artifact, SqliteSparseSequence): artifact.copy_to(dir / self.FILENAME) else: sqlite = SqliteSparseSequence(dir / self.FILENAME) sqlite.extend(artifact) def read(self, dir: Union[str, PathLike]) -> Sequence[T]: dir = Path(dir) return SqliteSparseSequence(dir / self.FILENAME, read_only=True) @Format.register("sqlite") class SqliteDictFormat(Format[DatasetDict]): """This format works specifically on results of type :class:`~tango.common.DatasetDict`. It writes those datasets into Sqlite databases. During reading, the advantage is that the dataset can be read lazily. Reading a result that is stored in :class:`SqliteDictFormat` takes milliseconds. No actual reading takes place until you access individual instances. During writing, you have to take some care to take advantage of the same trick. Recall that :class:`~tango.DatasetDict` is basically a map, mapping split names to lists of instances. If you ensure that those lists of instances are of type :class:`~tango.common.sequences.SqliteSparseSequence`, then writing the results in :class:`SqliteDictFormat` can in many cases be instantaneous. Here is an example of the pattern to use to make writing fast: .. code-block:: Python @Step.register("my_step") class MyStep(Step[DatasetDict]): FORMAT: Format = SqliteDictFormat() VERSION = "001" def run(self, ...) -> DatasetDict: result: Dict[str, Sequence] = {} for split_name in my_list_of_splits: output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite") for instance in instances: output_split.append(instance) result[split_name] = output_split metadata = {} return DatasetDict(result, metadata) Observe how for each split, we create a :class:`~tango.common.sequences.SqliteSparseSequence` in the step's work directory (accessible with :meth:`~tango.step.Step.work_dir`). This has the added advantage that if the step fails and you have to re-run it, the previous results that were already written to the :class:`~tango.common.sequences.SqliteSparseSequence` are still there. You could replace the inner ``for`` loop like this to take advantage: .. code-block:: Python output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite") for instance in instances[len(output_split):]: # <-- here is the difference output_split.append(instance) result[split_name] = output_split This works because when you re-run the step, the work directory will still be there, so ``output_split`` is not empty when you open it. """ VERSION = "003" def write(self, artifact: DatasetDict, dir: Union[str, PathLike]): dir = Path(dir) with gzip.open(dir / "metadata.dill.gz", "wb") as f: dill.dump(artifact.metadata, f) for split_name, split in artifact.splits.items(): filename = f"{split_name}.sqlite" if not filename_is_safe(filename): raise ValueError(f"{split_name} is not a valid name for a split.") try: (dir / filename).unlink() except FileNotFoundError: pass if isinstance(split, SqliteSparseSequence): split.copy_to(dir / filename) else: sqlite = SqliteSparseSequence(dir / filename) sqlite.extend(split) def read(self, dir: Union[str, PathLike]) -> DatasetDict: dir = Path(dir) with gzip.open(dir / "metadata.dill.gz", "rb") as f: metadata = dill.load(f) splits = { filename.stem: SqliteSparseSequence(filename, read_only=True) for filename in dir.glob("*.sqlite") } return DatasetDict(metadata=metadata, splits=splits) ================================================ FILE: tango/integrations/__init__.py ================================================ """ In :mod:`tango.integrations` we provide many ready-to-use `component <../components/index.html>`_ implementations for leveraging the functionality from popular libraries. .. tip:: All registered components will be registered under a name that starts with the name of the integration module, possibly followed by a double colon ("::") and another identifier if there are multiple registered components of a given type. For example, the :class:`~tango.integrations.datasets.LoadDataset` step in the `🤗 Datasets `_ integration is registered under the name "datasets::load", and the :class:`~tango.integrations.torch.TorchFormat` format in the `PyTorch `_ integration is registered under the name "torch". """ ================================================ FILE: tango/integrations/beaker/__init__.py ================================================ """ .. important:: To use this integration you should install ``tango`` with the "beaker" extra (e.g. ``pip install tango[beaker]``) or just install the `beaker-py `_ library after the fact (e.g. ``pip install beaker-py``). Components for Tango integration with `Beaker `_. """ from tango.common.exceptions import IntegrationMissingError try: from beaker import Beaker except (ModuleNotFoundError, ImportError): raise IntegrationMissingError("beaker", dependencies={"beaker-py"}) from .executor import ( BeakerExecutor, BeakerScheduler, ResourceAssignment, ResourceAssignmentError, SimpleBeakerScheduler, UnrecoverableResourceAssignmentError, ) from .step_cache import BeakerStepCache from .workspace import BeakerWorkspace __all__ = [ "BeakerStepCache", "BeakerWorkspace", "BeakerExecutor", "BeakerScheduler", "SimpleBeakerScheduler", "ResourceAssignment", "ResourceAssignmentError", "UnrecoverableResourceAssignmentError", ] ================================================ FILE: tango/integrations/beaker/common.py ================================================ import atexit import json import logging import os.path import tempfile import time import urllib import urllib.parse from pathlib import Path from typing import Any, Dict, Optional, Union from beaker import Beaker from beaker import Dataset as BeakerDataset from beaker import DatasetConflict, DatasetNotFound, Experiment, ExperimentNotFound from tango.common.remote_utils import RemoteConstants from tango.step import Step from tango.step_info import StepInfo from tango.version import VERSION logger = logging.getLogger(__name__) class Constants(RemoteConstants): ENTRYPOINT_DATASET_PREFIX = "tango-entrypoint-" BEAKER_TOKEN_SECRET_NAME: str = "BEAKER_TOKEN" GOOGLE_TOKEN_SECRET_NAME: str = "GOOGLE_TOKEN" DEFAULT_GOOGLE_CREDENTIALS_FILE: str = os.path.expanduser( os.path.join("~", ".config", "gcloud", "application_default_credentials.json") ) ENTRYPOINT_DIR: str = "/tango/entrypoint" ENTRYPOINT_FILENAME: str = "entrypoint.sh" def get_client(beaker_workspace: Optional[str] = None, **kwargs) -> Beaker: user_agent = f"tango v{VERSION}" if beaker_workspace is not None: return Beaker.from_env( default_workspace=beaker_workspace, session=True, user_agent=user_agent, **kwargs, ) else: return Beaker.from_env(session=True, user_agent=user_agent, **kwargs) def dataset_url(beaker: Beaker, dataset: Optional[str] = None) -> str: # this just creates a string url. workspace_url = beaker.workspace.url() if dataset: return ( workspace_url + "/datasets?" + urllib.parse.urlencode( { "text": dataset, "committed": "false", } ) ) return workspace_url class BeakerStepLock: METADATA_FNAME = "metadata.json" def __init__( self, beaker: Beaker, step: Union[str, StepInfo, Step], current_beaker_experiment: Optional[Experiment] = None, ): self._beaker = beaker self._step_id = step if isinstance(step, str) else step.unique_id self._lock_dataset_name = RemoteConstants.step_lock_artifact_name(step) self._lock_dataset: Optional[BeakerDataset] = None self._current_beaker_experiment = current_beaker_experiment self.lock_dataset_url = dataset_url(beaker, self._lock_dataset_name) @property def metadata(self) -> Dict[str, Any]: return { "beaker_experiment": None if not self._current_beaker_experiment else self._current_beaker_experiment.id } def _last_metadata(self) -> Optional[Dict[str, Any]]: try: metadata_bytes = self._beaker.dataset.get_file( self._lock_dataset_name, self.METADATA_FNAME, quiet=True ) metadata = json.loads(metadata_bytes) return metadata except (DatasetNotFound, FileNotFoundError): return None def _acquiring_job_is_done(self) -> bool: last_metadata = self._last_metadata() if last_metadata is None: return False last_experiment_id = last_metadata.get("beaker_experiment") if last_experiment_id is None: return False try: last_experiment = self._beaker.experiment.get(last_experiment_id) if ( self._current_beaker_experiment is not None and self._current_beaker_experiment.id == last_experiment_id ): # This means a previous job for this experiment was preempted and # it didn't clean up after itself. return True else: job = self._beaker.experiment.latest_job(last_experiment) return False if job is None else job.is_done except ExperimentNotFound: # Experiment must have been deleted. return True except ValueError: return False def acquire(self, timeout=None, poll_interval: float = 2.0, log_interval: float = 30.0) -> None: if self._lock_dataset is not None: return start = time.monotonic() last_logged = None while timeout is None or (time.monotonic() - start < timeout): try: self._lock_dataset = self._beaker.dataset.create( self._lock_dataset_name, commit=False ) atexit.register(self.release) # Write metadata. with tempfile.TemporaryDirectory() as tmp_dir_name: tmp_dir = Path(tmp_dir_name) metadata_path = tmp_dir / self.METADATA_FNAME with open(metadata_path, "w") as f: json.dump(self.metadata, f) self._beaker.dataset.sync(self._lock_dataset, metadata_path, quiet=True) except DatasetConflict: # Check if existing lock was created from a Beaker experiment. # If it was, and the experiment is no-longer running, we can safely # delete it. if self._acquiring_job_is_done(): self._beaker.dataset.delete(self._lock_dataset_name) continue now = time.monotonic() if last_logged is None or now - last_logged >= log_interval: logger.warning( "Waiting to acquire lock dataset for step '%s':\n\n%s\n\n" "This probably means the step is being run elsewhere, but if you're sure it isn't " "you can just delete the lock dataset.", self._step_id, self.lock_dataset_url, ) last_logged = now time.sleep(poll_interval) continue else: break else: raise TimeoutError( f"Timeout error occurred while waiting to acquire dataset lock for step '{self._step_id}':\n\n" f"{self.lock_dataset_url}\n\n" f"This probably means the step is being run elsewhere, but if you're sure it isn't you can " f"just delete the lock dataset." ) def release(self): if self._lock_dataset is not None: try: self._beaker.dataset.delete(self._lock_dataset) except DatasetNotFound: # Dataset must have been manually deleted. pass self._lock_dataset = None atexit.unregister(self.release) def __del__(self): self.release() ================================================ FILE: tango/integrations/beaker/entrypoint.sh ================================================ #!/bin/bash # # This is the entrypoint script that the Beaker Executor uses when it runs a step # on Beaker. # It will work on any Docker image that has bash and conda / miniconda installed. set -eo pipefail # Ensure we have all the environment variables we need. for env_var in "$GITHUB_TOKEN" "$GITHUB_REPO" "$GIT_REF"; do if [[ -z "$env_var" ]]; then echo >&2 "error: required environment variable is empty" exit 1 fi done # Initialize conda for bash. # See https://stackoverflow.com/a/58081608/4151392 eval "$(command conda 'shell.bash' 'hook' 2> /dev/null)" echo " [TANGO] [1/3] Installing prerequisites... " # Install GitHub CLI. if ! command -v gh &> /dev/null; then conda install gh --channel conda-forge fi # Configure git to use GitHub CLI as a credential helper so that we can clone private repos. gh auth setup-git echo " [TANGO] [2/3] Cloning source code from '$GITHUB_REPO'... " # Clone the repo and checkout the target commit. gh repo clone "$GITHUB_REPO" src cd src git checkout "$GIT_REF" echo " [TANGO] [3/3] Reconstructing Python env... " if [[ -z "$VENV_NAME" ]]; then VENV_NAME=venv fi if [[ -z "$CONDA_ENV_FILE" ]]; then # shellcheck disable=SC2296 CONDA_ENV_FILE="environment.yml" fi if [[ -z "$PIP_REQUIREMENTS_FILE" ]]; then # shellcheck disable=SC2296 PIP_REQUIREMENTS_FILE="requirements.txt" fi if conda activate $VENV_NAME &>/dev/null; then echo "[TANGO] Using existing conda environment '$VENV_NAME'" # The virtual environment already exists. Possibly update it based on an environment file. if [[ -f "$CONDA_ENV_FILE" ]]; then echo "[TANGO] Updating environment from conda env file '$CONDA_ENV_FILE'..." conda env update -f "$CONDA_ENV_FILE" fi else # The virtual environment doesn't exist yet. Create it. if [[ -f "$CONDA_ENV_FILE" ]]; then # Create from the environment file. echo "[TANGO] Initializing environment from conda env file '$CONDA_ENV_FILE'..." conda env create -n "$VENV_NAME" -f "$CONDA_ENV_FILE" elif [[ -z "$PYTHON_VERSION" ]]; then # Create a new empty environment with the whatever the default Python version is. echo "[TANGO] Initializing environment with default Python version..." conda create -n "$VENV_NAME" pip else # Create a new empty environment with the specific Python version. echo "[TANGO] Initializing environment with Python $PYTHON_VERSION..." conda create -n "$VENV_NAME" "python=$PYTHON_VERSION" pip fi conda activate "$VENV_NAME" fi # Every time Beaker changes their APIs, we need to upgrade beaker-py. This happens all the # time, so we make sure we have the latest. # We do this when the conda environment is up, but before the requirements, so that # requirements can request a particular beaker-py version if they want. pip install --upgrade beaker-py if [[ -z "$INSTALL_CMD" ]]; then # Check for a 'requirements.txt' and/or 'setup.py/pyproject.toml/setup.cfg' file. if ( [[ -f 'setup.py' ]] || [[ -f 'pyproject.toml' ]] || [[ -f 'setup.cfg' ]] ) && [[ -f "$PIP_REQUIREMENTS_FILE" ]]; then echo "[GANTRY] Installing local project and packages from '$PIP_REQUIREMENTS_FILE'..." pip install . -r "$PIP_REQUIREMENTS_FILE" elif ( [[ -f 'setup.py' ]] || [[ -f 'pyproject.toml' ]] || [[ -f 'setup.cfg' ]] ); then echo "[GANTRY] Installing local project..." pip install . elif [[ -f "$PIP_REQUIREMENTS_FILE" ]]; then echo "[GANTRY] Installing packages from '$PIP_REQUIREMENTS_FILE'..." pip install -r "$PIP_REQUIREMENTS_FILE" fi else echo "[TANGO] Installing packages with given command: $INSTALL_CMD" eval "$INSTALL_CMD" fi PYTHONPATH="$(pwd)" export PYTHONPATH echo " Environment info: " echo "Using $(python --version) from $(which python)" echo "Packages:" if which sed >/dev/null; then pip freeze | sed 's/^/- /' else pip freeze fi echo " [TANGO] Setup complete ✓ " # Execute the arguments to this script as commands themselves. exec "$@" ================================================ FILE: tango/integrations/beaker/executor.py ================================================ import json import logging import os import threading import time import uuid import warnings from abc import abstractmethod from typing import Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union from beaker import ( Beaker, DataMount, Dataset, DatasetConflict, DatasetNotFound, Digest, EnvVar, Experiment, ExperimentNotFound, ExperimentSpec, JobFailedError, JobTimeoutError, NodeResources, Priority, TaskResources, TaskSpec, TaskStoppedError, ) from git import Git, GitCommandError, InvalidGitRepositoryError, Repo from tango.common.exceptions import ( CancellationError, ConfigurationError, ExecutorError, RunCancelled, ) from tango.common.logging import cli_logger, log_exception from tango.common.registrable import Registrable from tango.executor import ExecutionMetadata, Executor, ExecutorOutput from tango.step import Step from tango.step_graph import StepGraph from tango.step_info import GitMetadata from tango.version import VERSION from tango.workspace import Workspace from .common import Constants, get_client logger = logging.getLogger(__name__) class StepFailedError(ExecutorError): def __init__(self, msg: str, experiment_url: str): super().__init__(msg) self.experiment_url = experiment_url class ResourceAssignmentError(ExecutorError): """ Raised when a scheduler can't find enough free resources at the moment to run a step. """ class UnrecoverableResourceAssignmentError(ExecutorError): """ An unrecoverable version of :class:`ResourceAssignmentError`. Raises this from a :class:`BeakerScheduler` will cause the executor to fail. """ class ResourceAssignment(NamedTuple): """ Resources assigned to a step. """ cluster: Union[str, List[str]] """ The cluster(s) to use to execute the step. """ resources: TaskResources """ The compute resources on the cluster to allocate for execution of the step. """ priority: Union[str, Priority] """ The priority to execute the step with. """ class BeakerScheduler(Registrable): """ A :class:`BeakerScheduler` is responsible for determining which resources and priority to assign to the execution of a step. """ default_implementation = "simple" """ The default implementation is :class:`SimpleBeakerScheduler`. """ def __init__(self): self._beaker: Optional[Beaker] = None @property def beaker(self) -> Beaker: if self._beaker is None: raise ValueError("'beaker' client has not be assigned to scheduler yet!") return self._beaker @beaker.setter def beaker(self, beaker: Beaker) -> None: self._beaker = beaker @abstractmethod def schedule(self, step: Step) -> ResourceAssignment: """ Determine the :class:`ResourceAssignment` for a step. :raises ResourceAssignmentError: If the scheduler can't find enough free resources at the moment to run the step. """ raise NotImplementedError() @BeakerScheduler.register("simple") class SimpleBeakerScheduler(BeakerScheduler): """ The :class:`SimpleBeakerScheduler` just searches the given clusters for one with enough resources to match what's specified by the step's required resources. """ def __init__(self, clusters: List[str], priority: Union[str, Priority]): super().__init__() self.clusters = clusters self.priority = priority self._node_resources: Optional[Dict[str, List[NodeResources]]] = None if not self.clusters: raise ConfigurationError("At least one cluster is required in 'clusters'") @property def node_resources(self) -> Dict[str, List[NodeResources]]: if self._node_resources is None: node_resources = { cluster: [node.limits for node in self.beaker.cluster.nodes(cluster)] for cluster in self.clusters } self._node_resources = node_resources return node_resources else: return self._node_resources def schedule(self, step: Step) -> ResourceAssignment: step_resources = step.resources task_resources = TaskResources( cpu_count=step_resources.cpu_count, gpu_count=step_resources.gpu_count, memory=step_resources.memory, shared_memory=step_resources.shared_memory, ) clusters = self.clusters if step_resources.gpu_type is not None: clusters = [ cluster for cluster, nodes in self.node_resources.items() if all([node.gpu_type == step_resources.gpu_type for node in nodes]) ] if not clusters: raise UnrecoverableResourceAssignmentError( f"Could not find cluster with nodes that have GPU type '{step_resources.gpu_type}'" ) return ResourceAssignment( cluster=clusters, resources=task_resources, priority=self.priority ) @Executor.register("beaker") class BeakerExecutor(Executor): """ This is a :class:`~tango.executor.Executor` that runs steps on `Beaker`_. Each step is run as its own Beaker experiment. .. tip:: Registered as an :class:`~tango.executor.Executor` under the name "beaker". .. important:: The :class:`BeakerExecutor` requires that you run Tango within a GitHub repository and you push all of your changes prior to each ``tango run`` call. It also requires that you have a `GitHub personal access token `_ with at least the "repo" scope set to the environment variable ``GITHUB_TOKEN`` (you can also set it using the ``github_token`` parameter, see below). This is because :class:`BeakerExecutor` has to be able to clone your code from Beaker. .. important:: The :class:`BeakerExecutor` will try to recreate your Python environment on Beaker every time a step is run, so it's important that you specify all of your dependencies in a PIP ``requirements.txt`` file, ``setup.py`` file, or a conda ``environment.yml`` file. Alternatively you could provide the ``install_cmd`` argument. .. important:: The :class:`BeakerExecutor` takes no responsibility for saving the results of steps that it runs on Beaker. That's the job of your workspace. So make sure your using the right type of workspace or your results will be lost. For example, any "remote" workspace (like the :class:`BeakerWorkspace`) would work, or in some cases you could use a :class:`~tango.workspaces.LocalWorkspace` on an NFS drive. .. important:: If you're running a step that requires special hardware, e.g. a GPU, you should specify that in the ``step_resources`` parameter to the step, or by overriding the step's :meth:`.resources() ` property method. :param workspace: The :class:`~tango.workspace.Workspace` to use. :param clusters: A list of Beaker clusters that the executor may use to run steps. If ``scheduler`` is specified, this argument is ignored. :param include_package: A list of Python packages to import before running steps. :param beaker_workspace: The name or ID of the Beaker workspace to use. :param github_token: You can use this parameter to set a GitHub personal access token instead of using the ``GITHUB_TOKEN`` environment variable. :param google_token: You can use this parameter to set a Google Cloud token instead of using the ``GOOGLE_TOKEN`` environment variable. :param beaker_image: The name or ID of a Beaker image to use for running steps on Beaker. The image must come with bash and `conda `_ installed (Miniconda is okay). This is mutually exclusive with the ``docker_image`` parameter. If neither ``beaker_image`` nor ``docker_image`` is specified, the :data:`DEFAULT_BEAKER_IMAGE` will be used. :param docker_image: The name of a publicly-available Docker image to use for running steps on Beaker. The image must come with bash and `conda `_ installed (Miniconda is okay). This is mutually exclusive with the ``beaker_image`` parameter. :param datasets: External data sources to mount into the Beaker job for each step. You could use this to mount an NFS drive, for example. :param env_vars: Environment variables to set in the Beaker job for each step. :param venv_name: The name of the conda virtual environment to use or create on the image. If you're using your own image that already has a conda environment you want to use, you should set this variable to the name of that environment. You can also set this to "base" to use the base environment. :param parallelism: Control the maximum number of steps run in parallel on Beaker. :param install_cmd: Override the command used to install your code and its dependencies in each Beaker job. For example, you could set ``install_cmd="pip install .[dev]"``. :param priority: The default task priority to assign to jobs ran on Beaker. If ``scheduler`` is specified, this argument is ignored. :param scheduler: A :class:`BeakerScheduler` to use for assigning resources to steps. If not specified the :class:`SimpleBeakerScheduler` is used with the given ``clusters`` and ``priority``. :param allow_dirty: By default, the Beaker Executor requires that your git working directory has no uncommitted changes. If you set this to ``True``, we skip this check. :param kwargs: Additional keyword arguments passed to :meth:`Beaker.from_env() `. .. attention:: Certain parameters should not be included in the :data:`~tango.settings.TangoGlobalSettings.executor` part of your ``tango.yml`` file, namely ``workspace`` and ``include_package``. Instead use the top-level :data:`~tango.settings.TangoGlobalSettings.workspace` and :data:`~tango.settings.TangoGlobalSettings.include_package` fields, respectively. :examples: **Minimal tango.yaml file** You can use this executor by specifying it in your ``tango.yml`` settings file: .. code:: yaml executor: type: beaker beaker_workspace: ai2/my-workspace clusters: - ai2/general-cirrascale **Using GPUs** If you have a step that requires a GPU, there are two things you need to do: 1. First, you'll need to ensure that the :class:`BeakerExecutor` can install your dependencies the right way to support the GPU hardware. There are usually two ways to do this: use a Docker image that comes with a proper installation of your hardware-specific dependencies (e.g. PyTorch), or add a conda ``environment.yml`` file to your project that specifies the proper version of those dependencies. If you go with first option you don't necessarily need to build your own Docker image. If PyTorch is the only hardware-specific dependency you have, you could just use one of AI2's pre-built PyTorch images. Just add these lines to your ``tango.yml`` file: .. code:: diff executor: type: beaker beaker_workspace: ai2/my-workspace + docker_image: ghcr.io/allenai/pytorch:1.12.0-cuda11.3-python3.9 + venv_name: base clusters: - ai2/general-cirrascale The ``venv_name: base`` line tells the :class:`BeakerExecutor` to use the existing conda environment called "base" on the image instead of creating a new one. Alternatively, you could use the :data:`default image ` and just add a conda ``environment.yml`` file to the root of your project that looks like this: .. code:: yaml name: torch-env channels: - pytorch dependencies: - python=3.9 - cudatoolkit=11.3 - numpy - pytorch - ... 2. And second, you'll need to specify the GPUs required by each step in the config for that step under the :class:`step_resources ` parameter. For example, .. code:: json "steps": { "train": { "type": "torch::train", "step_resources": { "gpu_count": 1 } } } """ DEFAULT_BEAKER_IMAGE: str = "ai2/conda" """ The default image. Used if neither ``beaker_image`` nor ``docker_image`` are set. """ DEFAULT_NFS_DRIVE = "/net/nfs.cirrascale" RESOURCE_ASSIGNMENT_WARNING_INTERVAL = 60 * 5 def __init__( self, workspace: Workspace, clusters: Optional[List[str]] = None, include_package: Optional[Sequence[str]] = None, beaker_workspace: Optional[str] = None, github_token: Optional[str] = None, google_token: Optional[str] = None, beaker_image: Optional[str] = None, docker_image: Optional[str] = None, datasets: Optional[List[DataMount]] = None, env_vars: Optional[List[EnvVar]] = None, venv_name: Optional[str] = None, parallelism: Optional[int] = None, install_cmd: Optional[str] = None, priority: Optional[Union[str, Priority]] = None, allow_dirty: bool = False, scheduler: Optional[BeakerScheduler] = None, budget: Optional[str] = None, **kwargs, ): # Pre-validate arguments. if beaker_image is None and docker_image is None: beaker_image = self.DEFAULT_BEAKER_IMAGE elif (beaker_image is None) == (docker_image is None): raise ConfigurationError( "Either 'beaker_image' or 'docker_image' must be specified for BeakerExecutor, but not both." ) if budget is None: raise ConfigurationError("You must specify a budget to use the beaker executor.") else: self._budget = budget from tango.workspaces import LocalWorkspace, MemoryWorkspace if isinstance(workspace, MemoryWorkspace): raise ConfigurationError( "You cannot use the `MemoryWorkspace` with the `BeakerExecutor`! " "Please specify a different workspace." ) elif isinstance(workspace, LocalWorkspace): if str(workspace.dir).startswith(self.DEFAULT_NFS_DRIVE): # Mount the NFS drive if not mount already. datasets = datasets or [] if not datasets or not any( [ dm.source.host_path is not None and dm.source.host_path.startswith(self.DEFAULT_NFS_DRIVE) for dm in datasets ] ): nfs_mount = DataMount.new( self.DEFAULT_NFS_DRIVE, host_path=self.DEFAULT_NFS_DRIVE ) datasets.append(nfs_mount) else: warnings.warn( "It appears that you're using a `LocalWorkspace` on a directory that is not an NFS drive. " "If the `BeakerExecutor` cannot access this directory from Beaker, your results will be lost.", UserWarning, ) super().__init__(workspace, include_package=include_package, parallelism=parallelism) self.max_thread_workers = self.parallelism or min(32, (os.cpu_count() or 1) + 4) self.beaker = get_client(beaker_workspace=beaker_workspace, **kwargs) self.beaker_image = beaker_image self.docker_image = docker_image self.datasets = datasets self.env_vars = env_vars self.venv_name = venv_name self.install_cmd = install_cmd self.allow_dirty = allow_dirty self.scheduler: BeakerScheduler if scheduler is None: if clusters is None: raise ConfigurationError( "Either 'scheduler' or 'clusters' argument to BeakerExecutor is required" ) self.scheduler = SimpleBeakerScheduler(clusters, priority=priority or Priority.normal) else: if clusters is not None: warnings.warn( "The 'clusters' parameter will be ignored since you specified a 'scheduler'", UserWarning, ) if priority is not None: warnings.warn( "The 'priority' parameter will be ignored since you specified a 'scheduler'", UserWarning, ) self.scheduler = scheduler self.scheduler.beaker = self.beaker self._is_cancelled = threading.Event() self._logged_git_info = False self._last_resource_assignment_warning: Optional[float] = None self._jobs = 0 try: self.github_token: str = github_token or os.environ["GITHUB_TOKEN"] except KeyError: raise ConfigurationError( "A GitHub personal access token with the 'repo' scope is required. " "This can be set with the 'github_token' argument to the BeakerExecutor, " "or as the environment variable 'GITHUB_TOKEN'." ) self.google_token = google_token or os.environ.get("GOOGLE_TOKEN") # Check if google auth credentials are in the default location if self.google_token is None and os.path.exists(Constants.DEFAULT_GOOGLE_CREDENTIALS_FILE): self.google_token = Constants.DEFAULT_GOOGLE_CREDENTIALS_FILE # If credentials are provided in the form of a file path, load the credentials # so that they can be used in beaker. Do this only if required, i.e., only if GSWorkspace # is being used. if self.google_token is not None and self.google_token.endswith(".json"): from tango.integrations.gs import GSWorkspace if isinstance(workspace, GSWorkspace): with open(self.google_token) as f: self.google_token = f.read() if self.google_token is None: self.google_token = "default" # Ensure entrypoint dataset exists. self._ensure_entrypoint_dataset() # Get repo info and make sure we're in a GitHub-hosted repository. git = GitMetadata.check_for_repo() if ( git is None or git.commit is None or git.remote is None or "github.com" not in git.remote ): raise ExecutorError( f"Missing git data. " f"BeakerExecutor requires a git repository with a GitHub remote." ) self._github_account, self._github_repo = self._parse_git_remote(git.remote) self._git_commit = git.commit def check_repo_state(self): if not self.allow_dirty: # Make sure repository is clean, if we're in one. try: # Check for uncommitted changes. repo = Repo(".") if repo.is_dirty(): raise ExecutorError( "You have uncommitted changes! Commit your changes or use the 'allow_dirty' option." ) # Check for un-pushed commits. remote_name = repo.remote().name git = Git(".") if git.log([f"{remote_name}..HEAD", "--not", "--remotes", "--oneline"]): raise ExecutorError( "You have unpushed changes! Push your changes or use the 'allow_dirty' option." ) except InvalidGitRepositoryError: raise ExecutorError( "It appears you're not in a valid git repository. " "The Beaker executor requires a git repository." ) except GitCommandError: pass def execute_step_graph( self, step_graph: StepGraph, run_name: Optional[str] = None ) -> ExecutorOutput: import concurrent.futures self.check_repo_state() self._is_cancelled.clear() # These will hold the final results which we'll update along the way. successful: Dict[str, ExecutionMetadata] = {} failed: Dict[str, ExecutionMetadata] = {} not_run: Dict[str, ExecutionMetadata] = {} # Keeps track of steps that are next up to run on Beaker. steps_to_run: Set[str] = set() # These are steps that have been submitted to Beaker but haven't completed yet. submitted_steps: Set[str] = set() # Futures for tracking the Beaker runs for each step. step_futures: List[concurrent.futures.Future] = [] uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps() # These are all of the steps that still need to be run at some point. steps_left_to_run = uncacheable_leaf_steps | { step for step in step_graph.values() if step.cache_results } def update_steps_to_run(): nonlocal steps_to_run, not_run for step_name, step in step_graph.items(): if ( step_name in submitted_steps or step_name in successful or step_name in failed or step_name in not_run ): # Make sure step is no longer in queue. steps_to_run.discard(step_name) # This does NOT raise KeyError if not found else: # Check dependencies. for dependency in step.dependencies: if dependency.name not in successful and dependency.cache_results: if dependency.name in failed or dependency.name in not_run: # A dependency failed or can't be run, so this step can't be run. not_run[step_name] = ExecutionMetadata() steps_to_run.discard(step_name) steps_left_to_run.discard(step) break else: # Dependencies are OK, so we can run this step now. if step.cache_results or step in uncacheable_leaf_steps: steps_to_run.add(step_name) def make_future_done_callback(step_name: str): def future_done_callback(future: concurrent.futures.Future): nonlocal successful, failed, steps_left_to_run self._jobs = max(0, self._jobs - 1) step = step_graph[step_name] try: exc = future.exception() if exc is None: successful[step_name] = ExecutionMetadata( result_location=None if not step.cache_results else self.workspace.step_info(step).result_location, logs_location=future.result(), ) steps_left_to_run.discard(step) elif isinstance(exc, ResourceAssignmentError): submitted_steps.discard(step_name) self._emit_resource_assignment_warning() elif isinstance(exc, StepFailedError): failed[step_name] = ExecutionMetadata(logs_location=exc.experiment_url) steps_left_to_run.discard(step) elif isinstance(exc, (ExecutorError, CancellationError)): failed[step_name] = ExecutionMetadata() steps_left_to_run.discard(step) else: log_exception(exc, logger) failed[step_name] = ExecutionMetadata() steps_left_to_run.discard(step) except concurrent.futures.TimeoutError as exc: log_exception(exc, logger) failed[step_name] = ExecutionMetadata() steps_left_to_run.discard(step) return future_done_callback last_progress_update = time.monotonic() def log_progress(): nonlocal last_progress_update now = time.monotonic() if now - last_progress_update >= 60 * 2: last_progress_update = now waiting_for = [ step_name for step_name in submitted_steps if step_name not in failed and step_name not in successful ] if len(waiting_for) > 5: logger.info( "Waiting for %d steps...", len(waiting_for), ) elif len(waiting_for) > 1: logger.info( "Waiting for %d steps (%s)...", len(waiting_for), "'" + "', '".join(waiting_for) + "'", ) elif len(waiting_for) == 1: logger.info("Waiting for 1 step ('%s')...", list(waiting_for)[0]) still_to_run = [ step.name for step in steps_left_to_run if step.name not in submitted_steps ] if len(still_to_run) > 5: logger.info( "Still waiting to submit %d more steps...", len(still_to_run), ) elif len(still_to_run) > 1: logger.info( "Still waiting to submit %d more steps (%s)...", len(still_to_run), "'" + "', '".join(still_to_run) + "'", ) elif len(still_to_run) == 1: logger.info("Still waiting to submit 1 more step ('%s')...", still_to_run[0]) update_steps_to_run() try: with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_workers) as pool: while steps_left_to_run: # Submit steps left to run. for step_name in steps_to_run: future = pool.submit( self._execute_sub_graph_for_step, step_graph, step_name, True ) future.add_done_callback(make_future_done_callback(step_name)) self._jobs += 1 step_futures.append(future) submitted_steps.add(step_name) if step_futures: # Wait for something to happen. _, not_done = concurrent.futures.wait( step_futures, return_when=concurrent.futures.FIRST_COMPLETED, timeout=2.0, ) # Update the list of running futures. step_futures.clear() step_futures = list(not_done) else: time.sleep(2.0) # Update the step queue. update_steps_to_run() log_progress() except (KeyboardInterrupt, CancellationError): if step_futures: cli_logger.warning("Received interrupt, canceling steps...") self._is_cancelled.set() concurrent.futures.wait(step_futures) raise finally: self._is_cancelled.clear() # NOTE: The 'done callback' added to each future is executed in a thread, # and so might not complete before the last 'update_steps_to_run()' is called # in the loop above. Therefore we have to call 'update_steps_to_run()' # one last time here to ensure the 'not_run' set is up-to-date. update_steps_to_run() return ExecutorOutput(successful=successful, failed=failed, not_run=not_run) def _emit_resource_assignment_warning(self): if self._last_resource_assignment_warning is None or ( time.monotonic() - self._last_resource_assignment_warning > self.RESOURCE_ASSIGNMENT_WARNING_INTERVAL ): self._last_resource_assignment_warning = time.monotonic() logger.warning( "Some steps can't be run yet - waiting on more Beaker resources " "to become available..." ) def _check_if_cancelled(self): if self._is_cancelled.is_set(): raise RunCancelled def _execute_sub_graph_for_step( self, step_graph: StepGraph, step_name: str, in_thread: bool = False, ) -> Optional[str]: if not in_thread: self._is_cancelled.clear() else: self._check_if_cancelled() step = step_graph[step_name] if step.cache_results and step in self.workspace.step_cache: cli_logger.info( '[green]\N{check mark} Found output for step [bold]"%s"[/] in cache...[/]', step_name, ) return None if step.resources.machine == "local": if step.cache_results: step.ensure_result(self.workspace) else: result = step.result(self.workspace) if hasattr(result, "__next__"): from collections import deque deque(result, maxlen=0) return None experiment: Optional[Experiment] = None experiment_url: Optional[str] = None ephemeral_datasets: List[Dataset] = [] # Try to find any existing experiments for this step that are still running. if step.cache_results: for exp in self.beaker.workspace.experiments( match=f"{Constants.STEP_EXPERIMENT_PREFIX}{step.unique_id}-" ): self._check_if_cancelled() try: latest_job = self.beaker.experiment.latest_job(exp) except (ValueError, ExperimentNotFound): continue if latest_job is not None and not latest_job.is_done: experiment = exp experiment_url = self.beaker.experiment.url(exp) cli_logger.info( "[blue]\N{black rightwards arrow} Found existing Beaker experiment [b]%s[/] for " 'step [b]"%s"[/] that is still running...[/]', experiment_url, step_name, ) break # Otherwise we submit a new experiment... if experiment is None: # Initialize experiment and task spec. experiment_name, spec, ephemeral_datasets = self._build_experiment_spec( step_graph, step_name ) self._check_if_cancelled() step.log_starting() # Create experiment. experiment = self.beaker.experiment.create(experiment_name, spec) experiment_url = self.beaker.experiment.url(experiment) cli_logger.info( '[blue]\N{black rightwards arrow} Submitted Beaker experiment [b]%s[/] for step [b]"%s"[/]...[/]', experiment_url, step_name, ) assert experiment is not None assert experiment_url is not None # Follow the experiment until it completes. try: while True: poll_interval = min(60, 5 * min(self._jobs, self.max_thread_workers)) try: self._check_if_cancelled() self.beaker.experiment.wait_for( experiment, strict=True, quiet=True, timeout=poll_interval + 2, poll_interval=poll_interval, ) break except JobTimeoutError: time.sleep(poll_interval) continue except (JobFailedError, TaskStoppedError): cli_logger.error( '[red]\N{ballot x} Step [b]"%s"[/] failed. You can check the logs at [b]%s[/][/]', step_name, experiment_url, ) raise StepFailedError( f'Beaker job for step "{step_name}" failed. ' f"You can check the logs at {experiment_url}", experiment_url, ) except (KeyboardInterrupt, CancellationError): cli_logger.warning( 'Stopping Beaker experiment [cyan]%s[/] for step [b]"%s"[/] (%s)', experiment_url, step_name, step.unique_id, ) self.beaker.experiment.stop(experiment) raise else: step.log_finished() finally: # Remove ephemeral datasets. result_dataset = self.beaker.experiment.results(experiment) if result_dataset is not None: ephemeral_datasets.append(result_dataset) for dataset in ephemeral_datasets: try: self.beaker.dataset.delete(dataset) except DatasetNotFound: pass return experiment_url @staticmethod def _parse_git_remote(url: str) -> Tuple[str, str]: """ Parse a git remote URL into a GitHub (account, repo) pair. """ account, repo = ( url.split("https://github.com/")[-1] .split("git@github.com:")[-1] .split(".git")[0] .split("/") ) return account, repo def _ensure_entrypoint_dataset(self) -> Dataset: import hashlib from importlib.resources import read_binary import tango.integrations.beaker workspace_id = self.beaker.workspace.get().id # Get hash of the local entrypoint source file. sha256_hash = hashlib.sha256() contents = read_binary(tango.integrations.beaker, Constants.ENTRYPOINT_FILENAME) sha256_hash.update(contents) entrypoint_dataset_name = ( f"{Constants.ENTRYPOINT_DATASET_PREFIX}{workspace_id}-{sha256_hash.hexdigest()[:6]}" ) tmp_entrypoint_dataset_name = ( f"{Constants.ENTRYPOINT_DATASET_PREFIX}{str(uuid.uuid4())}-tmp" ) # Ensure entrypoint dataset exists. entrypoint_dataset: Dataset try: entrypoint_dataset = self.beaker.dataset.get(entrypoint_dataset_name) except DatasetNotFound: # Create it. logger.debug(f"Creating entrypoint dataset '{entrypoint_dataset_name}'") try: tmp_entrypoint_dataset = self.beaker.dataset.create( tmp_entrypoint_dataset_name, quiet=True, commit=False ) self.beaker.dataset.upload( tmp_entrypoint_dataset, contents, Constants.ENTRYPOINT_FILENAME, quiet=True ) self.beaker.dataset.commit(tmp_entrypoint_dataset) entrypoint_dataset = self.beaker.dataset.rename( tmp_entrypoint_dataset, entrypoint_dataset_name ) except DatasetConflict: # could be in a race with another `tango` process. time.sleep(1.0) entrypoint_dataset = self.beaker.dataset.get(entrypoint_dataset_name) # Verify contents. err_msg = ( f"Checksum failed for entrypoint dataset {self.beaker.dataset.url(entrypoint_dataset)}\n" f"This could be a bug, or it could mean someone has tampered with the dataset.\n" f"If you're sure no one has tampered with it, you can delete the dataset from " f"the Beaker dashboard and try again." ) file_info = self.beaker.dataset.file_info(entrypoint_dataset, Constants.ENTRYPOINT_FILENAME) if file_info.digest is not None and file_info.digest != Digest.from_decoded( sha256_hash.digest(), "SHA256" ): raise ExecutorError(err_msg) return entrypoint_dataset def _ensure_step_graph_dataset(self, step_graph: StepGraph) -> Dataset: step_graph_dataset_name = f"{Constants.STEP_GRAPH_ARTIFACT_PREFIX}{str(uuid.uuid4())}" try: dataset = self.beaker.dataset.create(step_graph_dataset_name, quiet=True, commit=False) self.beaker.dataset.upload( dataset, json.dumps({"steps": step_graph.to_config(include_unique_id=True)}).encode(), Constants.STEP_GRAPH_FILENAME, quiet=True, ) self.beaker.dataset.commit(dataset) except DatasetConflict: # could be in a race with another `tango` process. time.sleep(1.0) dataset = self.beaker.dataset.get(step_graph_dataset_name) return dataset def _build_experiment_spec( self, step_graph: StepGraph, step_name: str ) -> Tuple[str, ExperimentSpec, List[Dataset]]: from tango.common.logging import TANGO_LOG_LEVEL step = step_graph[step_name] sub_graph = step_graph.sub_graph(step_name) step_info = self.workspace.step_info(step) experiment_name = ( f"{Constants.STEP_EXPERIMENT_PREFIX}{step.unique_id}-{str(uuid.uuid4())[:8]}" ) github_account, github_repo, git_ref = ( self._github_account, self._github_repo, self._git_commit, ) if not self._logged_git_info: self._logged_git_info = True cli_logger.info( "[blue]Using source code from " "[b]https://github.com/%s/%s/commit/%s[/] to run steps on Beaker[/]", github_account, github_repo, git_ref, ) # Get cluster, resources, and priority to use. clusters, task_resources, priority = self.scheduler.schedule(step) self._check_if_cancelled() # Ensure dataset with the entrypoint script exists and get it. entrypoint_dataset = self._ensure_entrypoint_dataset() self._check_if_cancelled() # Create dataset for step graph. step_graph_dataset = self._ensure_step_graph_dataset(sub_graph) self._check_if_cancelled() # Write the GitHub token secret. self.beaker.secret.write(Constants.GITHUB_TOKEN_SECRET_NAME, self.github_token) self._check_if_cancelled() # Write the Beaker token secret. self.beaker.secret.write(Constants.BEAKER_TOKEN_SECRET_NAME, self.beaker.config.user_token) self._check_if_cancelled() # Write the Google Cloud token secret. if self.google_token is not None: self.beaker.secret.write(Constants.GOOGLE_TOKEN_SECRET_NAME, self.google_token) self._check_if_cancelled() # Build Tango command to run. command = [ "tango", "--log-level", "debug", "--called-by-executor", "beaker-executor-run", Constants.INPUT_DIR + "/" + Constants.STEP_GRAPH_FILENAME, step.name, self.workspace.url, ] if self.include_package is not None: for package in self.include_package: command += ["-i", package, "--log-level", TANGO_LOG_LEVEL or "debug"] self._check_if_cancelled() # Ignore the patch version. # E.g. '3.9.7' -> '3.9' python_version = step_info.environment.python python_version = python_version[: python_version.find(".", python_version.find(".") + 1)] # Build task spec. task_spec = ( TaskSpec.new( step.unique_id, beaker_image=self.beaker_image, docker_image=self.docker_image, result_path=Constants.RESULTS_DIR, command=["bash", Constants.ENTRYPOINT_DIR + "/" + Constants.ENTRYPOINT_FILENAME], arguments=command, resources=task_resources, datasets=self.datasets, env_vars=self.env_vars, priority=priority, ) .with_constraint(cluster=[clusters] if isinstance(clusters, str) else clusters) .with_env_var(name="TANGO_VERSION", value=VERSION) .with_env_var(name="GITHUB_TOKEN", secret=Constants.GITHUB_TOKEN_SECRET_NAME) .with_env_var(name="BEAKER_TOKEN", secret=Constants.BEAKER_TOKEN_SECRET_NAME) .with_env_var(name="GOOGLE_TOKEN", secret=Constants.GOOGLE_TOKEN_SECRET_NAME) .with_env_var(name="GITHUB_REPO", value=f"{github_account}/{github_repo}") .with_env_var(name="GIT_REF", value=git_ref) .with_env_var(name="PYTHON_VERSION", value=python_version) .with_env_var(name="BEAKER_EXPERIMENT_NAME", value=experiment_name) .with_dataset(Constants.ENTRYPOINT_DIR, beaker=entrypoint_dataset.id) .with_dataset(Constants.INPUT_DIR, beaker=step_graph_dataset.id) ) if self.venv_name is not None: task_spec = task_spec.with_env_var(name="VENV_NAME", value=self.venv_name) if self.install_cmd is not None: task_spec = task_spec.with_env_var(name="INSTALL_CMD", value=self.install_cmd) return ( experiment_name, ExperimentSpec( tasks=[task_spec], description=f'Tango step "{step_name}" ({step.unique_id})', budget=self._budget, ), [step_graph_dataset], ) ================================================ FILE: tango/integrations/beaker/step_cache.py ================================================ import logging from pathlib import Path from typing import Optional, Union from beaker import Beaker from beaker import Dataset as BeakerDataset from beaker import DatasetConflict, DatasetNotFound, DatasetWriteError from tango import Step from tango.common import PathOrStr from tango.common.exceptions import ConfigurationError from tango.common.util import make_safe_filename, tango_cache_dir from tango.integrations.beaker.common import Constants, get_client from tango.step_cache import StepCache from tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache from tango.step_info import StepInfo logger = logging.getLogger(__name__) @StepCache.register("beaker") class BeakerStepCache(RemoteStepCache): """ This is a :class:`~tango.step_cache.StepCache` that's used by :class:`BeakerWorkspace`. It stores the results of steps on Beaker as datasets. It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a step's resulting subsequent times should be fast. .. tip:: Registered as a :class:`~tango.step_cache.StepCache` under the name "beaker". :param workspace: The name or ID of the Beaker workspace to use. :param beaker: The Beaker client to use. """ Constants = Constants def __init__(self, beaker_workspace: Optional[str] = None, beaker: Optional[Beaker] = None): self.beaker: Beaker if beaker is not None: self.beaker = beaker if beaker_workspace is not None: self.beaker.config.default_workspace = beaker_workspace self.beaker.workspace.ensure(beaker_workspace) else: self.beaker = get_client(beaker_workspace=beaker_workspace) if self.beaker.config.default_workspace is None: raise ConfigurationError("Beaker default workspace must be set") super().__init__( tango_cache_dir() / "beaker_cache" / make_safe_filename(self.beaker.config.default_workspace) ) def _step_result_remote(self, step: Union[Step, StepInfo]) -> Optional[BeakerDataset]: """ Returns a `BeakerDataset` object containing the details of the step. This only returns if the step has been finalized (committed). """ try: dataset = self.beaker.dataset.get(self.Constants.step_artifact_name(step)) return dataset if dataset.committed is not None else None except DatasetNotFound: return None def _upload_step_remote(self, step: Step, objects_dir: Path) -> BeakerDataset: """ Uploads the step's output to remote location. """ dataset_name = self.Constants.step_artifact_name(step) try: self.beaker.dataset.create(dataset_name, commit=False) except DatasetConflict: pass try: self.beaker.dataset.sync(dataset_name, objects_dir, quiet=True) self.beaker.dataset.commit(dataset_name) except DatasetWriteError: pass return self.beaker.dataset.get(dataset_name) def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None: """ Downloads the step's output from remote location. """ try: self.beaker.dataset.fetch(step_result, target_dir, quiet=True) except DatasetNotFound: raise RemoteNotFoundError() def __len__(self): """ Returns the number of committed step outputs present in the remote location. """ # NOTE: lock datasets should not count here. return sum( 1 for ds in self.beaker.workspace.iter_datasets( match=self.Constants.STEP_ARTIFACT_PREFIX, uncommitted=False, results=False ) if ds.name is not None and ds.name.startswith(self.Constants.STEP_ARTIFACT_PREFIX) and not ds.name.endswith(self.Constants.LOCK_ARTIFACT_SUFFIX) ) ================================================ FILE: tango/integrations/beaker/workspace.py ================================================ import json import logging import os import random from collections import OrderedDict from pathlib import Path from typing import Dict, List, Optional, Type, TypeVar, Union, cast from urllib.parse import ParseResult import petname from beaker import Dataset from beaker import Dataset as BeakerDataset from beaker import ( DatasetConflict, DatasetNotFound, DatasetSort, Digest, Experiment, ExperimentNotFound, ) from tango.common.util import make_safe_filename, tango_cache_dir from tango.step import Step from tango.step_info import StepInfo, StepState from tango.workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace from tango.workspaces.remote_workspace import RemoteWorkspace from .common import BeakerStepLock, Constants, dataset_url, get_client from .step_cache import BeakerStepCache T = TypeVar("T") U = TypeVar("U", Run, StepInfo) logger = logging.getLogger(__name__) @Workspace.register("beaker") class BeakerWorkspace(RemoteWorkspace): """ This is a :class:`~tango.workspace.Workspace` that stores step artifacts on `Beaker`_. .. tip:: Registered as a :class:`~tango.workspace.Workspace` under the name "beaker". :param workspace: The name or ID of the Beaker workspace to use. :param kwargs: Additional keyword arguments passed to :meth:`Beaker.from_env() `. """ STEP_INFO_CACHE_SIZE = 512 Constants = Constants NUM_CONCURRENT_WORKERS = 9 def __init__(self, workspace: str, max_workers: Optional[int] = None, **kwargs): self.beaker = get_client(beaker_workspace=workspace, **kwargs) self._cache = BeakerStepCache(beaker=self.beaker) self._locks: Dict[Step, BeakerStepLock] = {} super().__init__() self.max_workers = max_workers self._disk_cache_dir = tango_cache_dir() / "beaker_cache" / "_objects" self._mem_cache: "OrderedDict[Digest, Union[StepInfo, Run]]" = OrderedDict() @property def cache(self): return self._cache @property def locks(self): return self._locks @property def steps_dir_name(self): return "beaker_workspace" @property def url(self) -> str: return f"beaker://{self.beaker.workspace.get().full_name}" def _step_location(self, step: Step) -> str: return dataset_url(self.beaker, self.Constants.step_artifact_name(step)) @classmethod def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace: workspace: str if parsed_url.netloc and parsed_url.path: # e.g. "beaker://ai2/my-workspace" workspace = parsed_url.netloc + parsed_url.path elif parsed_url.netloc: # e.g. "beaker://my-workspace" workspace = parsed_url.netloc else: raise ValueError(f"Bad URL for Beaker workspace '{parsed_url}'") return cls(workspace) @property def current_beaker_experiment(self) -> Optional[Experiment]: """ When the workspace is being used within a Beaker experiment that was submitted by the Beaker executor, this will return the `Experiment` object. """ experiment_name = os.environ.get("BEAKER_EXPERIMENT_NAME") if experiment_name is not None: try: return self.beaker.experiment.get(experiment_name) except ExperimentNotFound: return None else: return None def _remote_lock(self, step: Step) -> BeakerStepLock: return BeakerStepLock( self.beaker, step, current_beaker_experiment=self.current_beaker_experiment ) def _get_object_from_cache(self, digest: Digest, o_type: Type[U]) -> Optional[U]: cache_path = self._disk_cache_dir / make_safe_filename(str(digest)) if digest in self._mem_cache: cached = self._mem_cache.pop(digest) # Move to end. self._mem_cache[digest] = cached return cached if isinstance(cached, o_type) else None elif cache_path.is_file(): try: with cache_path.open("r+t") as f: json_dict = json.load(f) cached = o_type.from_json_dict(json_dict) except Exception as exc: logger.warning("Error while loading object from workspace cache: %s", str(exc)) try: os.remove(cache_path) except FileNotFoundError: pass return None # Add to in-memory cache. self._mem_cache[digest] = cached while len(self._mem_cache) > self.STEP_INFO_CACHE_SIZE: self._mem_cache.popitem(last=False) return cached # type: ignore else: return None def _add_object_to_cache(self, digest: Digest, o: U): self._disk_cache_dir.mkdir(parents=True, exist_ok=True) cache_path = self._disk_cache_dir / make_safe_filename(str(digest)) self._mem_cache[digest] = o with cache_path.open("w+t") as f: json.dump(o.to_json_dict(), f) while len(self._mem_cache) > self.STEP_INFO_CACHE_SIZE: self._mem_cache.popitem(last=False) def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: try: dataset = self.beaker.dataset.get(self.Constants.step_artifact_name(step_or_unique_id)) return self._get_step_info_from_dataset(dataset) except (DatasetNotFound, FileNotFoundError): if not isinstance(step_or_unique_id, Step): raise KeyError(step_or_unique_id) step_info = StepInfo.new_from_step(step_or_unique_id) self._update_step_info(step_info) return step_info def _get_step_info_from_dataset(self, dataset: Dataset) -> StepInfo: file_info = self.beaker.dataset.file_info(dataset, Constants.STEP_INFO_FNAME) step_info: StepInfo cached = ( None if file_info.digest is None else self._get_object_from_cache(file_info.digest, StepInfo) ) if cached is not None: step_info = cached else: step_info_bytes = self.beaker.dataset.get_file(dataset, file_info, quiet=True) step_info = StepInfo.from_json_dict(json.loads(step_info_bytes)) if file_info.digest is not None: self._add_object_to_cache(file_info.digest, step_info) return step_info def _save_run( self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None ) -> Run: # Create a remote dataset that represents this run. The dataset which just contain # a JSON file that maps step names to step unique IDs. run_dataset: BeakerDataset if name is None: # Find a unique name to use. while True: name = petname.generate() + str(random.randint(0, 100)) try: run_dataset = self.beaker.dataset.create( self.Constants.run_artifact_name(cast(str, name)), commit=False ) except DatasetConflict: continue else: break else: try: run_dataset = self.beaker.dataset.create( self.Constants.run_artifact_name(name), commit=False ) except DatasetConflict: raise ValueError(f"Run name '{name}' is already in use") # Upload run data to dataset. # NOTE: We don't commit the dataset here since we'll need to upload the logs file # after the run. self.beaker.dataset.upload( run_dataset, json.dumps(run_data).encode(), self.Constants.RUN_DATA_FNAME, quiet=True ) return Run(name=cast(str, name), steps=steps, start_date=run_dataset.created) def registered_runs(self) -> Dict[str, Run]: import concurrent.futures runs: Dict[str, Run] = {} with concurrent.futures.ThreadPoolExecutor( max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="BeakerWorkspace.registered_runs()-", ) as executor: run_futures = [] for dataset in self.beaker.workspace.iter_datasets( match=self.Constants.RUN_ARTIFACT_PREFIX, uncommitted=True, results=False ): run_futures.append(executor.submit(self._get_run_from_dataset, dataset)) for future in concurrent.futures.as_completed(run_futures): run = future.result() if run is not None: runs[run.name] = run return runs def search_registered_runs( self, *, sort_by: Optional[RunSort] = None, sort_descending: bool = True, match: Optional[str] = None, start: Optional[int] = None, stop: Optional[int] = None, ) -> List[RunInfo]: if match is None: match = Constants.RUN_ARTIFACT_PREFIX else: match = Constants.RUN_ARTIFACT_PREFIX + match if sort_by is None or sort_by == RunSort.START_DATE: sort = DatasetSort.created elif sort_by == RunSort.NAME: sort = DatasetSort.dataset_name else: raise NotImplementedError runs = [] for dataset in self.beaker.workspace.iter_datasets( match=match, results=False, cursor=start or 0, limit=None if stop is None else stop - (start or 0), sort_by=sort, descending=sort_descending, ): if dataset.name is not None and dataset.name.startswith( self.Constants.RUN_ARTIFACT_PREFIX ): run_name = dataset.name[len(self.Constants.RUN_ARTIFACT_PREFIX) :] runs.append(RunInfo(name=run_name, start_date=dataset.created)) return runs def num_registered_runs(self, *, match: Optional[str] = None) -> int: if match is None: match = Constants.RUN_ARTIFACT_PREFIX else: match = Constants.RUN_ARTIFACT_PREFIX + match count = 0 for dataset in self.beaker.workspace.iter_datasets( match=match, results=False, ): if dataset.name is not None and dataset.name.startswith(Constants.RUN_ARTIFACT_PREFIX): count += 1 return count def search_step_info( self, *, sort_by: Optional[StepInfoSort] = None, sort_descending: bool = True, match: Optional[str] = None, state: Optional[StepState] = None, start: int = 0, stop: Optional[int] = None, ) -> List[StepInfo]: if state is not None: raise NotImplementedError( f"{self.__class__.__name__} cannot filter steps efficiently by state" ) if match is None: match = Constants.STEP_ARTIFACT_PREFIX else: match = Constants.STEP_ARTIFACT_PREFIX + match sort: Optional[DatasetSort] = None if sort_by is None or sort_by == StepInfoSort.START_TIME: sort = DatasetSort.created elif sort_by == StepInfoSort.UNIQUE_ID: sort = DatasetSort.dataset_name elif sort_by is not None: raise NotImplementedError steps = [] for dataset in self.beaker.workspace.iter_datasets( match=match, results=False, cursor=start or 0, limit=None if stop is None else stop - (start or 0), sort_by=sort or DatasetSort.created, descending=sort_descending, ): try: steps.append(self._get_step_info_from_dataset(dataset)) except (DatasetNotFound, FileNotFoundError): continue return steps def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int: if state is not None: raise NotImplementedError( f"{self.__class__.__name__} cannot filter steps efficiently by state" ) if match is None: match = Constants.STEP_ARTIFACT_PREFIX else: match = Constants.STEP_ARTIFACT_PREFIX + match count = 0 for dataset in self.beaker.workspace.iter_datasets( match=match, results=False, ): if dataset.name is not None and dataset.name.startswith(Constants.STEP_ARTIFACT_PREFIX): count += 1 return count def registered_run(self, name: str) -> Run: err_msg = f"Run '{name}' not found in workspace" try: dataset_for_run = self.beaker.dataset.get(self.Constants.run_artifact_name(name)) # Make sure the run is in our workspace. if dataset_for_run.workspace_ref.id != self.beaker.workspace.get().id: # type: ignore # TODO raise DatasetNotFound except DatasetNotFound: raise KeyError(err_msg) run = self._get_run_from_dataset(dataset_for_run) if run is None: raise KeyError(err_msg) else: return run def _save_run_log(self, name: str, log_file: Path): run_dataset = self.Constants.run_artifact_name(name) self.beaker.dataset.sync(run_dataset, log_file, quiet=True) self.beaker.dataset.commit(run_dataset) def _get_run_from_dataset(self, dataset: BeakerDataset) -> Optional[Run]: if dataset.name is None: return None if not dataset.name.startswith(self.Constants.RUN_ARTIFACT_PREFIX): return None try: run_name = dataset.name[len(self.Constants.RUN_ARTIFACT_PREFIX) :] steps_info_bytes = self.beaker.dataset.get_file( dataset, self.Constants.RUN_DATA_FNAME, quiet=True ) steps_info = json.loads(steps_info_bytes) except (DatasetNotFound, FileNotFoundError): return None steps: Dict[str, StepInfo] = {} import concurrent.futures with concurrent.futures.ThreadPoolExecutor( max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="BeakerWorkspace._get_run_from_dataset()-", ) as executor: step_info_futures = [] for unique_id in steps_info.values(): step_info_futures.append(executor.submit(self.step_info, unique_id)) for future in concurrent.futures.as_completed(step_info_futures): step_info = future.result() assert step_info.step_name is not None steps[step_info.step_name] = step_info return Run(name=run_name, start_date=dataset.created, steps=steps) def _update_step_info(self, step_info: StepInfo): dataset_name = self.Constants.step_artifact_name(step_info) step_info_dataset: BeakerDataset try: self.beaker.dataset.create(dataset_name, commit=False) except DatasetConflict: pass step_info_dataset = self.beaker.dataset.get(dataset_name) self.beaker.dataset.upload( step_info_dataset, # folder name json.dumps(step_info.to_json_dict()).encode(), # step info dict. self.Constants.STEP_INFO_FNAME, # step info filename quiet=True, ) def _remove_step_info(self, step_info: StepInfo) -> None: # remove dir from beaker workspace dataset_name = self.Constants.step_artifact_name(step_info) step_dataset = self.beaker.dataset.get(dataset_name) if step_dataset is not None: self.beaker.dataset.delete(step_dataset) ================================================ FILE: tango/integrations/datasets/__init__.py ================================================ """ .. important:: To use this integration you should install ``tango`` with the "datasets" extra (e.g. ``pip install tango[datasets]``) or just install the ``datasets`` library after the fact (e.g. ``pip install datasets``). Components for Tango integration with `🤗 Datasets `_. Example: loading and combining ------------------------------ Here's an example config that uses the built-in steps from this integration to load, concatenate, and interleave datasets from HuggingFace: .. literalinclude:: ../../../../test_fixtures/integrations/datasets/config.json You could run this with: .. code-block:: tango run config.json """ import re from pathlib import Path from typing import Any, Dict, List, Optional, TypeVar, Union, overload from tango.common.aliases import PathOrStr from tango.common.dataset_dict import DatasetDict, IterableDatasetDict from tango.common.exceptions import ConfigurationError, IntegrationMissingError from tango.format import Format from tango.step import Step try: import datasets as ds except ModuleNotFoundError: raise IntegrationMissingError("datasets") __all__ = [ "LoadDataset", "LoadStreamingDataset", "DatasetsFormat", "convert_to_tango_dataset_dict", "InterleaveDatasets", "ConcatenateDatasets", "DatasetRemixStep", ] @overload def convert_to_tango_dataset_dict(hf_dataset_dict: ds.DatasetDict) -> DatasetDict: ... @overload def convert_to_tango_dataset_dict(hf_dataset_dict: ds.IterableDatasetDict) -> IterableDatasetDict: # type: ignore ... def convert_to_tango_dataset_dict(hf_dataset_dict): """ A helper function that can be used to convert a HuggingFace :class:`~datasets.DatasetDict` or :class:`~datasets.IterableDatasetDict` into a native Tango :class:`~tango.common.DatasetDict` or :class:`~tango.common.IterableDatasetDict`. This is important to do when your dataset dict is input to another step for caching reasons. """ if isinstance(hf_dataset_dict, ds.IterableDatasetDict): return IterableDatasetDict(splits=hf_dataset_dict) else: return DatasetDict(splits=hf_dataset_dict) T = Union[ds.Dataset, ds.DatasetDict] @Format.register("datasets") class DatasetsFormat(Format[T]): """ This format writes a :class:`datasets.Dataset` or :class:`datasets.DatasetDict` to disk using :meth:`datasets.Dataset.save_to_disk()`. It is the default :class:`~tango.format.Format` for the :class:`LoadDataset` step. """ VERSION = "001" def write(self, artifact: T, dir: PathOrStr): dataset_path = Path(dir) / "data" artifact.save_to_disk(str(dataset_path)) def read(self, dir: PathOrStr) -> T: dataset_path = Path(dir) / "data" return ds.load_from_disk(str(dataset_path)) @Step.register("datasets::load") class LoadDataset(Step): """ This step loads a `HuggingFace dataset `_. .. tip:: Registered as a :class:`~tango.step.Step` under the name "datasets::load". .. important:: If you are loading an :class:`~datasets.IterableDataset` or :class:`~datasets.IterableDatasetDict` you need to use the :class:`LoadStreamingDataset` step instead. """ DETERMINISTIC = True VERSION = "001" CACHEABLE = True # Even though HuggingFace datasets has its own caching mechanism, it can still be worth caching # this step with tango's mechanism since some datasets take a really long time to query from HuggingFace # ("bigscience/P3", for example). Tango's caching mechanism circumvents that issue. FORMAT = DatasetsFormat() def run(self, path: str, **kwargs) -> Union[ds.DatasetDict, ds.Dataset]: # type: ignore """ Load the HuggingFace dataset specified by ``path``. ``path`` is the canonical name or path to the dataset. Additional key word arguments are passed as-is to :func:`datasets.load_dataset()`. """ dataset = ds.load_dataset(path, **kwargs) if not isinstance(dataset, (ds.Dataset, ds.DatasetDict)): raise ConfigurationError( f"{self.__class__.__name__} can only be used with non-streaming datasets. " f"For streaming datasets, use the 'LoadStreamingDataset' ('datasets::load_streaming') step instead." ) return dataset @Step.register("datasets::load_streaming") class LoadStreamingDataset(Step): """ This step loads an iterable/streaming `HuggingFace dataset `_. .. tip:: Registered as a :class:`~tango.step.Step` under the name "datasets::load_streaming". """ DETERMINISTIC = True VERSION = "001" CACHEABLE = ( False # can't be cached with `DatasetsFormat`, and might be really inefficient anyway. ) def run( # type: ignore self, path: str, **kwargs ) -> Union[ds.IterableDatasetDict, ds.IterableDataset]: """ Load the HuggingFace streaming dataset specified by ``path``. ``path`` is the canonical name or path to the dataset. Additional key word arguments are passed as-is to :func:`datasets.load_dataset()`. """ dataset = ds.load_dataset(path, **kwargs) if not isinstance(dataset, (ds.IterableDataset, ds.IterableDatasetDict)): raise ConfigurationError( f"{self.__class__.__name__} can only be used with streaming datasets. " f"For non-streaming datasets, use the 'LoadDataset' ('datasets::load') step instead." ) return dataset DatasetType = TypeVar("DatasetType", ds.Dataset, ds.IterableDataset) @Step.register("datasets::interleave") class InterleaveDatasets(Step): """ This steps interleaves multiple datasets using :func:`~datasets.interleave_datasets()`. .. tip:: Registered as a :class:`~tango.step.Step` under the name "datasets::interleave". """ DETERMINISTIC = True VERSION = "001" CACHEABLE = False # Not worth caching def run( # type: ignore[override] self, datasets: List[DatasetType], probabilities: Optional[List[float]] = None, seed: Optional[int] = None, ) -> DatasetType: """ Interleave the list of datasets. """ return ds.interleave_datasets(datasets, probabilities=probabilities, seed=seed) @Step.register("datasets::concatenate") class ConcatenateDatasets(Step): """ This step concatenates multiple datasets using :func:`~datasets.concatenate_datasets()`. .. tip:: Registered as a :class:`~tango.step.Step` under the name "datasets::concatenate". """ DETERMINISTIC = True VERSION = "001" CACHEABLE = False # Not worth caching def run( # type: ignore[override] self, datasets: List[ds.Dataset], info: Optional[Any] = None, split: Optional[Any] = None, axis: int = 0, ) -> ds.Dataset: """ Concatenate the list of datasets. """ return ds.concatenate_datasets(datasets, info=info, split=split, axis=axis) @Step.register("datasets::dataset_remix") class DatasetRemixStep(Step): """ This step can remix splits in a :class:`~datasets.DatasetDict` into new splits. .. tip:: Registered as a :class:`~tango.step.Step` under the name "datasets::dataset_remix". Examples -------- .. testcode:: :hide: from tango.common.logging import initialize_logging initialize_logging(enable_cli_logs=True) import datasets .. testcode:: input = datasets.load_dataset("lhoestq/test") new_splits = { "all": "train + validation", "crossval_train": "train[:1] + validation[1:]", "crossval_test": "train[1:] + validation[:1]", } step = DatasetRemixStep() remixed_dataset = step.run(input=input, new_splits=new_splits) .. testoutput:: :hide: :options: +ELLIPSIS ... """ DETERMINISTIC = True CACHEABLE = True VERSION = "001" def run( # type: ignore self, input: ds.DatasetDict, new_splits: Dict[str, str], keep_old_splits: bool = True, shuffle_before: bool = False, shuffle_after: bool = False, random_seed: int = 1532637578, ) -> ds.DatasetDict: """ Remixes and shuffles a dataset. This is done eagerly with native 🤗 Datasets features. :param input: The input dataset that will be remixed. :param new_splits: Specifies the new splits that the output dataset should have. Keys are the name of the new splits. Values refer to the original splits. You can refer to original splits in the following ways: * Mention the original split name to copy it to a new name. * Mention the original split name with Python's slicing syntax to select part of the original split's instances. For example, ``"train[:1000]"`` selects the first 1000 instances from the ``"train"`` split. * ``"instances + instances"`` concatenates the instances into one split. You can combine these possibilities. :param keep_old_splits: Whether to keep the splits from the input dataset in addition to the new ones given by ``new_splits``. :param shuffle_before: Whether to shuffle the input splits before creating the new ones. If you need shuffled instances and you're not sure the input is properly shuffled, use this. :param shuffle_after: Whether to shuffle the input splits after creating the new ones. If you need shuffled instances and you're slicing or concatenating splits, use this. If you want to be on the safe side, shuffle both before and after. :param random_seed: Random seed, affects shuffling :returns: Returns a new dataset that is appropriately remixed. """ if shuffle_before: input = input.shuffle(random_seed) def get_slice(split_name: str) -> ds.Dataset: slice_match = re.match(r"(.*)\[(-?[0-9]*:-?[0-9]*)\]", split_name) if slice_match is None: return input[split_name] else: split_name = slice_match[1] slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(":")] slice_indices = range(*slice(*slice_args).indices(len(input[split_name]))) return input[split_name].select(slice_indices) def parse_split_spec(split_spec: str): parts = [get_slice(name.strip()) for name in split_spec.split("+")] if len(parts) == 1: return parts[0] else: return ds.concatenate_datasets(parts) if keep_old_splits: result = ds.DatasetDict(input.items()) else: result = ds.DatasetDict() result.update( { new_split_name: parse_split_spec(new_split_spec) for new_split_name, new_split_spec in new_splits.items() } ) if shuffle_after: result = result.shuffle(random_seed) return result ================================================ FILE: tango/integrations/fairscale/__init__.py ================================================ """ .. important:: To use this integration you should install ``tango`` with the "fairscale" extra (e.g. ``pip install tango[fairscale]``) or just install FairScale after the fact. This integration also depends on `PyTorch `_, so make sure you install the correct version of torch *first* given your operating system and supported CUDA version. Check `pytorch.org/get-started/locally/ `_ for more details. Components for Tango integration with `FairScale `_. Overview -------- FairScale is a PyTorch library for large scale training. Among other things, it implements the main memory-savings techniques for distributed data-parallel training (DDP) that came from the paper `ZeRO: Memory Optimization Towards Training A Trillion Parameter Models `_. The main part of this Tango integration is the :class:`FairScaleTrainingEngine`. This is a :class:`~tango.integrations.torch.TrainingEngine` implementation that utilizes FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` (FSDP) for substantial memory savings during distributed training. For the best performance you should also use :func:`with_wrapped_modules()` to wrap the inner modules of your :class:`~tango.integrations.torch.Model`. When used with FSDP this will dramatically reduce the memory required to load your model. """ from tango.common.exceptions import IntegrationMissingError try: import fairscale except ModuleNotFoundError: raise IntegrationMissingError("fairscale") __all__ = [ "FairScaleTrainingEngine", "FSDPConfig", "with_wrapped_modules", ] from .fsdp_config import FSDPConfig from .module_wrapper import with_wrapped_modules from .training_engine import FairScaleTrainingEngine ================================================ FILE: tango/integrations/fairscale/fsdp_config.py ================================================ from dataclasses import asdict, dataclass from typing import Any, Dict, Optional import torch from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from tango.common import FromParams @dataclass class FSDPConfig(FromParams): """ Defines all of the configurable options for FairScale's :class:`~fairscale.nn.FullyShardedDataParallel`. .. seealso:: `Best practices for FullyShardedDataParallel `_ from the FairScale docs. """ # noqa: E501 reshard_after_forward: bool = True """ See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. """ move_params_to_cpu: bool = False """ See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. """ move_grads_to_cpu: Optional[bool] = None """ See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. .. seealso:: :data:`move_params_to_cpu` .. warning:: At the moment we recommend that you don't mess with this parameter, or only explicitly set it to the same value as :data:`move_params_to_cpu`. If you leave it as ``None`` (the default), it will automatically be set to match :data:`move_params_to_cpu` by FairScale. Currently training seems to crash if you set this ``False`` while :data:`move_params_to_cpu` is ``True``. We're tracking `fairscale#918 `_, which may be related. """ mixed_precision: bool = False """ See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`. .. important:: We recommend setting this to the same value as the ``amp`` parameter in :class:`FairScaleTrainingEngine`. Based on our experiments, if you're training with AMP enabled (``amp=True``) you might see a small additional speedup in training time along with a small additional decrease in GPU memory utilization without any performance penalty (with respect to convergence) by setting this to ``True``. But if you're *not* training with AMP, setting this ``True`` could impact the model's ability to converge. """ def as_kwargs(self) -> Dict[str, Any]: """ Convert to the appropriate ``kwargs`` for :class:`~fairscale.nn.FullyShardedDataParallel`. """ return asdict(self) def wrap(self, module: torch.nn.Module): """ A convenience method for wrapping a module in :class:`~fairscale.nn.FullyShardedDataParallel` with all of the options defined in this class. .. seealso:: Internally this is what :func:`with_wrapped_modules()` calls. """ return FSDP(module, **self.as_kwargs()) ================================================ FILE: tango/integrations/fairscale/module_wrapper.py ================================================ import re from typing import Optional, Set import torch import torch.nn as nn from fairscale.nn.checkpoint import checkpoint_wrapper from tango.integrations.torch import Model from .fsdp_config import FSDPConfig @Model.register("fairscale::with_wrapped_modules") # type: ignore[arg-type] def with_wrapped_modules( model: Model, modules_to_wrap: Set[str], fsdp_config: Optional[FSDPConfig] = None, activation_checkpointing: bool = False, ) -> Model: """ A :class:`~tango.integrations.torch.Model` wrapper that can be used to easily wrap inner modules of a model with FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` wrapper and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`. .. tip:: Registered as a :class:`~tango.integrations.torch.Model` constructor under the name "fairscale::with_wrapped_modules". .. important:: This is meant to be used with the :class:`FairScaleTrainingEngine`. :param model: The model to wrap. :param modules_to_wrap: The names of submodule to wrap. Can be regular expressions. :param fsdp_config: The ``FullyShardedDataParallel`` configuration to use when wrapping the modules. If not specified, the modules will NOT be wrapped with FSDP. :param activation_checkpointing: Whether to wrap the modules with FairScale's :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`. Examples -------- You can use this as a :class:`~tango.integrations.torch.Model` constructor from a config/params like this: .. testcode:: import torch.nn as nn from tango.integrations.torch import Model class FeedForward(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(4, 4) self.activation = nn.ReLU() def forward(self, x): return self.activation(self.linear(x)) @Model.register("simple_regression_model") class SimpleRegressionModel(Model): def __init__(self): super().__init__() self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)]) self.regression_head = nn.Linear(4, 1) self.loss_fcn = nn.MSELoss() def forward(self, x, y): output = self.blocks(x) output = self.regression_head(output) loss = self.loss_fcn(output, y) return {"loss": loss} model = Model.from_params({ "type": "fairscale::with_wrapped_modules", "model": { "type": "simple_regression_model", }, "modules_to_wrap": [r"blocks\\.[0-9]+", "regression_head"], "activation_checkpointing": True, }) """ def wrap_module( module: nn.Module, ) -> nn.Module: if activation_checkpointing: module = checkpoint_wrapper(module, offload_to_cpu=True) if fsdp_config is not None and torch.distributed.is_initialized(): module = fsdp_config.wrap(module) return module all_module_names: Set[str] = set([name for name, _ in model.named_modules() if name]) actual_modules_to_wrap: Set[str] = set() unmatched_patterns: Set[str] = modules_to_wrap.copy() for module_name in all_module_names: for pattern in modules_to_wrap: if re.fullmatch(pattern, module_name): actual_modules_to_wrap.add(module_name) if pattern in unmatched_patterns: unmatched_patterns.remove(pattern) if unmatched_patterns: raise ValueError( f"Some patterns in 'modules_to_wrap' did not match actual module names ({unmatched_patterns})" ) for module_name in actual_modules_to_wrap: if "." in module_name: *parent_parts, module_name = module_name.split(".") parent_module = model.get_submodule(".".join(parent_parts)) else: parent_module = model module = parent_module.get_submodule(module_name) module = wrap_module(module) parent_module.add_module(module_name, module) return model ================================================ FILE: tango/integrations/fairscale/training_engine.py ================================================ import logging from pathlib import Path from typing import Any, Dict, List, Optional, Union import torch from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP from fairscale.optim.grad_scaler import ShardedGradScaler from tango.common import Lazy from tango.common.exceptions import ConfigurationError from tango.integrations.torch import ( LRScheduler, Model, Optimizer, TorchTrainingEngine, TrainConfig, TrainingEngine, ) from .fsdp_config import FSDPConfig @TrainingEngine.register("fairscale") class FairScaleTrainingEngine(TorchTrainingEngine): """ A :class:`~tango.integrations.torch.TrainingEngine` that leverages FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` for use within :class:`~tango.integrations.torch.TorchTrainStep`. .. tip:: Registered as an :class:`~tango.integrations.torch.TrainingEngine` under the name "fairscale". .. tip:: To get the best performance out of :class:`FairScaleTrainingEngine` you should wrap individual layers of your model with :class:`~fairscale.nn.FullyShardedDataParallel` and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper` while instantiating them. You can use :class:`with_wrapped_modules()` to accomplish this. .. important:: Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within :class:`~tango.integrations.torch.TorchTrainStep`. .. warning:: :class:`~FairScaleTrainingEngine` can only be used in distributed training, i.e. when ``device_count > 1`` in the :class:`~tango.integrations.torch.TorchTrainStep`. For maximum memory savings, we recommend training with AMP enabled and the following :class:`FSDPConfig`: .. testcode:: from tango.integrations.fairscale import FSDPConfig fsdp_config = FSDPConfig( reshard_after_forward=True, move_params_to_cpu=True, move_grads_to_cpu=True, mixed_precision=True, ) For maximum training *speed*, we recommend training with AMP enabled and the following :class:`FSDPConfig`: .. testcode:: from tango.integrations.fairscale import FSDPConfig fsdp_config = FSDPConfig( reshard_after_forward=False, move_params_to_cpu=False, move_grads_to_cpu=False, mixed_precision=True, ) :param amp: Use automatic mixed precision (AMP). Default is ``False``. :param max_grad_norm: If set, gradients will be clipped to have this max norm. Default is ``None``. :param amp_use_bfloat16: Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training. Only applicable when ``amp=True``. If not specified, the default behavior will be to use ``bfloat16`` when training with AMP on CPU, otherwise not. :param fsdp_config: The options for :class:`~fairscale.nn.FullyShardedDataParallel`. If not specified, the default options will be used. """ def __init__( self, train_config: TrainConfig, model: Lazy[Model], optimizer: Lazy[Optimizer], *, lr_scheduler: Optional[Lazy[LRScheduler]] = None, amp: bool = False, max_grad_norm: Optional[float] = None, amp_use_bfloat16: Optional[bool] = None, fsdp_config: Optional[FSDPConfig] = None, ) -> None: if not train_config.is_distributed: raise ConfigurationError( f"{self.__class__.__name__} can only be used with distributed training" ) self.fsdp_config = fsdp_config or FSDPConfig() self.logger = logging.getLogger(self.__class__.__name__) super().__init__( train_config, model, optimizer, lr_scheduler=lr_scheduler, amp=amp, max_grad_norm=max_grad_norm, amp_use_bfloat16=amp_use_bfloat16, ) if amp: self.grad_scaler = ShardedGradScaler() def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model: if isinstance(model, Lazy): model = model.construct() if not self.fsdp_config.move_params_to_cpu: model.to(self.train_config.worker_local_default_device) return FSDP(model, **self.fsdp_config.as_kwargs()) def clip_grad_norm(self) -> None: if self.max_grad_norm is not None: self.model.clip_grad_norm_(self.max_grad_norm) # type: ignore def get_model_state(self) -> Dict[str, torch.Tensor]: return { "weights": self.model.local_state_dict(), # type: ignore "metadata": self.model.local_metadata_dict(), # type: ignore } def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None: self.model.load_local_state_dict(state_dict["weights"]) # type: ignore def save_complete_weights_from_checkpoint( self, checkpoint_dir: Path, weights_path: Path ) -> None: self.logger.info("Consolidating sharded checkpoint weights...") sharded_weights: List[Dict[str, torch.Tensor]] = [] sharded_metadata: List[Dict[str, Any]] = [] for path in checkpoint_dir.resolve().glob("worker*_model.pt"): sharded_state = torch.load(path, map_location="cpu") sharded_weights.append(sharded_state["weights"]) sharded_metadata.append(sharded_state["metadata"]) full_state = FSDP.consolidate_shard_weights(sharded_weights, sharded_metadata) del sharded_weights del sharded_metadata torch.save(full_state, weights_path) ================================================ FILE: tango/integrations/flax/__init__.py ================================================ from tango.common.exceptions import IntegrationMissingError try: import flax except ModuleNotFoundError: raise IntegrationMissingError("flax") __all__ = [ "DataLoader", "FlaxDataLoader", "LRScheduler", "Model", "Optimizer", "FlaxTrainStep", "FlaxFormat", "TrainCallback", "EvalCallback", "FlaxWrapper", "TrainConfig", "FlaxEvalStep", ] from .data import DataLoader, FlaxDataLoader from .eval import FlaxEvalStep from .eval_callback import EvalCallback from .format import FlaxFormat from .model import Model from .optim import LRScheduler, Optimizer from .train import FlaxTrainStep from .train_callback import TrainCallback from .train_config import TrainConfig from .wrapper import FlaxWrapper ================================================ FILE: tango/integrations/flax/data.py ================================================ import logging from typing import Generic, TypeVar import jax.random import numpy as np from datasets import Dataset from flax.training.common_utils import shard from tango.common.registrable import Registrable T = TypeVar("T") class DataLoader(Generic[T], Registrable): """ A :class:`~tango.common.Registrable` version of a ``Flax DataLoader``. ``Flax DataLoader`` accepts Dataset object. The class yields a numpy batch. """ @DataLoader.register("flax::dataloader") class FlaxDataLoader(DataLoader): def __init__( self, dataset: Dataset, batch_size: int = 8, drop_last: bool = True, shuffle: bool = True, ): self.dataset = dataset self.dataset_size = dataset.num_rows self.batch_size = batch_size self.drop_last = drop_last if not drop_last: raise NotImplementedError( "With Jax you have to drop the last incomplete batch, because the batch size is compiled into the " "model." ) self.shuffle = shuffle self.logger = logging.getLogger(FlaxDataLoader.__name__) def __call__(self, rng: jax._src.random.KeyArrayLike, do_distributed: bool): steps_per_epoch = self.dataset_size // self.batch_size if self.shuffle: perms = jax.random.permutation(rng, self.dataset_size) perms = np.asarray(perms) # using jax arrays for indexing is a bottleneck on TPUs. else: perms = np.arange(self.dataset_size) self.logger.info("Skipping last incomplete batch") perms = perms[: steps_per_epoch * self.batch_size] # Skip incomplete batch. perms = perms.reshape((steps_per_epoch, self.batch_size)) for perm in perms: batch = self.dataset[perm] if do_distributed: batch = shard(batch) yield batch ================================================ FILE: tango/integrations/flax/eval.py ================================================ import logging from collections import defaultdict from itertools import islice from typing import Dict, List, Optional, Sequence import jax from flax import jax_utils from flax.training.train_state import TrainState from tango.common.dataset_dict import DatasetDictBase from tango.common.exceptions import ConfigurationError from tango.common.lazy import Lazy from tango.common.tqdm import Tqdm from tango.format import Format, JsonFormat from tango.step import Step from .data import FlaxDataLoader from .eval_callback import EvalCallback from .util import get_PRNGkey from .wrapper import FlaxWrapper @Step.register("flax::eval") class FlaxEvalStep(Step): """ A Flax evaluation loop that pairs well with :class:`FlaxTrainStep`. .. tip:: Registered as a :class:`~tango.step.Step` under the name "flax::eval". .. important:: The evaluation loop will use a GPU/TPU automatically if one is available. You can control which GPU it uses with the environment variable ``CUDA_VISIBLE_DEVICES``. For example, set ``CUDA_VISIBLE_DEVICES=1`` to force ``FlaxEvalStep`` to only use the GPU with ID 1. .. warning:: By default the metrics specified by the ``metric_names`` parameter are aggregated by simply averaging across batches. This behavior is usually correct for metrics like "loss" or "accuracy", for example, but may not be correct for other metrics like "F1". If this is not correct for your metric you will need to handle the aggregation internally in your model or with an :class:`EvalCallback` using the :meth:`EvalCallback.post_batch()` method. Then set the parameter ``auto_aggregate_metrics`` to ``False``. """ DETERMINISTIC = True CACHEABLE = True FORMAT: Format = JsonFormat() SKIP_ID_ARGUMENTS = {"log_every"} def run( # type: ignore[override] self, state: TrainState, dataset: DatasetDictBase, dataloader: Lazy[FlaxDataLoader], wrapper: FlaxWrapper, test_split: str = "test", seed: int = 42, log_every: int = 1, do_distributed: bool = False, eval_steps: Optional[int] = None, metric_names: Sequence[str] = ("loss",), auto_aggregate_metrics: bool = True, callbacks: Optional[List[Lazy[EvalCallback]]] = None, ) -> Dict[str, float]: """ Evaluate the ``model``. :param state: The state of the model to evaluate. This contains the parameters. :param dataset: Should contain the test data. :param dataloader: The data loader that generates test batches. The batches should be :class:`dict` objects. :param wrapper: The wrapper should define :meth:`eval_metrics`. :param test_split: The name of the data split used for evaluation in the ``dataset_dict``. Default is "test". :param seed: Used to set the PRNG states at the beginning of the evaluation loop. :param log_every: Log every this many steps. Default is ``1``. :param do_distributed: Whether to do distributed training or not. Set as 0 or 1. :param eval_steps: The number of steps to evaluate for. If not specified evaluation will stop after a complete iteration through the ``dataloader``. :param metric_names: The names of the metrics to track and aggregate. Default is ``("loss",)``. :param auto_aggregate_metrics: If ``True`` (the default), the metrics will be averaged across batches. This may not be the correct behavior for some metrics (such as F1), in which you should set this to ``False`` and handle the aggregation internally in your model or with an :class:`EvalCallback` (using :meth:`EvalCallback.post_batch()`). :param callbacks: A list of :class:`EvalCallback`. """ logger = logging.getLogger(FlaxEvalStep.__name__) # construct dataloader dataloader: FlaxDataLoader = dataloader.construct( dataset=dataset[test_split].set_format("numpy") ) steps: int try: dataloader_len = dataloader.dataset_size steps = dataloader_len if eval_steps is None else min(dataloader_len, eval_steps) except TypeError: if eval_steps is None: raise ConfigurationError( "You must set 'eval_steps' for streaming/iterable datasets" ) else: steps = eval_steps if do_distributed: devices = jax.devices() if len(devices) <= 1: raise ConfigurationError( "YOu have set distributed training=True but there is only one device." ) # Initialize callbacks callbacks: List[EvalCallback] = [ callback.construct( step_id=self.unique_id, work_dir=self.work_dir, dataset_dict=dataset, dataloader=dataloader, ) for callback in (callbacks or []) ] for callback in callbacks: callback.pre_eval_loop() rng = get_PRNGkey(seed) devices = jax.devices() if len(devices) > 1: do_distributed = True def eval_step(state, batch): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=state.params, train=False)[0] metrics = wrapper.eval_metrics(batch, logits, labels) if do_distributed: metrics = jax.lax.pmean(metrics, axis_name="batch") return logits, metrics if do_distributed: state = jax_utils.replicate(state) parallel_eval_step = jax.pmap(eval_step, axis_name="batch") eval_batches = enumerate(islice(dataloader(rng, do_distributed), steps)) running_metrics: Dict[str, float] = defaultdict(float) aggregated_metrics: Dict[str, float] = defaultdict(float) with Tqdm.tqdm(eval_batches, desc="Evaluating", total=steps) as batch_iter: for step, batch in batch_iter: should_log_this_step = step % log_every == 0 or step == steps - 1 for callback in callbacks: callback.pre_batch(step, batch) if do_distributed: logits, metrics = parallel_eval_step(state, batch) metrics = jax_utils.unreplicate(metrics) else: logits, metrics = eval_step(state, batch) for callback in callbacks: callback.post_batch(step, logits) if auto_aggregate_metrics: for key, val in metrics.items(): if key in metric_names: running_metrics[key] += metrics[key].item() aggregated_metrics[key] = running_metrics[key] / (step + 1) else: aggregated_metrics.update(metrics) if should_log_this_step: batch_iter.set_postfix(**aggregated_metrics) del batch logger.info("Evaluation Metrics:") for key, val in aggregated_metrics.items(): logger.info(key, ":", val) for callback in callbacks: callback.post_eval_loop(aggregated_metrics) return aggregated_metrics ================================================ FILE: tango/integrations/flax/eval_callback.py ================================================ from pathlib import Path from typing import Any, Dict from tango.common.dataset_dict import DatasetDictBase from tango.common.registrable import Registrable from tango.workspace import Workspace from .data import FlaxDataLoader class EvalCallback(Registrable): """ An ``EvalCallback`` is a :class:`~tango.common.Registrable` class that can be used within :class:`FlaxEvalStep` to customize the behavior of the evaluation loop, similar to how :class:`TrainCallback` is used to customize the behavior of the training loop. .. tip:: All of the parameters to this base class will be automatically set within the training loop, so you shouldn't include them in your config for your callbacks. :ivar Workspace workspace: The tango workspace being used. :ivar str step_id: The unique ID of the step. :ivar pathlib.Path work_dir: The working directory of the step :ivar DatasetDictBase dataset_dict: The dataset dict containing the evaluation split. :ivar DataLoader dataloader: The data loader used to load the evaluation split data. """ def __init__( self, workspace: Workspace, step_id: str, work_dir: Path, dataset_dict: DatasetDictBase, dataloader: FlaxDataLoader, ) -> None: self.workspace = workspace self.step_id = step_id self.work_dir = work_dir self.dataset_dict = dataset_dict self.dataloader = dataloader def pre_eval_loop(self) -> None: """ Called right before the first batch is processed """ pass def post_eval_loop(self, aggregated_metrics: Dict[str, float]) -> None: """ Called after the evaluation loop completes with the final aggregated metrics. This is the last method that is called, so any cleanup can be done in this method. """ pass def pre_batch(self, step: int, batch: Dict[str, Any]) -> None: """ Called directly before processing a batch. """ pass def post_batch(self, step: int, batch_outputs: Dict[str, Any]) -> None: """ Called directly after processing a batch with the outputs of the batch. .. tip:: This method can be used to modify ``batch_outputs`` in place, which is useful in scenarios where you might need to aggregate metrics in a special way other than a simple average. If that's the case, make sure to set ``auto_aggregate_metrics`` to ``False`` in :class:`FlaxEvalStep`. """ pass ================================================ FILE: tango/integrations/flax/format.py ================================================ from pathlib import Path from typing import Generic, TypeVar from flax.training import checkpoints from tango.common.aliases import PathOrStr from tango.format import Format T = TypeVar("T") @Format.register("flax") class FlaxFormat(Format[T], Generic[T]): """ This format writes the artifact. .. tip:: Registered as a :class:`~tango.format.Format` under the name "flax". """ VERSION = "002" def write(self, artifact: T, dir: PathOrStr) -> None: checkpoints.save_checkpoint(Path(dir), artifact, step=0) def read(self, dir: PathOrStr) -> T: # will return a dict return checkpoints.restore_checkpoint(dir, target=None) ================================================ FILE: tango/integrations/flax/model.py ================================================ from flax import linen as nn from tango.common.registrable import Registrable class Model(nn.Module, Registrable): """ This is a :class:`~tango.common.Registrable` mixin class that inherits from :class:`flax.linen.Module`. Its :meth:`~flax.linen.Module.setup()` can be used to register submodules, variables, parameters you will need in your model. Its :meth:`~flax.linen.Module.__call__()` returns the output of the model for a given input. """ ================================================ FILE: tango/integrations/flax/optim.py ================================================ from inspect import isfunction from typing import Callable, Type import optax from tango.common.registrable import Registrable class Optimizer(Registrable): """ A :class:`~tango.common.Registrable` version of Optax optimizers. All `built-in Optax optimizers `_ are registered according to their class name (e.g. "optax::adam"). .. tip:: You can see a list of all available optimizers by running .. testcode:: from tango.integrations.flax import Optimizer for name in sorted(Optimizer.list_available()): print(name) .. testoutput:: :options: +ELLIPSIS optax::adabelief optax::adadelta optax::adafactor optax::adagrad optax::adam ... """ def __init__(self, optimizer: Callable) -> None: self.optimizer = optimizer def __call__(self, **kwargs) -> optax.GradientTransformation: return self.optimizer(**kwargs) class LRScheduler(Registrable): """ A :class:`~tango.common.Registrable` version of an Optax learning rate scheduler. All `built-in Optax learning rate schedulers `_ are registered according to their class name (e.g. "optax::linear_schedule"). .. tip:: You can see a list of all available schedulers by running .. testcode:: from tango.integrations.flax import LRScheduler for name in sorted(LRScheduler.list_available()): print(name) .. testoutput:: :options: +ELLIPSIS optax::constant_schedule optax::cosine_decay_schedule optax::cosine_onecycle_schedule optax::exponential_decay ... """ def __init__(self, scheduler: Callable) -> None: self.scheduler = scheduler def __call__(self, **kwargs): return self.scheduler(**kwargs) def optimizer_factory(optim_method: Callable) -> Type[Callable]: def factory_func(): return Optimizer(optim_method) return factory_func() def scheduler_factory(scheduler_method: Callable) -> Type[Callable]: def factory_func(): return LRScheduler(scheduler_method) return factory_func() # Register all optimizers. for name, cls in optax._src.alias.__dict__.items(): if isfunction(cls) and not name.startswith("_") and cls.__annotations__: factory_func = optimizer_factory(cls) Optimizer.register("optax::" + name)(factory_func) # Register all learning rate schedulers. for name, cls in optax.schedules.__dict__.items(): if isfunction(cls) and not name.startswith("_") and cls.__annotations__: factory_func = scheduler_factory(cls) LRScheduler.register("optax::" + name)(factory_func) # TODO: Handle inject_hyperparams. # Refer: https://optax.readthedocs.io/en/latest/api.html?highlight=inject%20hyperparam ================================================ FILE: tango/integrations/flax/train.py ================================================ import logging import time from collections import defaultdict from pathlib import Path from typing import Any, DefaultDict, Dict, List, Optional import jax import jax.numpy as jnp from flax import jax_utils from flax.training import checkpoints from flax.training.train_state import TrainState from tango.common.dataset_dict import DatasetDictBase from tango.common.exceptions import ConfigurationError from tango.common.lazy import Lazy from tango.common.tqdm import Tqdm from tango.format import Format from tango.step import Step from tango.workspace import Workspace from .data import FlaxDataLoader from .format import FlaxFormat from .model import Model from .optim import LRScheduler, Optimizer from .train_callback import TrainCallback from .train_config import TrainConfig from .util import get_multiple_keys, get_PRNGkey from .wrapper import FlaxWrapper PyTree = Any @Step.register("flax::train") class FlaxTrainStep(Step): """ A Flax training step that supports distributed training with configurable dataloaders, callbacks, optimizer. .. tip:: Registered as a :class:`~tango.step.Step` under the name "flax::train". .. important:: To train on GPUs and TPUs, installation of jax[cuda] or jax[tpu] is required. Follow the instructions here: https://github.com/google/jax to set up jax for GPUs and TPUs. Note: CUDA and cuDNN installation is required to run jax on NVidia GPUs. It is recommended to install cuDNN in your conda environment using: ``conda install -c anaconda cudnn``. Distributed data parallel training is activated when the ``device_count`` is greater than 1. You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``. For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1`` (and ``device_count`` to 2). .. warning:: During validation, the validation metric (specified by the ``val_metric_name`` parameter) is aggregated by simply averaging across validation batches and distributed processes. This behavior is usually correct when your validation metric is "loss" or "accuracy", for example, but may not be correct for other metrics like "F1". If this is not correct for your metric you will need to handle the aggregation internally in your model or with a :class:`TrainCallback` using the :meth:`TrainCallback.post_val_batch()` method. Then set the parameter ``auto_aggregate_val_metric`` to ``False``. Jax pre-allocates 90% of GPU memory. If you run into out-of-memory (OOM) issues, please refer to this: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html. """ DETERMINISTIC = True CACHEABLE = True FORMAT: Format = FlaxFormat() SKIP_ID_ARGUMENTS = {"log_every"} METADATA = {"artifact_kind": "model"} def run( # type: ignore[override] self, model: Model, dataset: DatasetDictBase, optimizer: Lazy[Optimizer], train_dataloader: Lazy[FlaxDataLoader], *, wrapper: FlaxWrapper, seed: int = 42, keep_checkpoints: int = 5, lr_scheduler: Optional[Lazy[LRScheduler]] = None, train_split: str = "train", validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None, validation_split: Optional[str] = None, train_steps: Optional[int] = None, train_epoch: Optional[int] = None, validation_steps: Optional[int] = None, log_every: int = 10, checkpoint_every: int = 100, validate_every: Optional[int] = None, val_metric_name: str = "loss", minimize_val_metric: bool = True, auto_aggregate_val_metric: bool = True, callbacks: Optional[List[Lazy[TrainCallback]]] = None, remove_stale_checkpoints: bool = True, ) -> PyTree: """ Run a basic training loop to train the ``model``. :param model: The flax model to train. It should define ``__call__()``. Defining ``setup()`` is Optional. :param dataset: The train and optional validation dataset. :param optimizer: The name of the optax Optimizer to use for training. :param train_dataloader: The dataloader object that generates training batches. :param wrapper: A Wrapper class that defines ``loss_fn()``, ``eval_fn()`` and ``compute_metrics()`` :param seed: Used to set the PRNG state. By default, ``seed=42`` :param keep_checkpoints: An integer which denotes how many previous checkpoints should be stored while training. By default, ``keep_checkpoints=5`` :param lr_scheduler: The name of the learning rate scheduler. :param train_split: The name of the data split used for training in the ``dataset_dict``. Default is "train". :param validation_dataloader: An optional data loader for generating validation batches. The batches should be :class:`dict` objects. If not specified, but ``validation_split`` is given, the validation ``DataLoader`` will be constructed from the same parameters as the train ``DataLoader``. :param validation_split: Optional name of the validation split in the ``dataset_dict``. Default is ``None``, which means no validation. :param train_steps: The number of steps to train for. If not specified training will stop after a complete iteration through the ``train_dataloader``. :param train_epoch: The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs`` at the same time. :param validation_steps: The number of steps to validate for. If not specified validation will stop after a complete iteration through the ``validation_dataloader``. :param log_every: Log every this many steps. :param checkpoint_every: Save a checkpoint every this many steps. :param validate_every: Run the validation loop every this many steps. :param val_metric_name: The name of the validation metric, i.e. the key of the metric in the dictionary returned by the forward pass of the model. Default is "loss". :param minimize_val_metric: Whether the validation metric is meant to be minimized (such as the loss). Default is ``True``. When using a metric such as accuracy, you should set this to ``False``. :param auto_aggregate_val_metric: If ``True`` (the default), the validation metric will be averaged across validation batches and distributed processes. This may not be the correct behavior for some metrics (such as F1), in which you should set this to ``False`` and handle the aggregation internally in your model or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`). :param callbacks: A list of :class: `TrainCallback`. :param remove_stale_checkpoints: If ``True`` (the default), stale checkpoints will be removed throughout training so that only the latest and best checkpoints are kept. :returns: The trained model with the last checkpoint loaded. """ return self._train( dataset=dataset, model=model, optimizer=optimizer, train_dataloader=train_dataloader, train_wrapper=wrapper, seed=seed, keep_checkpoints=keep_checkpoints, lr_scheduler=lr_scheduler, train_split=train_split, validation_split=validation_split, validation_dataloader=validation_dataloader, train_steps=train_steps, train_epochs=train_epoch, validation_steps=validation_steps, log_every=log_every, checkpoint_every=checkpoint_every, validate_every=validate_every, val_metric_name=val_metric_name, minimize_val_metric=minimize_val_metric, auto_aggregate_val_metric=auto_aggregate_val_metric, callbacks=callbacks, remove_stale_checkpoints=remove_stale_checkpoints, ) def _train( self, model: Model, optimizer: Lazy[Optimizer], dataset: DatasetDictBase, train_dataloader: Lazy[FlaxDataLoader], *, train_wrapper: FlaxWrapper, seed: int = 42, keep_checkpoints: int = 5, lr_scheduler: Optional[Lazy[LRScheduler]], train_split: str = "train", validation_split: Optional[str] = None, validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None, train_steps: Optional[int] = None, train_epochs: Optional[int] = None, validation_steps: Optional[int] = None, log_every: int = 10, checkpoint_every: int = 100, validate_every: Optional[int] = None, val_metric_name: str = "loss", minimize_val_metric: bool = True, auto_aggregate_val_metric: bool = True, callbacks: Optional[List[Lazy[TrainCallback]]] = None, remove_stale_checkpoints: bool = True, ) -> PyTree: if validate_every is not None and validation_split is None: raise ConfigurationError( "You have set a validation interval, but no validation split. " "That's probably unintentional." ) if (train_steps is not None) and (train_epochs is not None): raise ConfigurationError( "One of 'train_steps' or 'train_epochs' needs to be specified, but not both." ) if isinstance(dataset, DatasetDictBase) and train_split is None: raise ConfigurationError("Specify the train split for Datasets object.") config = TrainConfig( self.unique_id, self.work_dir, step_name=self.name, train_split=train_split, validation_split=validation_split, seed=seed, train_steps=train_steps, train_epochs=train_epochs, log_every=log_every, checkpoint_every=checkpoint_every, validate_every=validate_every, validation_steps=validation_steps, val_metric_name=val_metric_name, minimize_val_metric=minimize_val_metric, auto_aggregate_val_metric=auto_aggregate_val_metric, remove_stale_checkpoints=remove_stale_checkpoints, ) optimizer = self._construct_optimizer(optimizer) lr_scheduler_: Optional[LRScheduler] = None if lr_scheduler is not None: lr_scheduler_ = self._construct_lr_scheduler(lr_scheduler) lr_scheduler = lr_scheduler_ final_model: Model final_model = self.train_helper( self.workspace, config, model, optimizer, keep_checkpoints, lr_scheduler, train_wrapper, dataset, train_dataloader, validation_dataloader, callbacks, ) assert final_model is not None return final_model def train_helper( self, workspace: Workspace, config: TrainConfig, model: Model, optimizer: Optimizer, keep_checkpoints: int, lr_scheduler: Optional[LRScheduler], train_wrapper: FlaxWrapper, dataset: DatasetDictBase, train_dataloader: Lazy[FlaxDataLoader], validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None, callbacks: Optional[List[Lazy[TrainCallback]]] = None, ) -> PyTree: if lr_scheduler is not None: raise NotImplementedError( "Learning rate scheduling is not supported by the flax trainer. " "Please voice your support for this feature at " "https://github.com/allenai/tango/issues/477." ) logger = logging.getLogger(FlaxTrainStep.__name__) # construct data loaders validation_dataloader_: Optional[FlaxDataLoader] = None if config.validation_split is not None: validation_dataset = dataset[config.validation_split] validation_dataset.set_format("numpy") if validation_dataloader is not None: validation_dataloader_ = validation_dataloader.construct(dataset=validation_dataset) else: validation_dataloader_ = train_dataloader.construct(dataset=validation_dataset) validation_dataloader = validation_dataloader_ train_dataset = dataset[config.train_split] train_dataset.set_format("numpy") # type:ignore train_dataloader: FlaxDataLoader = train_dataloader.construct(dataset=train_dataset) devices = self._get_devices() do_distributed: bool = False if len(devices) > 1: do_distributed = True if validation_dataloader is not None: validation_dataloader.batch_size *= len(devices) train_dataloader.batch_size *= len(devices) rng = get_PRNGkey(config.seed) if hasattr(model, "params"): params = model.params else: # TODO: Find better way to init the shape shape = list(train_dataset["x"].shape) shape[0] = 1 x = jnp.ones(shape) params = model.init(rng, x)["params"] state = TrainState.create(apply_fn=model.__call__, params=params, tx=optimizer) initial_state: Optional[Dict[str, Any]] = None if config.state_path.exists(): logger.info("Recovering from previous run at %s" % config.state_path) state = self.load_checkpoint(config.state_path, state) if config.train_epochs is None: assert config.train_steps is not None try: train_epochs = len(train_dataloader.dataset) // train_dataloader.batch_size except TypeError: raise ConfigurationError( "You must set train_epochs for streaming/iterable datasets" ) config.train_epochs = train_epochs assert config.train_epochs is not None if validation_dataloader is not None: if config.validation_steps is None: try: config.validation_steps = len(validation_dataloader.dataset) except TypeError: raise ConfigurationError( "You must set 'validation_steps' for streaming/iterable datasets" ) val_metric: Optional[float] = None best_val_metric: Optional[float] = None start_step: int = 0 if initial_state is not None: val_metric = initial_state[f"val_{config.val_metric_name}"] best_val_metric = initial_state[f"best_{config.val_metric_name}"] start_step = initial_state["training_epochs"] # Initialize callbacks callbacks: List[TrainCallback] = [ callback.construct( workspace=workspace, train_config=config, dataset=dataset, train_dataloader=train_dataloader, model=model, optimizer=optimizer, validation_dataloader=validation_dataloader, ) for callback in (callbacks or []) ] if initial_state: for callback, state in zip(callbacks, initial_state["callbacks"]): callback.load_state_dict(state) del initial_state if start_step > 0: with Tqdm.tqdm( train_dataloader, desc=f"Catching dataloader up to step {start_step}", total=start_step - 1, ) as batch_iter: for step, batch in enumerate(batch_iter): del batch if step >= start_step - 1: break def train_step(state, batch, dropout_rng): # if transformer model labels = batch.pop("labels") dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) grad_fn = jax.value_and_grad(train_wrapper.train_loss) loss, grad = grad_fn(state.params, state, batch, dropout_rng, labels) if do_distributed: grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) other_metrics = train_wrapper.train_metrics(state, batch, labels=labels) metrics = {"loss": loss} metrics.update(other_metrics) if do_distributed: metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics, new_dropout_rng def val_step(state, batch): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=state.params, train=False)[0] metrics = train_wrapper.val_metrics(batch, logits, labels) if do_distributed: metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics if do_distributed: # NOTE: The trainer currently handles only data parallelism. state = jax_utils.replicate(state) dropout_rngs = get_multiple_keys(rng, jax.local_device_count()) parallel_train_step = jax.pmap(train_step, axis_name="batch") parallel_val_step = jax.pmap(val_step, axis_name="batch") step_per_epoch = train_dataloader.dataset_size // train_dataloader.batch_size config.train_steps = step_per_epoch * config.train_epochs assert config.train_steps is not None # for mypy for callback in callbacks: callback.pre_train_loop() logger.info("***** Running training *****") logger.info(f" Num examples = {train_dataloader.dataset_size}") logger.info(f" Num Epochs = {config.train_epochs}") logger.info( f" Total train batch size (w. parallel & distributed) = {train_dataloader.batch_size}" ) logger.info(f" Total optimization steps = {config.train_steps}") step = start_step epochs = Tqdm.tqdm( range(config.train_epochs), desc=f"Epoch (1/{config.train_epochs})", position=0 ) for epoch in epochs: start = time.time() train_metrics = [] for callback in callbacks: callback.pre_epoch(step, epoch) train_loader = train_dataloader(rng, do_distributed) for _ in Tqdm.tqdm(range(step_per_epoch), desc="Training", position=1): batch = next(train_loader) for callback in callbacks: callback.pre_batch(step, epoch, batch) if do_distributed: state, train_metric, dropout_rngs = parallel_train_step( state, batch, dropout_rngs ) else: state, train_metric, rng = train_step(state, batch, rng) train_metrics.append(train_metric) for callback in callbacks: callback.post_batch(step, epoch, train_metric) if config.should_log_this_step(step): for callback in callbacks: callback.log_batch(step, epoch, train_metric) if config.should_checkpoint_this_step(step): self.save_checkpoint(config.state_path, state, step, keep_checkpoints) step += 1 # check if we need to do validation if config.validation_split is None: # If we can't validate, we don't. should_validate = False elif step == config.train_steps - 1: # If we're at the end of the training run, we always validate. should_validate = True elif config.validate_every is not None and step % config.validate_every == 0: # If validate_every is given, we use that to decide. should_validate = True else: # Otherwise, we don't validate. should_validate = False if should_validate: assert validation_dataloader is not None assert config.validation_steps is not None val_metrics: DefaultDict = defaultdict(list) epoch_eval_metrics: DefaultDict = defaultdict(float) val_dataloader = validation_dataloader(rng, do_distributed) valid_step = 0 total_val_steps = len(validation_dataset) // validation_dataloader.batch_size for callback in callbacks: callback.pre_val_loop(step, valid_step, state) for _ in Tqdm.tqdm(range(total_val_steps), desc="Evaluating", position=2): batch = next(val_dataloader) for callback in callbacks: callback.pre_val_batch(step, valid_step, epoch, batch) if do_distributed: metrics = parallel_val_step(state, batch) metrics = jax_utils.unreplicate(metrics) else: metrics = val_step(state, batch) for key, value in metrics.items(): val_metrics[key].append(value.item()) for callback in callbacks: callback.post_val_batch(step, valid_step, epoch, val_metrics) valid_step += 1 for key, value in val_metrics.items(): if config.auto_aggregate_val_metric: epoch_eval_metrics[key] = jax.tree_map( jnp.mean, jnp.array(value) ).item() else: epoch_eval_metrics[key] = metrics[key].item() for key, value in epoch_eval_metrics.items(): print("Validation %s : %.5f" % (key, value)) val_metric = epoch_eval_metrics[config.val_metric_name] assert val_metric is not None if best_val_metric is None: best_val_metric = val_metric elif config.minimize_val_metric and val_metric <= best_val_metric: best_val_metric = val_metric elif not config.minimize_val_metric and val_metric >= best_val_metric: best_val_metric = val_metric for callback in callbacks: callback.post_val_loop(step, epoch, val_metric, best_val_metric) if do_distributed: train_metric = jax_utils.unreplicate(train_metric) for key, value in train_metric.items(): print("Train %s : %.2f" % (key, value)) for callback in callbacks: callback.post_epoch(step, epoch) end = time.time() train_time = (end - start) / 60 desc = f"Epoch... ({epoch + 1}/{config.train_epochs} | Time taken (mins): {train_time})" epochs.write(desc) epochs.desc = desc for callback in callbacks: callback.post_train_loop(step, epoch) if do_distributed: state = jax_utils.unreplicate(state) return state def save_checkpoint(self, dir: Path, target: PyTree, step: int, keep_checkpoints: int): return checkpoints.save_checkpoint( dir, target, step, prefix="checkpoint_", keep=keep_checkpoints, overwrite=True ) def load_checkpoint(self, dir: Path, target: PyTree): return checkpoints.restore_checkpoint(dir, target, prefix="checkpoint_") def _construct_optimizer(self, optimizer): self.optimizer = optimizer.construct() return self.optimizer def _construct_lr_scheduler(self, scheduler): self.lr_scheduler = scheduler.construct() return self.lr_scheduler def _get_devices(self) -> List[Any]: device_type = jax.default_backend() self.devices = jax.devices() device_count = len(self.devices) print("Training on %d %s" % (device_count, device_type)) return self.devices ================================================ FILE: tango/integrations/flax/train_callback.py ================================================ import logging from pathlib import Path from typing import Any, Dict, Optional from tango.common.dataset_dict import DatasetDictBase from tango.common.registrable import Registrable from tango.workspace import Workspace from .data import DataLoader from .model import Model from .optim import Optimizer from .train_config import TrainConfig class TrainCallback(Registrable): """ A :class:`TrainCallback` is a :class:`~tango.common.Registrable` class that can be used within :class:`FlaxTrainStep` to customize behavior in the training loop. You can set the training callbacks with the ``callbacks`` parameter to :class:`FlaxTrainStep`. .. tip:: All of the parameters to this base class will be automatically set within the training loop, so you shouldn't include them in your config for your callbacks. .. tip:: You can access the model being trained through :attr:`self.model `. .. important:: The ``step`` argument to callback methods is the total/overall number of training steps so far, independent of the current epoch. .. seealso:: See :class:`~tango.integrations.wandb.WandbTrainCallback` for an example implementation. :ivar Workspace workspace: The tango workspace being used. :ivar TrainConfig train_config: The training config. :ivar tango.common.DatasetDictBase dataset_dict: The dataset dict containing train and optional validation splits. :ivar DataLoader train_dataloader: The dataloader used for the training split. :ivar Model model: The flax model being trained. :ivar Optimizer optimizer: The optimizer being used for training. :ivar DataLoader validation_dataloader: Optional dataloader used for the validation split. """ def __init__( self, workspace: Workspace, train_config: TrainConfig, dataset: DatasetDictBase, train_dataloader: DataLoader, model: Model, optimizer: Optimizer, validation_dataloader: Optional[DataLoader] = None, ) -> None: self.workspace = workspace self.train_config = train_config self.dataset = dataset self.train_dataloader = train_dataloader self.model = model self.optimizer = optimizer self.validation_dataloader = validation_dataloader self.logger = logging.getLogger(self.__class__.__name__) @property def step_id(self) -> str: """ The unique ID of the current :class:`~tango.Step`. """ return self.train_config.step_id @property def step_name(self) -> Optional[str]: """ The name of the current:class:`~tango.Step`. """ return self.train_config.step_name @property def work_dir(self) -> Path: """ The working directory of the current train step """ return self.train_config.work_dir def state_dict(self) -> Dict[str, Any]: """ Return any state that needs to be kept after a restart. Some callbacks need to maintain state across restarts. This is the callback's opportunity to save it's state. It will be restored using :meth:`load_state_dict`. """ return {} def load_state_dict(self, state_dict: Dict[str, Any]): """ Load the state on a restart. Some callbacks need to maintain state across restarts. This is the callback's opportunity to restore it's state. It gets saved using :meth:`state_dict`. """ pass def pre_train_loop(self) -> None: """ Called right before the first batch is processed, or after a restart """ pass def post_train_loop(self, step: int, epoch: int) -> None: """ Called after the training loop completes. This is the last method that is called, so any cleanup can be done in this method. """ pass def pre_epoch(self, step: int, epoch: int) -> None: """ Called before start of an epoch. Epochs start at 0. """ pass def post_epoch(self, step: int, epoch: int) -> None: """ Called after an epoch is completed. Epochs start at 0. """ pass def pre_batch(self, step: int, epoch: int, batch) -> None: """ Called directly before processing a batch. """ def post_batch(self, step: int, epoch: int, train_metrics: Dict) -> None: """ Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step. .. note:: The ``train_metrics`` here is the dictionary with train metrics of the current batch. If doing, distributed training, use `jax_utils.unreplicate(train_metrics)` before using train_metrics. If you need the average loss, use :meth:`log_batch()`. """ pass def log_batch(self, step: int, epoch: int, train_metrics: Dict) -> None: """ Called after the optimizer step. Here ``train_metrics`` is the average metrics across all distributed workers. If doing, distributed training, use `jax_utils.unreplicate(train_metrics)` before using train_metrics. .. note:: This callback method is not necessarily called on every step. The frequency depends on the value of the ``log_every`` parameter of :class:`FlaxTrainStep`. """ pass def pre_val_loop(self, step: int, val_step: int, state) -> None: """ Called right before the validation loop starts. """ pass def pre_val_batch(self, step: int, val_step: int, epoch: int, val_batch) -> None: """ Called right before a validation batch is processed. """ pass def post_val_batch(self, step: int, val_step: int, epoch: int, val_metrics: Dict) -> None: """ Called right after a validation batch is processed with the outputs of the batch. .. tip:: This method can be used to modify ``val_metrics`` in place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that's the case, make sure to set ``auto_aggregate_val_metric`` to ``False`` in :class:`FlaxTrainStep`. """ pass def post_val_loop( self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float] ) -> None: """ Called right after the evaluation loop finishes """ pass ================================================ FILE: tango/integrations/flax/train_config.py ================================================ from dataclasses import asdict, dataclass from pathlib import Path from typing import Any, Dict, Optional @dataclass class TrainConfig: """ Encapsulates the parameters of :class:`FlaxTrainStep`. This is used to pass all the training options to :class:`TrainCallback`. """ step_id: str """ The unique ID of the current step. """ work_dir: Path """ The working directory for the training run. """ step_name: Optional[str] = None """ The name of the current step. """ train_split: str = "train" """ The name of the training split. """ validation_split: Optional[str] = None """ The name of the validation split. """ seed: int = 42 """ The random seed used to generate """ train_steps: Optional[int] = None """ The number of steps to train for. """ train_epochs: Optional[int] = None """ The number of epochs to train for. You cannot specify `train_steps` and `train_epochs` at the same time. """ validation_steps: Optional[int] = None """ The number of validation steps. """ log_every: int = 10 """ Controls the frequency of log updates. """ checkpoint_every: int = 100 """ Controls the frequency of checkpoints. """ validate_every: Optional[int] = None """ Controls the frequency of the validation loop. """ is_distributed: bool = False """ Whether or not the training job is distributed. """ val_metric_name: str = "loss" """ The name of the validation metric to track. """ minimize_val_metric: bool = True """ Should be ``True`` when the validation metric being tracked should be minimized. """ auto_aggregate_val_metric: bool = True """ Controls automatic aggregation of validation metric. """ remove_stale_checkpoints: bool = True """ Controls removal of stale checkpoints. """ @property def state_path(self) -> Path: """ The path to the latest state checkpoint file. """ return self.work_dir / "checkpoint_state_latest" @property def best_state_path(self) -> Path: """ The path to the best state checkpoint file according to the validation metric or training loss (if no validation split is given). """ return self.work_dir / "checkpoint_state_best" def should_log_this_step(self, step: int) -> bool: assert self.train_steps is not None return step == 0 or (step + 1) % self.log_every == 0 or step == self.train_steps - 1 def should_checkpoint_this_step(self, step: int) -> bool: assert self.train_steps is not None return ((step + 1) % self.checkpoint_every == 0) or step == self.train_steps - 1 def should_log_this_val_step(self, val_step: int) -> bool: assert self.validation_steps is not None return val_step % self.log_every == 0 or val_step == self.validation_steps - 1 def as_dict(self) -> Dict[str, Any]: return {k: v for k, v in asdict(self).items() if not k.startswith("_")} ================================================ FILE: tango/integrations/flax/util.py ================================================ from typing import Any, Union import jax def get_PRNGkey(seed: int = 42) -> Union[Any, jax._src.random.KeyArray]: """ Utility function to create a pseudo-random number generator key given a seed. """ return jax.random.PRNGKey(seed) def get_multiple_keys(key, multiple: int = 1) -> Union[Any, jax._src.random.KeyArray]: """ Utility function to split a PRNG key into multiple new keys. Used in distributed training. """ return jax.random.split(key, multiple) ================================================ FILE: tango/integrations/flax/wrapper.py ================================================ from abc import abstractmethod from typing import Dict from tango.common.registrable import Registrable class FlaxWrapper(Registrable): """ A wrapper class which contains functions that need to be defined by the user for using the ``flax::train`` and ``flax::eval`` steps. """ def train_metrics(self, state, batch, labels) -> Dict: """ Returns the train metrics other than loss as Dict. """ # return empty dict if no other metrics to compute return {} @abstractmethod def train_loss(self, params, state, batch, dropout_rng, labels): """ This function performs the forward pass and computes loss. The function should return the loss for the batch as a jax device array. The gradient of this function is used for training. """ raise NotImplementedError() @abstractmethod def val_metrics(self, batch, logits, labels) -> Dict: """ Returns the validation metrics as Dict. """ raise NotImplementedError() @abstractmethod def eval_metrics(self, batch, logits, labels) -> Dict: """ Returns the evaluation metrics as Dict. """ raise NotImplementedError() ================================================ FILE: tango/integrations/gs/__init__.py ================================================ """ .. important:: To use this integration you should install ``tango`` with the "gs" extra (e.g. ``pip install tango[gs]``) or just install the `gcsfs `_ library after the fact (e.g. ``pip install gcsfs``). Components for Tango integration with `GS `_. """ from tango.common.exceptions import IntegrationMissingError try: from google.cloud import datastore, storage except (ModuleNotFoundError, ImportError): raise IntegrationMissingError("gs", dependencies={"google-cloud-storage"}) from .step_cache import GSStepCache from .workspace import GSWorkspace __all__ = [ "GSStepCache", "GSWorkspace", ] ================================================ FILE: tango/integrations/gs/common.py ================================================ """ Classes and utility functions for GSWorkspace and GSStepCache. """ import atexit import datetime import json import logging import os import time from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple, Union import google.auth from google.api_core import exceptions from google.auth.credentials import Credentials from google.cloud import storage from google.oauth2.credentials import Credentials as OAuth2Credentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from tango.common.aliases import PathOrStr from tango.common.exceptions import TangoError from tango.common.remote_utils import RemoteConstants from tango.step import Step from tango.step_info import StepInfo logger = logging.getLogger(__name__) def get_bucket_and_prefix(folder_name: str) -> Tuple[str, str]: """ Split bucket name and subfolder name, if present. """ split = folder_name.split("/") return split[0], "/".join(split[1:]) def empty_bucket_folder(folder_name: str): """ Removes all the tango-related blobs from the specified bucket folder. Used for testing. """ credentials, project = google.auth.default() client = storage.Client(project=project, credentials=credentials) bucket_name, prefix = get_bucket_and_prefix(folder_name) prefix = prefix + "/tango-" if prefix else "tango-" bucket = client.bucket(bucket_name) try: bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix))) except exceptions.NotFound: pass def empty_datastore(folder_name: str): """ Removes all the tango-related entities from the specified namespace subfolder in datastore. Used for testing. """ from google.cloud import datastore credentials, project = google.auth.default() namespace, prefix = get_bucket_and_prefix(folder_name) run_kind = prefix + "/run" if prefix else "run" stepinfo_kind = prefix + "/stepinfo" if prefix else "stepinfo" client = datastore.Client(project=project, credentials=credentials, namespace=namespace) run_query = client.query(kind=run_kind) run_query.keys_only() keys = [entity.key for entity in run_query.fetch()] stepinfo_query = client.query(kind=stepinfo_kind) stepinfo_query.keys_only() keys += [entity.key for entity in stepinfo_query.fetch()] client.delete_multi(keys) @dataclass class GSArtifact: """ A GSArtifact object is used for representing storage objects in google cloud storage. """ name: str """ Name of the artifact. """ artifact_path: str """ Remote location url for the artifact. """ created: datetime.datetime """ Time of creation. """ committed: bool """ If set to True, no further changes to the remote artifact are allowed. If set to False, it means that the artifact is under construction. """ class GSArtifactConflict(TangoError): """ Error denoting that the storage artifact already exists. """ pass class GSArtifactNotFound(TangoError): """ Error denoting that the storage artifact does not exist. """ pass class GSArtifactWriteError(TangoError): """ Error denoting that there was an issue writing the artifact to the remote storage. """ pass def join_path(*args) -> str: """ We use this since we cannot use `os.path.join` for cloud storage paths. """ return "/".join(args).strip("/") class GSClient: """ A client for interacting with Google Cloud Storage. The authorization works by providing OAuth2 credentials. :param folder_name: The name of the Google Cloud bucket folder to use. :param credentials: OAuth2 credentials can be provided. If not provided, default gcloud credentials are inferred. :param project: Optionally, the project ID can be provided. This is not essential for `google.cloud.storage` API, since buckets are at the account level, rather than the project level. """ placeholder_file = ".placeholder" """ The placeholder file is used for creation of a folder in the cloud bucket folder, as empty folders are not allowed. It is also used as a marker for the creation time of the folder, hence we use a separate file to mark the artifact as uncommitted. """ uncommitted_file = ".uncommitted" """ The uncommitted file is used to denote an artifact under construction. """ settings_file = "settings.json" """ This file is for storing metadata like version information, etc. """ NUM_CONCURRENT_WORKERS: int = 9 def __init__( self, folder_name: str, credentials: Optional[Credentials] = None, project: Optional[str] = None, ): if not credentials: credentials, project = google.auth.default() self.storage = storage.Client(project=project, credentials=credentials) self.folder_name = folder_name self.bucket_name, self.prefix = get_bucket_and_prefix(folder_name) settings_file = self._gs_path(self.settings_file) blob = self.storage.bucket(self.bucket_name).blob(settings_file) # no HTTP request yet try: with blob.open("r") as file_ref: json.load(file_ref) except exceptions.NotFound: settings = {"version": 1} with blob.open("w") as file_ref: json.dump(settings, file_ref) def url(self, artifact: Optional[str] = None): """ Returns the remote url of the storage artifact. """ path = f"gs://{self.folder_name}" if artifact is not None: path = f"{path}/{artifact}" return path def _convert_blobs_to_artifact(self, blobs: List[storage.Blob]) -> GSArtifact: """ Converts a list of `google.cloud.storage.Blob` to a `GSArtifact`. """ name: str artifact_path: str created: datetime.datetime committed: bool = True for blob in blobs: if blob.name.endswith(self.placeholder_file): created = blob.time_created name = blob.name.replace("/" + self.placeholder_file, "") if self.prefix: name = name.replace(self.prefix + "/", "") artifact_path = name # does not contain bucket info here. elif blob.name.endswith(self.uncommitted_file): committed = False assert name is not None, "Folder is not a GSArtifact, should not have happened." return GSArtifact(name, artifact_path, created, committed) @classmethod def from_env(cls, folder_name: str): """ Constructs the client object from the environment, using default credentials. """ return cls(folder_name) def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact: """ Returns a `GSArtifact` object created by fetching the artifact's information from remote location. """ if isinstance(artifact, str): path = artifact else: # We have an artifact, and we recreate it with refreshed info. path = artifact.artifact_path prefix = self._gs_path(path) blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=prefix)) if len(blobs) > 0: return self._convert_blobs_to_artifact(blobs) else: raise GSArtifactNotFound() def _gs_path(self, *args): """ Returns path within google cloud storage bucket. """ return join_path(self.prefix, *args) def create(self, artifact: str): """ Creates a new artifact in the remote location. By default, it is uncommitted. """ bucket = self.storage.bucket(self.bucket_name) # gives refreshed information artifact_path = self._gs_path(artifact, self.placeholder_file) if bucket.blob(artifact_path).exists(): raise GSArtifactConflict(f"{artifact} already exists!") else: # Additional safety check if bucket.blob(self._gs_path(artifact, self.uncommitted_file)).exists(): raise GSArtifactConflict(f"{artifact} already exists!") bucket.blob(self._gs_path(artifact, self.placeholder_file)).upload_from_string("") bucket.blob(self._gs_path(artifact, self.uncommitted_file)).upload_from_string("") return self._convert_blobs_to_artifact( list(bucket.list_blobs(prefix=self._gs_path(artifact))) ) def delete(self, artifact: GSArtifact): """ Removes the artifact from the remote location. """ bucket = self.storage.bucket(self.bucket_name) prefix = self._gs_path(artifact.artifact_path) blobs = list(bucket.list_blobs(prefix=prefix)) bucket.delete_blobs(blobs) def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path): """ Writes the contents of objects_dir to the remote artifact location. """ if isinstance(artifact, str): folder_path = artifact else: folder_path = artifact.artifact_path source_path = str(objects_dir) def _sync_blob(source_file_path: str, target_file_path: str): blob = self.storage.bucket(self.bucket_name).blob(self._gs_path(target_file_path)) blob.upload_from_filename(source_file_path) import concurrent.futures try: # TODO: google-cloud-storage==2.7.0 has added a preview feature called `transfer_manager` # which allows for concurrent uploads and downloads. We should upgrade to this once # it is more robustly supported. Also update in `download()`. if os.path.isdir(source_path): with concurrent.futures.ThreadPoolExecutor( max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSClient.upload()-" ) as executor: upload_futures = [] for dirpath, _, filenames in os.walk(source_path): for filename in filenames: source_file_path = os.path.join(dirpath, filename) target_file_path = join_path( folder_path, source_file_path.replace(source_path + "/", "") ) upload_futures.append( executor.submit(_sync_blob, source_file_path, target_file_path) ) for future in concurrent.futures.as_completed(upload_futures): future.result() else: source_file_path = source_path target_file_path = join_path(folder_path, os.path.basename(source_file_path)) _sync_blob(source_file_path, target_file_path) except Exception: raise GSArtifactWriteError() def commit(self, artifact: Union[str, GSArtifact]): """ Marks the artifact as committed. No further changes to the artifact are allowed. """ if isinstance(artifact, str): folder_path = artifact else: folder_path = artifact.artifact_path bucket = self.storage.bucket(self.bucket_name) try: bucket.delete_blob(self._gs_path(folder_path, self.uncommitted_file)) except exceptions.NotFound: if not bucket.blob(self._gs_path(folder_path, self.placeholder_file)).exists(): raise GSArtifactNotFound() # Otherwise, already committed. No change. def download(self, artifact: GSArtifact, target_dir: PathOrStr): """ Writes the contents of the remote artifact to the `target_dir`. """ assert ( self.storage.bucket(self.bucket_name) .blob(self._gs_path(artifact.artifact_path, self.placeholder_file)) .exists() ) def _fetch_blob(blob: storage.Blob): source_path = blob.name.replace(artifact.artifact_path + "/", "") target_path = os.path.join(target_dir, source_path) if not os.path.exists(os.path.dirname(target_path)): os.mkdir(os.path.dirname(target_path)) blob.download_to_filename(target_path) import concurrent.futures bucket = self.storage.bucket(self.bucket_name) # We may not need updates that frequently, with list_blobs(prefix). # bucket.update() try: with concurrent.futures.ThreadPoolExecutor( max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSClient.download()-" ) as executor: download_futures = [] prefix = self._gs_path(artifact.artifact_path) for blob in bucket.list_blobs(prefix=prefix): download_futures.append(executor.submit(_fetch_blob, blob)) for future in concurrent.futures.as_completed(download_futures): future.result() except exceptions.NotFound: raise GSArtifactWriteError() def artifacts(self, prefix: str, uncommitted: bool = True) -> List[GSArtifact]: """ Lists all the artifacts within the remote storage, based on `match` and `uncommitted` criteria. These can include steps and runs. """ list_of_artifacts = [] prefix = self._gs_path(prefix) for folder_name in self.storage.list_blobs( self.bucket_name, prefix=prefix, delimiter="/" )._get_next_page_response()["prefixes"]: artifact = self._convert_blobs_to_artifact( list(self.storage.list_blobs(self.bucket_name, prefix=folder_name)) ) if not uncommitted: if not artifact.committed: continue list_of_artifacts.append(artifact) return list_of_artifacts def get_credentials(credentials: Optional[Union[str, Credentials]] = None) -> Credentials: """ :param credentials: * if OAuth2 credentials are provided, they are returned. * if `str`, it can be either a file path or a json string of credentials dict. * if `None`, credentials are inferred from the environment. More details on Google Cloud credentials can be found here: https://googleapis.dev/python/google-auth/latest/user-guide.html#service-account-private-key-files, and https://googleapis.dev/python/google-api-core/latest/auth.html """ # BeakerExecutor uses GOOGLE_TOKEN credentials = os.environ.get("GOOGLE_TOKEN", credentials) if credentials is not None: # Path to the credentials file has been provided if isinstance(credentials, str) and credentials.endswith(".json"): with open(credentials) as file_ref: credentials = file_ref.read() try: # If credentials dict has been passed as a json string credentials_dict = json.loads(credentials) if credentials_dict.pop("type", None) == "service_account": credentials = ServiceAccountCredentials.from_service_account_info(credentials_dict) else: # sometimes the credentials dict may not contain `token` and `token_uri` keys, # but `Credentials()` needs the parameter. token = credentials_dict.pop("token", None) token_uri = credentials_dict.pop("token_uri", "https://oauth2.googleapis.com/token") credentials = OAuth2Credentials( token=token, token_uri=token_uri, **credentials_dict ) except (json.decoder.JSONDecodeError, TypeError, ValueError): # It is not a json string. # We use this string because BeakerExecutor cannot write a None secret. if credentials == "default": credentials = None if not credentials: # Infer default credentials credentials, _ = google.auth.default() return credentials def get_client( folder_name: str, credentials: Optional[Union[str, Credentials]] = None, project: Optional[str] = None, ) -> GSClient: """ Returns a `GSClient` object for a google cloud bucket folder. """ credentials = get_credentials(credentials) return GSClient(folder_name, credentials=credentials, project=project) class Constants(RemoteConstants): pass class GCSStepLock: """ Google Cloud offers consistency https://cloud.google.com/storage/docs/consistency, so we can use lock files. """ def __init__( self, client: GSClient, step: Union[str, StepInfo, Step], ): self._client = client self._step_id = step if isinstance(step, str) else step.unique_id self._lock_artifact_name = RemoteConstants.step_lock_artifact_name(step) self._lock_artifact: Optional[GSArtifact] = None self.lock_artifact_url = self._client.url(self._lock_artifact_name) def acquire(self, timeout=None, poll_interval: float = 2.0, log_interval: float = 30.0) -> None: if self._lock_artifact is not None: return start = time.monotonic() last_logged = None while timeout is None or (time.monotonic() - start < timeout): try: self._lock_artifact = self._client.create(self._lock_artifact_name) atexit.register(self.release) except GSArtifactConflict: if last_logged is None or last_logged - start >= log_interval: logger.warning( "Waiting to acquire lock artifact for step '%s':\n\n%s\n\n" "This probably means the step is being run elsewhere, but if you're sure it isn't " "you can just delete the lock artifact, using the command: \n`gsutil rm -r %s`", self._step_id, self.lock_artifact_url, self.lock_artifact_url, ) last_logged = time.monotonic() time.sleep(poll_interval) continue else: break else: raise TimeoutError( f"Timeout error occurred while waiting to acquire artifact lock for step '{self._step_id}':\n\n" f"{self.lock_artifact_url}\n\n" f"This probably means the step is being run elsewhere, but if you're sure it isn't you can " f"just delete the lock, using the command: \n`gsutil rm -r {self.lock_artifact_url}`" ) def release(self): if self._lock_artifact is not None: try: self._client.delete(self._lock_artifact) except GSArtifactNotFound: # Artifact must have been manually deleted. pass self._lock_artifact = None atexit.unregister(self.release) def __del__(self): self.release() ================================================ FILE: tango/integrations/gs/step_cache.py ================================================ import logging from pathlib import Path from typing import Optional, Union from tango.common import PathOrStr from tango.common.util import make_safe_filename, tango_cache_dir from tango.integrations.gs.common import ( Constants, GSArtifact, GSArtifactConflict, GSArtifactNotFound, GSArtifactWriteError, GSClient, get_bucket_and_prefix, ) from tango.step import Step from tango.step_cache import StepCache from tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache from tango.step_info import StepInfo logger = logging.getLogger(__name__) @StepCache.register("gs") class GSStepCache(RemoteStepCache): """ This is a :class:`~tango.step_cache.StepCache` that's used by :class:`GSWorkspace`. It stores the results of steps on Google cloud buckets as blobs. It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a step's resulting subsequent times should be fast. .. tip:: Registered as a :class:`~tango.step_cache.StepCache` under the name "gs". :param folder_name: The name of the google cloud bucket folder to use. :param client: The google cloud storage client to use. """ Constants = Constants def __init__(self, folder_name: str, client: Optional[GSClient] = None): if client is not None: bucket_name, _ = get_bucket_and_prefix(folder_name) assert ( bucket_name == client.bucket_name ), "Assert that bucket name is same as client bucket until we do better" self.folder_name = folder_name self._client = client else: self._client = GSClient(folder_name) super().__init__(tango_cache_dir() / "gs_cache" / make_safe_filename(folder_name)) @property def client(self): return self._client def _step_result_remote(self, step: Union[Step, StepInfo]) -> Optional[GSArtifact]: """ Returns a `GSArtifact` object containing the details of the step. This only returns if the step has been finalized (committed). """ try: artifact = self.client.get(self.Constants.step_artifact_name(step)) return artifact if artifact.committed else None except GSArtifactNotFound: return None def _upload_step_remote(self, step: Step, objects_dir: Path) -> GSArtifact: """ Uploads the step's output to remote location. """ artifact_name = self.Constants.step_artifact_name(step) try: self.client.create(artifact_name) except GSArtifactConflict: pass try: self.client.upload(artifact_name, objects_dir) self.client.commit(artifact_name) except GSArtifactWriteError: pass return self.client.get(artifact_name) def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None: """ Downloads the step's output from remote location. """ try: self.client.download(step_result, target_dir) except GSArtifactNotFound: raise RemoteNotFoundError() def __len__(self): """ Returns the number of committed step outputs present in the remote location. """ # NOTE: lock files should not count here. return sum( 1 for ds in self.client.artifacts( prefix=self.Constants.STEP_ARTIFACT_PREFIX, uncommitted=False ) if ds.name is not None and ds.name.startswith(self.Constants.STEP_ARTIFACT_PREFIX) and not ds.name.endswith(self.Constants.LOCK_ARTIFACT_SUFFIX) ) ================================================ FILE: tango/integrations/gs/workspace.py ================================================ import json import random from pathlib import Path from typing import ( Dict, Generator, Iterable, List, Optional, Tuple, TypeVar, Union, cast, ) from urllib.parse import ParseResult import petname from google.auth.credentials import Credentials from google.cloud import datastore from tango.common.util import utc_now_datetime from tango.integrations.gs.common import ( Constants, GCSStepLock, get_bucket_and_prefix, get_client, get_credentials, ) from tango.integrations.gs.step_cache import GSStepCache from tango.step import Step from tango.step_info import StepInfo, StepState from tango.workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace from tango.workspaces.remote_workspace import RemoteWorkspace T = TypeVar("T") @Workspace.register("gs") class GSWorkspace(RemoteWorkspace): """ This is a :class:`~tango.workspace.Workspace` that stores step artifacts on Google Cloud Storage. .. tip:: Registered as a :class:`~tango.workspace.Workspace` under the name "gs". :param workspace: The name or ID of the Google Cloud bucket folder to use. :param project: The Google project ID. This is required for the datastore. If not provided, it will be inferred from the Google cloud credentials. .. important:: Credentials can be provided in the following ways: - Using the `credentials` keyword argument: - You can specify the path to the credentials json file. - You can specify the `google.oauth2.credentials.Credentials()` object. - You can specify the json string of credentials dict. - Using the default credentials: You can use your default google cloud credentials by running `gcloud auth application-default login`. If you are using `GSWorkspace` with :class:`~tango.integrations.beaker.BeakerExecutor`, you will need to set the environment variable `GOOGLE_TOKEN` to the credentials json file. The default location is usually `~/.config/gcloud/application_default_credentials.json`. """ Constants = Constants NUM_CONCURRENT_WORKERS = 32 def __init__( self, workspace: str, project: Optional[str] = None, credentials: Optional[Union[str, Credentials]] = None, ): credentials = get_credentials(credentials) self.client = get_client(folder_name=workspace, credentials=credentials, project=project) self.client.NUM_CONCURRENT_WORKERS = self.NUM_CONCURRENT_WORKERS self._cache = GSStepCache(workspace, client=self.client) self._locks: Dict[Step, GCSStepLock] = {} super().__init__() project = project or self.client.storage.project or credentials.quota_project_id self.bucket_name, self.prefix = get_bucket_and_prefix(workspace) self._ds = datastore.Client( namespace=self.bucket_name, project=project, credentials=credentials ) @property def cache(self): return self._cache @property def locks(self): return self._locks @property def steps_dir_name(self): return "gs_workspace" @classmethod def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace: workspace: str if parsed_url.netloc and parsed_url.path: # e.g. "gs://ai2/my-workspace" workspace = parsed_url.netloc + parsed_url.path elif parsed_url.netloc: # e.g. "gs://my-workspace" workspace = parsed_url.netloc else: raise ValueError(f"Bad URL for GS workspace '{parsed_url}'") return cls(workspace) @property def url(self) -> str: return self.client.url() def _remote_lock(self, step: Step) -> GCSStepLock: return GCSStepLock(self.client, step) def _step_location(self, step: Step) -> str: return self.client.url(self.Constants.step_artifact_name(step)) @property def _run_key(self): return self.client._gs_path("run") @property def _stepinfo_key(self): return self.client._gs_path("stepinfo") def _save_run( self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None ) -> Run: if name is None: while True: name = petname.generate() + str(random.randint(0, 100)) if not self._ds.get(self._ds.key(self._run_key, name)): break else: if self._ds.get(self._ds.key(self._run_key, name)): raise ValueError(f"Run name '{name}' is already in use") run_entity = self._ds.entity( key=self._ds.key(self._run_key, name), exclude_from_indexes=("steps",) ) # Even though the run's name is part of the key, we add this as a # field so we can index on it and order asc/desc (indices on the key field don't allow ordering). run_entity["name"] = name run_entity["start_date"] = utc_now_datetime() run_entity["steps"] = json.dumps(run_data).encode() self._ds.put(run_entity) return Run(name=cast(str, name), steps=steps, start_date=run_entity["start_date"]) def _get_run_from_entity(self, run_entity: datastore.Entity) -> Optional[Run]: try: steps_info_bytes = run_entity["steps"] steps_info = json.loads(steps_info_bytes) except KeyError: return None import concurrent.futures steps: Dict[str, StepInfo] = {} with concurrent.futures.ThreadPoolExecutor( max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSWorkspace._get_run_from_dataset()-", ) as executor: step_info_futures = [] for unique_id in steps_info.values(): step_info_futures.append(executor.submit(self.step_info, unique_id)) for future in concurrent.futures.as_completed(step_info_futures): step_info = future.result() assert step_info.step_name is not None steps[step_info.step_name] = step_info return Run(name=run_entity.key.name, start_date=run_entity["start_date"], steps=steps) def registered_runs(self) -> Dict[str, Run]: import concurrent.futures runs: Dict[str, Run] = {} with concurrent.futures.ThreadPoolExecutor( max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix="GSWorkspace.registered_runs()-", ) as executor: run_futures = [] for run_entity in self._ds.query(kind=self._run_key).fetch(): run_futures.append(executor.submit(self._get_run_from_entity, run_entity)) for future in concurrent.futures.as_completed(run_futures): run = future.result() if run is not None: runs[run.name] = run return runs def search_registered_runs( self, *, sort_by: Optional[RunSort] = None, sort_descending: bool = True, match: Optional[str] = None, start: int = 0, stop: Optional[int] = None, ) -> List[RunInfo]: run_entities = self._fetch_run_entities( sort_by=sort_by, sort_descending=sort_descending, match=match, start=start, stop=stop ) return [ RunInfo(name=e.key.name, start_date=e["start_date"], steps=json.loads(e["steps"])) for e in run_entities ] def num_registered_runs(self, *, match: Optional[str] = None) -> int: count = 0 for _ in self._fetch_run_entities(match=match): count += 1 return count def _fetch_run_entities( self, *, sort_by: Optional[RunSort] = None, sort_descending: bool = True, match: Optional[str] = None, start: int = 0, stop: Optional[int] = None, ) -> Generator[datastore.Entity, None, None]: from itertools import islice # Note: we can't query or order by multiple fields without a suitable # composite index. So in that case we have to apply remaining filters # or slice and order locally. We'll default to using 'match' in the query. # But if 'match' is null we can sort with the query. sort_locally = bool(match) sort_field: Optional[str] = None if sort_by == RunSort.START_DATE: sort_field = "start_date" elif sort_by == RunSort.NAME: sort_field = "name" elif sort_by is not None: raise NotImplementedError(sort_by) order: List[str] = [] if sort_field is not None and not sort_locally: order = [sort_field if not sort_descending else f"-{sort_field}"] query = self._ds.query(kind=self._run_key, order=order) if match: # HACK: Datastore has no direct string matching functionality, # but this comparison is equivalent to checking if 'name' starts with 'match'. query.add_filter("name", ">=", match) query.add_filter("name", "<=", match[:-1] + chr(ord(match[-1]) + 1)) entity_iter: Iterable[datastore.Entity] = query.fetch( offset=0 if sort_locally else start, limit=None if (stop is None or sort_locally) else stop - start, ) if sort_field is not None and sort_locally: entity_iter = sorted( entity_iter, key=lambda entity: entity[sort_field], reverse=sort_descending ) if sort_locally: entity_iter = islice(entity_iter, start, stop) for entity in entity_iter: yield entity def search_step_info( self, *, sort_by: Optional[StepInfoSort] = None, sort_descending: bool = True, match: Optional[str] = None, state: Optional[StepState] = None, start: int = 0, stop: Optional[int] = None, ) -> List[StepInfo]: step_info_entities = self._fetch_step_info_entities( sort_by=sort_by, sort_descending=sort_descending, match=match, state=state, start=start, stop=stop, ) return [ StepInfo.from_json_dict(json.loads(e["step_info_dict"])) for e in step_info_entities ] def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int: count = 0 for _ in self._fetch_step_info_entities(match=match, state=state): count += 1 return count def _fetch_step_info_entities( self, *, sort_by: Optional[StepInfoSort] = None, sort_descending: bool = True, match: Optional[str] = None, state: Optional[StepState] = None, start: int = 0, stop: Optional[int] = None, ) -> Generator[datastore.Entity, None, None]: from itertools import islice # Note: we can't query or order by multiple fields without a suitable # composite index. So in that case we have to apply remaining filters # or slice and order locally. We'll default to using 'match' in the query. # But if 'match' is null, we'll use 'state' to filter in the query. # If 'state' is also null, we can sort with the query. sort_locally = sort_by is not None and (match is not None or state is not None) filter_locally = state is not None and match is not None slice_locally = sort_locally or filter_locally sort_field: Optional[str] = None if sort_by == StepInfoSort.START_TIME: sort_field = "start_time" elif sort_by == StepInfoSort.UNIQUE_ID: sort_field = "step_id" elif sort_by is not None: raise NotImplementedError(sort_by) order: List[str] = [] if sort_field is not None and not sort_locally: order = [sort_field if not sort_descending else f"-{sort_field}"] query = self._ds.query(kind=self._stepinfo_key, order=order) if match is not None: # HACK: Datastore has no direct string matching functionality, # but this comparison is equivalent to checking if 'step_id' starts with 'match'. query.add_filter("step_id", ">=", match) query.add_filter("step_id", "<=", match[:-1] + chr(ord(match[-1]) + 1)) elif state is not None and not filter_locally: query.add_filter("state", "=", str(state.value)) entity_iter: Iterable[datastore.Entity] = query.fetch( offset=0 if slice_locally else start, limit=None if (stop is None or slice_locally) else stop - start, ) if state is not None and filter_locally: entity_iter = filter(lambda entity: entity["state"] == state, entity_iter) if sort_field is not None and sort_locally: entity_iter = sorted( entity_iter, key=lambda entity: entity[sort_field], reverse=sort_descending ) if slice_locally: entity_iter = islice(entity_iter, start, stop) for entity in entity_iter: yield entity def registered_run(self, name: str) -> Run: err_msg = f"Run '{name}' not found in workspace" run_entity = self._ds.get(key=self._ds.key(self._run_key, name)) if not run_entity: raise KeyError(err_msg) run = self._get_run_from_entity(run_entity) if run is None: raise KeyError(err_msg) else: return run def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: unique_id = ( step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id ) step_info_entity = self._ds.get(key=self._ds.key(self._stepinfo_key, unique_id)) if step_info_entity is not None: step_info_bytes = step_info_entity["step_info_dict"] step_info = StepInfo.from_json_dict(json.loads(step_info_bytes)) return step_info else: if not isinstance(step_or_unique_id, Step): raise KeyError(step_or_unique_id) step_info = StepInfo.new_from_step(step_or_unique_id) self._update_step_info(step_info) return step_info def _step_info_multiple( self, step_or_unique_ids: Union[List[Step], List[str]] ) -> List[StepInfo]: """ This method is to combine all calls to the datastore api in a single transaction. """ all_unique_id_keys = [] for step_or_unique_id in step_or_unique_ids: unique_id = ( step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id ) key = self._ds.key(self._stepinfo_key, unique_id) all_unique_id_keys.append(key) missing: List = [] step_info_entities = self._ds.get_multi(keys=all_unique_id_keys, missing=missing) missing_steps = [entity.key.name for entity in missing] step_infos = [] for step_info_entity in step_info_entities: step_info_bytes = step_info_entity["step_info_dict"] step_info = StepInfo.from_json_dict(json.loads(step_info_bytes)) step_infos.append(step_info) for step_or_unique_id in step_or_unique_ids: step_id = ( step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id ) if step_id in missing_steps: if not isinstance(step_or_unique_id, Step): raise KeyError(step_or_unique_id) step_info = StepInfo.new_from_step(step_or_unique_id) self._update_step_info(step_info) step_infos.append(step_info) return step_infos def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]: all_steps = set(targets) for step in targets: all_steps |= step.recursive_dependencies steps: Dict[str, StepInfo] = {} run_data: Dict[str, str] = {} all_valid_steps = [step for step in all_steps if step.name is not None] step_infos = self._step_info_multiple(all_valid_steps) for step_info in step_infos: assert step_info.step_name is not None steps[step_info.step_name] = step_info run_data[step_info.step_name] = step_info.unique_id return steps, run_data def _update_step_info(self, step_info: StepInfo): step_info_entity = self._ds.entity( key=self._ds.key(self._stepinfo_key, step_info.unique_id), exclude_from_indexes=("step_info_dict",), ) # Even though the step's unique ID is part of the key, we add this as a # field so we can index on it and order asc/desc (indices on the key field don't allow ordering). step_info_entity["step_id"] = step_info.unique_id step_info_entity["step_name"] = step_info.step_name step_info_entity["start_time"] = step_info.start_time step_info_entity["end_time"] = step_info.end_time step_info_entity["state"] = str(step_info.state.value) step_info_entity["updated"] = utc_now_datetime() step_info_entity["step_info_dict"] = json.dumps(step_info.to_json_dict()).encode() self._ds.put(step_info_entity) def _remove_step_info(self, step_info: StepInfo) -> None: # remove dir from bucket step_artifact = self.client.get(self.Constants.step_artifact_name(step_info)) if step_artifact is not None: self.client.delete(step_artifact) # remove datastore entities self._ds.delete(key=self._ds.key("stepinfo", step_info.unique_id)) def _save_run_log(self, name: str, log_file: Path): """ The logs are stored in the bucket. The Run object details are stored in the remote database. """ run_dataset = self.Constants.run_artifact_name(name) self.client.upload(run_dataset, log_file) ================================================ FILE: tango/integrations/torch/__init__.py ================================================ # -*- coding: UTF-8 -*- """ .. important:: To use this integration you should install ``tango`` with the "torch" extra (e.g. ``pip install tango[torch]``) or just install PyTorch after the fact. Make sure you install the correct version of torch given your operating system and supported CUDA version. Check `pytorch.org/get-started/locally/ `_ for more details. Components for Tango integration with `PyTorch `_. These include a training loop :class:`~tango.step.Step` and registrable versions of many ``torch`` classes, such :class:`torch.optim.Optimizer` and :class:`torch.utils.data.DataLoader`. Example: training a model ------------------------- Let's look a simple example of training a model. We'll make a basic regression model and generate some fake data to train on. First, the setup: .. testcode:: import torch import torch.nn as nn from tango.common.dataset_dict import DatasetDict from tango.step import Step from tango.integrations.torch import Model Now let's build and register our model: .. testcode:: @Model.register("basic_regression") class BasicRegression(Model): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) self.sigmoid = nn.Sigmoid() self.mse = nn.MSELoss() def forward(self, x, y=None): pred = self.sigmoid(self.linear(x)) out = {"pred": pred} if y is not None: out["loss"] = self.mse(pred, y) return out def _to_params(self): return {} Lastly, we'll need a step to generate data: .. testcode:: @Step.register("generate_data") class GenerateData(Step): DETERMINISTIC = True CACHEABLE = False def run(self) -> DatasetDict: torch.manual_seed(1) return DatasetDict( { "train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)], "validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)], } ) You could then run this experiment with a config that looks like this: .. literalinclude:: ../../../../test_fixtures/integrations/torch/train.jsonnet .. testcode:: :hide: from tango.common.testing import run_experiment from tango.common.registrable import Registrable # Pickling the model fails because the class is defined ad hoc, not in a module. # So we put in this hack to pickle a 0 instead of the Model. def _return_zero(self): return (int, (0,)) BasicRegression.__reduce__ = _return_zero with run_experiment( "test_fixtures/integrations/torch/train.jsonnet", name="boss-alien" ) as run_dir: assert (run_dir / "train").is_dir(), "Output for the 'train' step was not produced." # Restore state of registry. del Registrable._registry[Step]["generate_data"] del Registrable._registry[Model]["basic_regression"] For example, .. code-block:: tango run train.jsonnet -i my_package -d /tmp/train would produce the following output: .. testoutput:: :options: +ELLIPSIS Starting new run boss-alien ● Starting step "data" (needed by "train")... ✓ Finished step "data" ● Starting step "train"... ✓ Finished step "train" ✓ Finished run boss-alien ... Tips ---- Debugging ~~~~~~~~~ When debugging a training loop that's causing errors on a GPU, you should set the environment variable ``CUDA_LAUNCH_BLOCKING=1``. This will ensure that the stack traces shows where the error actually happened. You could also use a custom :class:`TrainCallback` to log each batch before they are passed into the model so that you can see the exact inputs that are causing the issue. Stopping early ~~~~~~~~~~~~~~ You can stop the "torch::train" step early using a custom :class:`TrainCallback`. Your callback just needs to raise the :class:`StopEarly` exception. """ from tango.common.exceptions import IntegrationMissingError try: import torch except ModuleNotFoundError: raise IntegrationMissingError("torch") __all__ = [ "TorchFormat", "TorchTrainStep", "TorchEvalStep", "Optimizer", "LRScheduler", "Model", "DataLoader", "DataCollator", "Sampler", "ConcatTensorDictsCollator", "TrainCallback", "EvalCallback", "TrainConfig", "StopEarlyCallback", "StopEarly", "TrainingEngine", "TorchTrainingEngine", ] from .data import ConcatTensorDictsCollator, DataCollator, DataLoader, Sampler from .eval import TorchEvalStep from .eval_callback import EvalCallback from .exceptions import StopEarly from .format import TorchFormat from .model import Model from .optim import LRScheduler, Optimizer from .train import TorchTrainStep from .train_callback import StopEarlyCallback, TrainCallback from .train_config import TrainConfig from .training_engine import TorchTrainingEngine, TrainingEngine ================================================ FILE: tango/integrations/torch/data.py ================================================ from typing import Any, Dict, Generic, List, Optional, TypeVar, Union import torch from tango.common.lazy import Lazy from tango.common.registrable import Registrable T = TypeVar("T") class DataCollator(Generic[T], Registrable): """ A :class:`~tango.common.Registrable` version of a ``collate_fn`` for a ``DataLoader``. Subclasses just need to implement :meth:`__call__()`. """ default_implementation = "concat_tensor_dicts" """ The default implementation is :class:`ConcatTensorDictsCollator`. """ def __call__(self, items: List[T]) -> Dict[str, Any]: """ Takes a list of items from a dataset and combines them into a batch. """ raise NotADirectoryError @DataCollator.register("concat_tensor_dicts") class ConcatTensorDictsCollator(DataCollator[Dict[str, Any]]): """ A simple ``collate_fn`` that expects items to be dictionaries of tensors. The tensors are just concatenated together. .. tip:: Registered as a :class:`DataCollator` under the name "concat_tensor_dicts". """ def __call__(self, items: List[Dict[str, Any]]) -> Dict[str, Any]: out = {} keys = items[0].keys() for key in keys: if isinstance(items[0][key], torch.Tensor): out[key] = torch.cat([item[key].unsqueeze(0) for item in items]) elif isinstance(items[0][key], (int, float)): out[key] = torch.tensor([item[key] for item in items]) else: out[key] = [item[key] for item in items] # type: ignore[assignment] return out class Sampler(torch.utils.data.Sampler, Registrable): """ A :class:`~tango.common.Registrable` version of a PyTorch :class:`~torch.utils.data.Sampler`. All `built-in PyTorch samplers `_ are registered under their corresponding class name (e.g. "RandomSampler"). """ @Sampler.register("torch::BatchSampler") class BatchSampler(torch.utils.data.BatchSampler, Sampler): def __init__( self, dataset: torch.utils.data.Dataset, sampler: Union[Lazy[Sampler], Sampler], batch_size: int, drop_last: bool, ) -> None: super().__init__( sampler.construct(data_source=dataset, dataset=dataset) if isinstance(sampler, Lazy) else sampler, batch_size, drop_last, ) # Register all remaining samplers. for name, cls in torch.utils.data.__dict__.items(): registered_name = "torch::" + name if ( isinstance(cls, type) and issubclass(cls, torch.utils.data.Sampler) and not cls == torch.utils.data.Sampler and registered_name not in Sampler.list_available() ): Sampler.register(registered_name)(cls) class DataLoader(torch.utils.data.DataLoader, Registrable): """ A :class:`~tango.common.Registrable` version of a PyTorch :class:`~torch.utils.data.DataLoader`. """ default_implementation = "default" def __init__( self, dataset: torch.utils.data.Dataset, collate_fn: Optional[DataCollator] = ConcatTensorDictsCollator(), sampler: Optional[Union[Lazy[Sampler], Sampler]] = None, **kwargs, ): super().__init__( dataset, collate_fn=collate_fn, sampler=sampler.construct(data_source=dataset, dataset=dataset) if isinstance(sampler, Lazy) else sampler, **kwargs, ) DataLoader.register("default")(DataLoader) ================================================ FILE: tango/integrations/torch/eval.py ================================================ from collections import defaultdict from itertools import islice from typing import Dict, List, Optional, Sequence import torch from tango.common.dataset_dict import DatasetDictBase from tango.common.exceptions import ConfigurationError from tango.common.lazy import Lazy from tango.common.tqdm import Tqdm from tango.format import Format, JsonFormat from tango.step import Step, StepResources from .data import DataLoader from .eval_callback import EvalCallback from .model import Model from .util import check_dataset, move_to_device, resolve_device, set_seed_all @Step.register("torch::eval") class TorchEvalStep(Step): """ A PyTorch evaluation loop that pairs well with :class:`TorchTrainStep`. .. tip:: Registered as a :class:`~tango.step.Step` under the name "torch::eval". .. important:: The evaluation loop will use a GPU automatically if one is available. You can control which GPU it uses with the environment variable ``CUDA_VISIBLE_DEVICES``. For example, set ``CUDA_VISIBLE_DEVICES=1`` to force ``TorchEvalStep`` to only use the GPU with ID 1. .. warning:: By default the metrics specified by the ``metric_names`` parameter are aggregated by simply averaging across batches. This behavior is usually correct for metrics like "loss" or "accuracy", for example, but may not be correct for other metrics like "F1". If this is not correct for your metric you will need to handle the aggregation internally in your model or with an :class:`EvalCallback` using the :meth:`EvalCallback.post_batch()` method. Then set the parameter ``auto_aggregate_metrics`` to ``False``. """ DETERMINISTIC = True CACHEABLE = True FORMAT: Format = JsonFormat() SKIP_ID_ARGUMENTS = {"log_every"} @property def resources(self) -> StepResources: return self.step_resources or StepResources(gpu_count=1) def run( # type: ignore[override] self, model: Model, dataset_dict: DatasetDictBase, dataloader: Lazy[DataLoader], test_split: str = "test", seed: int = 42, eval_steps: Optional[int] = None, log_every: int = 1, metric_names: Sequence[str] = ("loss",), auto_aggregate_metrics: bool = True, callbacks: Optional[List[Lazy[EvalCallback]]] = None, ) -> Dict[str, float]: """ Evaluate the ``model``. :param model: The model to evaluate. It should return a ``dict`` from its ``forward()`` method that includes all of the metrics in ``metric_names`` . :param dataset_dict: Should contain the test data. :param dataloader: The data loader that generates test batches. The batches should be :class:`dict` objects. :param test_split: The name of the data split used for evaluation in the ``dataset_dict``. Default is "test". :param seed: Used to set the RNG states at the beginning of the evaluation loop. :param eval_steps: The number of steps to evaluate for. If not specified evaluation will stop after a complete iteration through the ``dataloader``. :param log_every: Log every this many steps. Default is ``1``. :param metric_names: The names of the metrics to track and aggregate. Default is ``("loss",)``. :param auto_aggregate_metrics: If ``True`` (the default), the metrics will be averaged across batches. This may not be the correct behavior for some metrics (such as F1), in which you should set this to ``False`` and handle the aggregation internally in your model or with an :class:`EvalCallback` (using :meth:`EvalCallback.post_batch()`). :param callbacks: A list of :class:`EvalCallback`. """ set_seed_all(seed) check_dataset(dataset_dict, test_split) # Resolve device. device = resolve_device() # Prep model. model = model.eval().to(device) # Construct dataloader. dataloader: DataLoader = dataloader.construct(dataset=dataset_dict[test_split]) steps: int try: dataloader_len = len(dataloader) steps = dataloader_len if eval_steps is None else min(dataloader_len, eval_steps) except TypeError: if eval_steps is None: raise ConfigurationError( "You must set 'eval_steps' for streaming/iterable datasets" ) else: steps = eval_steps # Initialize callbacks. callbacks: List[EvalCallback] = [ callback.construct( workspace=self.workspace, step_id=self.unique_id, work_dir=self.work_dir, model=model, dataset_dict=dataset_dict, dataloader=dataloader, ) for callback in (callbacks or []) ] for callback in callbacks: callback.pre_eval_loop() eval_batches = enumerate(islice(dataloader, steps)) running_metrics: Dict[str, float] = defaultdict(float) aggregated_metrics: Dict[str, float] = {} with Tqdm.tqdm(eval_batches, desc="Evaluating", total=steps) as batch_iter: for step, batch in batch_iter: should_log_this_step = step % log_every == 0 or step == steps - 1 for callback in callbacks: callback.pre_batch(step, batch) batch = move_to_device(batch, device) with torch.inference_mode(): outputs = model(**batch) for callback in callbacks: callback.post_batch(step, outputs) # Gather metrics we want to track. batch_metrics = { k: outputs[k].item() if isinstance(outputs[k], torch.Tensor) else outputs[k] for k in metric_names } # Aggregate metrics. if auto_aggregate_metrics: for k in batch_metrics: running_metrics[k] += batch_metrics[k] aggregated_metrics[k] = running_metrics[k] / (step + 1) else: aggregated_metrics.update(batch_metrics) # Update progress bar. if should_log_this_step: batch_iter.set_postfix(**aggregated_metrics) # Clean up to help garbage collector. Hopefully this saves memory. del batch del outputs del batch_metrics for callback in callbacks: callback.post_eval_loop(aggregated_metrics) return aggregated_metrics ================================================ FILE: tango/integrations/torch/eval_callback.py ================================================ from pathlib import Path from typing import Any, Dict from tango.common.dataset_dict import DatasetDictBase from tango.common.registrable import Registrable from tango.workspace import Workspace from .data import DataLoader from .model import Model class EvalCallback(Registrable): """ An ``EvalCallback`` is a :class:`~tango.common.Registrable` class that can be used within :class:`TorchEvalStep` to customize the behavior of the evaluation loop, similar to how :class:`TrainCallback` is used to customize the behavior of the training loop. .. tip:: All of the parameters to this base class will be automatically set within the training loop, so you shouldn't include them in your config for your callbacks. :ivar Workspace workspace: The tango workspace being used. :ivar str step_id: The unique ID of the step. :ivar pathlib.Path work_dir: The working directory of the step :ivar Model model: The model being evaluated. :ivar DatasetDictBase dataset_dict: The dataset dict containing the evaluation split. :ivar DataLoader dataloader: The data loader used to load the evaluation split data. """ def __init__( self, workspace: Workspace, step_id: str, work_dir: Path, model: Model, dataset_dict: DatasetDictBase, dataloader: DataLoader, ) -> None: self.workspace = workspace self.step_id = step_id self.work_dir = work_dir self.model = model self.dataset_dict = dataset_dict self.dataloader = dataloader def pre_eval_loop(self) -> None: """ Called right before the first batch is processed. """ pass def post_eval_loop(self, aggregated_metrics: Dict[str, float]) -> None: """ Called after the evaluation loop completes with the final aggregated metrics. This is the last method that is called, so any cleanup can be done in this method. """ pass def pre_batch(self, step: int, batch: Dict[str, Any]) -> None: """ Called directly before processing a batch. """ pass def post_batch(self, step: int, batch_outputs: Dict[str, Any]) -> None: """ Called directly after processing a batch with the outputs of the batch. .. tip:: This method can be used to modify ``batch_outputs`` in place, which is useful in scenarios where you might need to aggregate metrics in a special way other than a simple average. If that's the case, make sure to set ``auto_aggregate_metrics`` to ``False`` in :class:`TorchEvalStep`. """ pass ================================================ FILE: tango/integrations/torch/exceptions.py ================================================ from tango.common.exceptions import TangoError class StopEarly(TangoError): """ Callbacks can raise this exception to stop training early without crashing. .. important:: During distributed training all workers must raise this exception at the same point in the training loop, otherwise there will be a deadlock. """ ================================================ FILE: tango/integrations/torch/format.py ================================================ from pathlib import Path from typing import Generic, TypeVar import dill import torch from tango.common.aliases import PathOrStr from tango.format import Format T = TypeVar("T") @Format.register("torch") class TorchFormat(Format[T], Generic[T]): """ This format writes the artifact using ``torch.save()``. Unlike :class:`tango.format.DillFormat`, this has no special support for iterators. .. tip:: Registered as a :class:`~tango.format.Format` under the name "torch". """ VERSION = "002" def write(self, artifact: T, dir: PathOrStr): filename = Path(dir) / "data.pt" with open(filename, "wb") as f: torch.save((self.VERSION, artifact), f, pickle_module=dill) def read(self, dir: PathOrStr) -> T: filename = Path(dir) / "data.pt" with open(filename, "rb") as f: version, artifact = torch.load(f, pickle_module=dill, map_location=torch.device("cpu")) if version > self.VERSION: raise ValueError( f"File {filename} is too recent for this version of {self.__class__}." ) return artifact ================================================ FILE: tango/integrations/torch/model.py ================================================ import torch from tango.common.registrable import Registrable class Model(torch.nn.Module, Registrable): """ This is a :class:`~tango.common.Registrable` mixin class that inherits from :class:`torch.nn.Module`. Its :meth:`~torch.nn.Module.forward()` method should return a :class:`dict` that includes the ``loss`` during training and any tracked metrics during validation. """ ================================================ FILE: tango/integrations/torch/optim.py ================================================ from typing import Type import torch from tango.common.registrable import Registrable class Optimizer(torch.optim.Optimizer, Registrable): """ A :class:`~tango.common.Registrable` version of a PyTorch :class:`torch.optim.Optimizer`. All `built-in PyTorch optimizers `_ are registered according to their class name (e.g. "torch::Adam"). .. tip:: You can see a list of all available optimizers by running .. testcode:: from tango.integrations.torch import Optimizer for name in sorted(Optimizer.list_available()): print(name) .. testoutput:: :options: +ELLIPSIS torch::ASGD torch::Adadelta torch::Adagrad torch::Adam torch::AdamW ... """ class LRScheduler(torch.optim.lr_scheduler._LRScheduler, Registrable): """ A :class:`~tango.common.Registrable` version of a PyTorch learning rate scheduler. All `built-in PyTorch learning rate schedulers `_ are registered according to their class name (e.g. "torch::StepLR"). .. tip:: You can see a list of all available schedulers by running .. testcode:: from tango.integrations.torch import LRScheduler for name in sorted(LRScheduler.list_available()): print(name) .. testoutput:: :options: +ELLIPSIS torch::ChainedScheduler torch::ConstantLR torch::CosineAnnealingLR ... """ # Register all optimizers. for name, cls in torch.optim.__dict__.items(): if ( isinstance(cls, type) and issubclass(cls, torch.optim.Optimizer) and not cls == torch.optim.Optimizer ): Optimizer.register("torch::" + name)(cls) # Note: This is a hack. Remove after we upgrade the torch version. base_class: Type try: base_class = torch.optim.lr_scheduler.LRScheduler except AttributeError: base_class = torch.optim.lr_scheduler._LRScheduler # Register all learning rate schedulers. for name, cls in torch.optim.lr_scheduler.__dict__.items(): if isinstance(cls, type) and issubclass(cls, base_class) and not cls == base_class: LRScheduler.register("torch::" + name)(cls) ================================================ FILE: tango/integrations/torch/train.py ================================================ import logging import math import os import shutil from itertools import islice from typing import Any, Dict, List, Optional, Set, Union, cast import more_itertools import torch import torch.distributed as dist from more_itertools import chunked from torch.utils.data import DistributedSampler from tango.common.dataset_dict import DatasetDictBase from tango.common.exceptions import ConfigurationError from tango.common.lazy import Lazy from tango.common.tqdm import Tqdm from tango.common.util import get_extra_imported_modules, import_extra_module from tango.format import Format from tango.step import Step, StepResources from tango.workspace import Workspace from .data import DataLoader from .exceptions import StopEarly from .format import TorchFormat from .model import Model from .train_callback import TrainCallback from .train_config import TrainConfig from .training_engine import TrainingEngine from .util import check_dataloader, check_dataset, set_seed_all @Step.register("torch::train") class TorchTrainStep(Step): """ A PyTorch training loop step that supports gradient accumulation, distributed training, and AMP, with configurable dataloaders, callbacks, optimizer, and LR scheduler. .. tip:: Registered as a :class:`~tango.step.Step` under the name "torch::train". .. important:: The training loop will use GPU(s) automatically when available, as long as at least ``device_count`` CUDA devices are available. Distributed data parallel training is activated when the ``device_count`` is greater than 1. You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``. For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1`` (and ``device_count`` to 2). .. warning:: During validation, the validation metric (specified by the ``val_metric_name`` parameter) is aggregated by simply averaging across validation batches and distributed processes. This behavior is usually correct when your validation metric is "loss" or "accuracy", for example, but may not be correct for other metrics like "F1". If this is not correct for your metric you will need to handle the aggregation internally in your model or with a :class:`TrainCallback` using the :meth:`TrainCallback.post_val_batch()` method. Then set the parameter ``auto_aggregate_val_metric`` to ``False``. Note that correctly aggregating your metric during distributed training will involve distributed communication. """ DETERMINISTIC = True CACHEABLE = True FORMAT: Format = TorchFormat() SKIP_ID_ARGUMENTS = {"distributed_port", "log_every"} METADATA = {"artifact_kind": "model"} @property def resources(self) -> StepResources: return self.step_resources or StepResources(gpu_count=self.kwargs["device_count"]) def run( # type: ignore[override] self, model: Union[Lazy[Model], Model], # Lazy has to come first training_engine: Lazy[TrainingEngine], dataset_dict: DatasetDictBase, train_dataloader: Lazy[DataLoader], *, train_split: str = "train", validation_split: Optional[str] = None, validation_dataloader: Optional[Lazy[DataLoader]] = None, seed: int = 42, train_steps: Optional[int] = None, train_epochs: Optional[int] = None, validation_steps: Optional[int] = None, grad_accum: int = 1, log_every: int = 10, checkpoint_every: int = 100, validate_every: Optional[int] = None, device_count: int = 1, distributed_port: int = 54761, val_metric_name: str = "loss", minimize_val_metric: bool = True, auto_aggregate_val_metric: bool = True, callbacks: Optional[List[Lazy[TrainCallback]]] = None, remove_stale_checkpoints: bool = True, ) -> Model: """ Run a basic training loop to train the ``model``. :param model: The model to train. It should return a ``dict`` that includes the ``loss`` during training and the ``val_metric_name`` during validation. :param training_engine: The :class:`TrainingEngine` to use to train the model. :param dataset_dict: The train and optional validation data. :param train_dataloader: The data loader that generates training batches. The batches should be :class:`dict` objects that will be used as ``kwargs`` for the model's ``forward()`` method. :param train_split: The name of the data split used for training in the ``dataset_dict``. Default is "train". :param validation_split: Optional name of the validation split in the ``dataset_dict``. Default is ``None``, which means no validation. :param validation_dataloader: An optional data loader for generating validation batches. The batches should be :class:`dict` objects. If not specified, but ``validation_split`` is given, the validation ``DataLoader`` will be constructed from the same parameters as the train ``DataLoader``. :param seed: Used to set the RNG states at the beginning of training. :param train_steps: The number of steps to train for. If not specified training will stop after a complete iteration through the ``train_dataloader``. :param train_epochs: The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs`` at the same time. :param validation_steps: The number of steps to validate for. If not specified validation will stop after a complete iteration through the ``validation_dataloader``. :param grad_accum: The number of gradient accumulation steps. Defaults to 1. .. note:: This parameter - in conjuction with the settings of your data loader and the number distributed workers - determines the *effective batch size* of your training run. :param log_every: Log every this many steps. :param checkpoint_every: Save a checkpoint every this many steps. :param validate_every: Run the validation loop every this many steps. :param device_count: The number of devices to train on, i.e. the number of distributed data parallel workers. :param distributed_port: The port of the distributed process group. Default = "54761". :param val_metric_name: The name of the validation metric, i.e. the key of the metric in the dictionary returned by the forward pass of the model. Default is "loss". :param minimize_val_metric: Whether the validation metric is meant to be minimized (such as the loss). Default is ``True``. When using a metric such as accuracy, you should set this to ``False``. :param auto_aggregate_val_metric: If ``True`` (the default), the validation metric will be averaged across validation batches and distributed processes. This may not be the correct behavior for some metrics (such as F1), in which you should set this to ``False`` and handle the aggregation internally in your model or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`). :param callbacks: A list of :class:`TrainCallback`. :param remove_stale_checkpoints: If ``True`` (the default), stale checkpoints will be removed throughout training so that only the latest and best checkpoints are kept. :returns: The trained model on CPU with the weights from the best checkpoint loaded. """ devices = self._get_devices(device_count) return self._train( model=model, training_engine=training_engine, dataset_dict=dataset_dict, train_dataloader=train_dataloader, train_split=train_split, validation_split=validation_split, validation_dataloader=validation_dataloader, seed=seed, train_steps=train_steps, train_epochs=train_epochs, validation_steps=validation_steps, grad_accum=grad_accum, log_every=log_every, checkpoint_every=checkpoint_every, validate_every=validate_every, devices=devices, distributed_port=distributed_port, val_metric_name=val_metric_name, minimize_val_metric=minimize_val_metric, auto_aggregate_val_metric=auto_aggregate_val_metric, callbacks=callbacks, remove_stale_checkpoints=remove_stale_checkpoints, ) def _get_devices(self, device_count: int) -> List[int]: """ Validates the device count, and returns the list of devices. """ # Validate device(s). if device_count <= 0: raise ConfigurationError("Invalid value for 'device_count'. Must be at least 1.") devices: List[int] if torch.cuda.is_available() and torch.cuda.device_count() >= device_count: devices = list(range(device_count)) self.logger.info("Training on %d GPU%s", device_count, "s" if device_count > 1 else "") else: devices = [-1] * device_count self.logger.info( "Training on CPU with %d worker%s", device_count, "s" if device_count > 1 else "" ) return devices def _train( self, model: Union[Model, Lazy[Model]], training_engine: Lazy[TrainingEngine], dataset_dict: DatasetDictBase, train_dataloader: Lazy[DataLoader], *, train_split: str = "train", validation_split: Optional[str] = None, validation_dataloader: Optional[Lazy[DataLoader]] = None, seed: int = 42, train_steps: Optional[int] = None, train_epochs: Optional[int] = None, validation_steps: Optional[int] = None, grad_accum: int = 1, log_every: int = 10, checkpoint_every: int = 100, validate_every: Optional[int] = None, devices: Optional[List[int]] = None, distributed_port: int = 54761, val_metric_name: str = "loss", minimize_val_metric: bool = True, auto_aggregate_val_metric: bool = True, callbacks: Optional[List[Lazy[TrainCallback]]] = None, remove_stale_checkpoints: bool = True, ) -> Model: is_distributed = False num_workers = 1 if devices and len(devices) > 1: is_distributed = True num_workers = len(devices) if validate_every is not None and validation_split is None: raise ConfigurationError( "You have set a validation interval, but no validation split. " "That's probably unintentional." ) if (train_steps is not None) == (train_epochs is not None): raise ConfigurationError( "One of 'train_steps' or 'train_epochs' needs to be specified, but not both." ) if validate_every is not None and checkpoint_every is not None: if checkpoint_every % validate_every != 0 and validate_every % checkpoint_every != 0: raise ConfigurationError( "'checkpoint_every' needs to be multiple of 'validate_every' or vice versa" ) config = TrainConfig( self.unique_id, self.work_dir, step_name=self.name, train_split=train_split, validation_split=validation_split, seed=seed, train_steps=train_steps, train_epochs=train_epochs, grad_accum=grad_accum, log_every=log_every, checkpoint_every=checkpoint_every, validate_every=validate_every, validation_steps=validation_steps, is_distributed=is_distributed, devices=devices, distributed_port=distributed_port, val_metric_name=val_metric_name, minimize_val_metric=minimize_val_metric, auto_aggregate_val_metric=auto_aggregate_val_metric, remove_stale_checkpoints=remove_stale_checkpoints, world_size=num_workers, ) final_model: Model if is_distributed: import torch.multiprocessing as mp mp.spawn( _train, args=( self.workspace, config, model, training_engine, dataset_dict, train_dataloader, validation_dataloader, callbacks, get_extra_imported_modules(), ), nprocs=num_workers, ) self.logger.info("Constructing final model") if isinstance(model, Lazy): final_model = model.construct() else: final_model = model else: final_model = _train( # type: ignore[assignment] 0, self.workspace, config, model, training_engine, dataset_dict, train_dataloader, validation_dataloader=validation_dataloader, callbacks=callbacks, ) assert final_model is not None final_model = final_model.cpu() # Load best checkpoint before returning model. if config.final_weights_path.is_file(): self.logger.info( f"Loading best weights from {str(config.final_weights_path.resolve())}" ) state = torch.load(config.final_weights_path, map_location="cpu") # We use `strict=False` because there might be missing keys due to weight tying. final_model.load_state_dict(state, strict=False) return final_model def _train( worker_id: int, workspace: Workspace, config: TrainConfig, model: Union[Model, Lazy[Model]], training_engine: Lazy[TrainingEngine], dataset_dict: DatasetDictBase, train_dataloader: Lazy[DataLoader], validation_dataloader: Optional[Lazy[DataLoader]] = None, callbacks: Optional[List[Lazy[TrainCallback]]] = None, include_package: Optional[Set[str]] = None, ) -> Optional[Model]: # Set random seeds. set_seed_all(config.seed) config.worker_id = worker_id if config.is_distributed and include_package: # During distributed training we need to import `include_package` modules again # in order to initialize the lazy objects. for package_name in include_package: import_extra_module(package_name) if config.is_distributed: import tango.common.logging as common_logging common_logging.initialize_worker_logging(config.worker_id) logger = logging.getLogger(TorchTrainStep.__name__) training_engine: TrainingEngine = training_engine.construct( train_config=config, model=model, ) # Check working directory to see if we should recover from a previous run. initial_state: Optional[Dict[str, Any]] = None if config.state_path.exists(): if config.is_local_main_process: logger.info(f"Recovering from previous run at {str(config.state_path.resolve())}") initial_state = training_engine.load_checkpoint(config.state_path) device = config.worker_local_default_device # Construct data loaders. validation_dataloader_: Optional[DataLoader] = None if config.validation_split is not None: validation_dataset = dataset_dict[config.validation_split] check_dataset(validation_dataset, config.validation_split) if validation_dataloader is not None: validation_dataloader_ = validation_dataloader.construct(dataset=validation_dataset) else: validation_dataloader_ = train_dataloader.construct(dataset=validation_dataset) validation_dataloader: Optional[DataLoader] = validation_dataloader_ train_dataset = dataset_dict[config.train_split] check_dataset(train_dataset, config.train_split) train_dataloader: DataLoader = train_dataloader.construct(dataset=train_dataset) if config.train_steps is None: assert config.train_epochs is not None try: steps_per_epoch = len(train_dataloader) except TypeError: raise ConfigurationError("You must set 'train_steps' for streaming/iterable datasets") config.train_steps = math.ceil( steps_per_epoch * (config.train_epochs or 1) / config.grad_accum ) assert config.train_steps is not None # for mypy if validation_dataloader is not None: if config.validation_steps is None: try: config.validation_steps = len(validation_dataloader) except TypeError: raise ConfigurationError( "You must set 'validation_steps' for streaming/iterable datasets" ) # Make sure we're using a DistributedSampler during distributed training. if config.is_distributed: check_dataloader(train_dataloader) if validation_dataloader is not None: check_dataloader(validation_dataloader) # The (training) loss for each batch, updated every training batch. batch_loss: float = 0.0 # The value of the validation metric (could be loss), updated after every validation loop. val_metric: Optional[float] = None # The best validation metric over all validation set passes. best_val_metric: Optional[float] = None # The best validation metric over all validation set passes that correspond to a checkpoint. # Could be different from `best_val_metric` if `checkpoint_every` > `validate_every`. best_val_metric_checkpointed: Optional[float] = None # The step to start training from. start_step: int = 0 # The current training step. step: int = start_step # If we should do a validation pass after the current training batch. should_validate_this_step: bool = False # Load state from checkpoint. if initial_state is not None: val_metric = initial_state[f"val_{config.val_metric_name}"] best_val_metric = initial_state[f"best_{config.val_metric_name}"] best_val_metric_checkpointed = initial_state[f"best_{config.val_metric_name}_checkpointed"] start_step = initial_state["training_steps"] # Initialize callbacks. callbacks: List[TrainCallback] = [ callback.construct( workspace=workspace, train_config=config, training_engine=training_engine, dataset_dict=dataset_dict, train_dataloader=train_dataloader, validation_dataloader=validation_dataloader, ) for callback in (callbacks or []) ] if initial_state: for callback, state in zip(callbacks, initial_state["callbacks"]): callback.load_state_dict(state) del initial_state training_engine.model.train() training_batches = enumerate( islice( _cycle_through_epochs(train_dataloader, config.is_distributed, config.grad_accum), config.train_steps, ) ) def is_best_checkpoint() -> bool: """ A closure that we'll call when saving checkpoints to check if we should link the best checkpoint path to the current checkpoint file. """ if val_metric is not None: if best_val_metric_checkpointed is not None: return ( config.minimize_val_metric and val_metric <= best_val_metric_checkpointed ) or (not config.minimize_val_metric and val_metric >= best_val_metric_checkpointed) else: return False else: # Without a validation loop we always treat the most recent checkpoint as the best. return True def save_state(step: int): """ A closure that we'll call every `checkpoint_every` steps in the train loop to save model and training state. """ # Update best loss/metric trackers. nonlocal best_val_metric_checkpointed if should_validate_this_step and val_metric is not None: if ( best_val_metric_checkpointed is None or (config.minimize_val_metric and val_metric <= best_val_metric_checkpointed) or (not config.minimize_val_metric and val_metric >= best_val_metric_checkpointed) ): best_val_metric_checkpointed = val_metric train_state = { "training_steps": step + 1, f"val_{config.val_metric_name}": val_metric, f"best_{config.val_metric_name}": best_val_metric, f"best_{config.val_metric_name}_checkpointed": best_val_metric_checkpointed, "callbacks": [ callback.state_dict() for callback in callbacks # type: ignore[union-attr] ], } # For reason mypy can't figure out that `training_engine` is a `TrainingEngine` in this closure, # and not a `Lazy[TrainingEngine]`. cast(TrainingEngine, training_engine).save_checkpoint( config.state_path_for_step(step), train_state ) # Link to most recent state path. # NOTE: While hard linking would be preferable to creating symlinks, some train engines # require a whole directory to save their state instead of a single file, which # means state_path_for_step will be a directory, so a hard link won't work. if config.is_local_main_process: if config.state_path.is_symlink(): config.state_path.unlink() config.state_path.symlink_to( config.state_path_for_step(step).relative_to(config.work_dir) ) # Link to best state path. if is_best_checkpoint(): if config.best_state_path.is_symlink(): config.best_state_path.unlink() config.best_state_path.symlink_to( config.state_path_for_step(step).relative_to(config.work_dir) ) # Clean up stale checkpoints. if config.remove_stale_checkpoints: checkpoints_to_keep = { config.best_state_path.resolve(), config.state_path.resolve(), } for path in config.work_dir.glob("checkpoint_state_step*"): path = path.resolve() if path not in checkpoints_to_keep: if path.is_file(): path.unlink() else: shutil.rmtree(path) if config.is_distributed: dist.barrier() # Catch data loader up to where we left off before. current_epoch: int = -1 if start_step > 0: with Tqdm.tqdm( training_batches, desc=f"Catching dataloader up to step {start_step}", total=start_step - 1, disable=not config.is_local_main_process, ) as batch_iter: for step, (current_epoch, batch) in batch_iter: del batch if step >= start_step - 1: break if config.is_distributed: dist.barrier() for callback in callbacks: callback.pre_train_loop() train_batch_iterator_tqdm = Tqdm.tqdm( training_batches, desc="Training", initial=start_step, total=config.train_steps, disable=not config.is_local_main_process, ) train_batch_iterator = more_itertools.peekable(train_batch_iterator_tqdm) try: for step, (epoch, batch) in train_batch_iterator: if epoch != current_epoch: # Start of new epoch. if epoch > 0: # Call post-epoch callbacks for the last epoch. for callback in callbacks: callback.post_epoch(step, current_epoch) for callback in callbacks: callback.pre_epoch(step, epoch) current_epoch = epoch # Pre-batch callback. for callback in callbacks: callback.pre_batch(step, current_epoch, batch) batch_loss = 0.0 batch_outputs = [] for micro_batch_idx, micro_batch in enumerate(batch): # Get loss. micro_batch_loss, micro_batch_outputs = training_engine.forward_train( micro_batch, micro_batch_idx, len(batch) ) if torch.isnan(micro_batch_loss): raise ValueError("nan loss encountered") batch_loss += micro_batch_loss.detach().item() batch_outputs.append( { key: output.detach() if isinstance(output, torch.Tensor) else output for key, output in micro_batch_outputs.items() } ) # Calculate gradients. training_engine.backward(micro_batch_loss) # Clean up in case it saves memory. del micro_batch del micro_batch_loss del micro_batch_outputs # Post-batch callback. for callback in callbacks: callback.post_batch(step, current_epoch, batch_loss, batch_outputs) del batch training_engine.step() # Find out whether we should validate if config.validation_split is None: # If we can't validate, we don't. should_validate_this_step = False elif step == config.train_steps - 1: # If we're at the end of the training run, we always validate. should_validate_this_step = True elif config.validate_every is not None and (step + 1) % config.validate_every == 0: # If validate_every is given, we use that to decide. should_validate_this_step = True elif config.validate_every is None and epoch != train_batch_iterator.peek()[1][0]: # If validate_every is not given, we validate at the end of the epoch. should_validate_this_step = True else: # Otherwise, we don't validate. should_validate_this_step = False # Gather average loss across all workers. if ( config.should_log_this_step(step) or should_validate_this_step ) and config.is_distributed: batch_loss_tensor = torch.tensor(batch_loss, device=device) dist.all_reduce(batch_loss_tensor) batch_loss = batch_loss_tensor.item() / config.world_size if config.should_log_this_step(step): # Callbacks. for callback in callbacks: callback.log_batch(step, current_epoch, batch_loss, batch_outputs) # Update progress bar. metrics_to_log: Dict[str, float] = {"batch_loss": batch_loss} if val_metric is not None: metrics_to_log[f"val_{config.val_metric_name}"] = val_metric if best_val_metric is not None: metrics_to_log[f"best_val_{config.val_metric_name}"] = best_val_metric if config.is_local_main_process: train_batch_iterator_tqdm.set_postfix(**metrics_to_log) # Validate. if should_validate_this_step: assert validation_dataloader is not None assert config.validation_steps is not None # Prepare model for validation. training_engine.model.eval() running_metric = 0.0 with Tqdm.tqdm( islice(validation_dataloader, config.validation_steps), desc="Validating", total=config.validation_steps, leave=False, disable=not config.is_local_main_process, ) as val_batch_iterator: for val_step, val_batch in enumerate(val_batch_iterator): for callback in callbacks: callback.pre_val_batch(step, val_step, current_epoch, val_batch) # Get metric. outputs = training_engine.forward_eval(val_batch) for callback in callbacks: callback.post_val_batch(step, val_step, current_epoch, outputs) metric = outputs[config.val_metric_name] if config.auto_aggregate_val_metric: running_metric += metric if isinstance(metric, float) else metric.item() val_metric = running_metric / (val_step + 1) else: val_metric = metric if isinstance(metric, float) else metric.item() # Average metric across all workers. if ( config.is_distributed and config.should_log_this_val_step(val_step) and config.auto_aggregate_val_metric ): val_metric_tensor = torch.tensor(val_metric, device=device) dist.all_reduce(val_metric_tensor) val_metric = val_metric_tensor.item() / config.world_size # Update progress bar. if config.is_local_main_process and config.should_log_this_val_step( val_step ): val_batch_iterator.set_postfix(**{config.val_metric_name: val_metric}) # Clean up. del val_batch del outputs del metric assert val_metric is not None # Reset model to train mode. training_engine.model.train() if ( best_val_metric is None or (config.minimize_val_metric and val_metric <= best_val_metric) or (not config.minimize_val_metric and val_metric >= best_val_metric) ): best_val_metric = val_metric # Checkpoint. if config.should_checkpoint_this_step(step): save_state(step) # Post validation callback. for callback in callbacks: callback.post_val_loop(step, current_epoch, val_metric, best_val_metric) # Reset model to train mode again in case the callbacks messed with it. if callbacks: training_engine.model.train() # Update progress bar again. metrics_to_log = { "batch_loss": batch_loss, f"val_{config.val_metric_name}": val_metric, f"best_{config.val_metric_name}": best_val_metric, } if config.is_local_main_process: train_batch_iterator_tqdm.set_postfix(**metrics_to_log) else: # Checkpoint. if config.should_checkpoint_this_step(step): save_state(step) # End train loop # Final post-epoch callback. for callback in callbacks: callback.post_epoch(step, current_epoch) except StopEarly: if config.is_local_main_process: logger.info("Stopping early!") finally: train_batch_iterator_tqdm.close() if config.is_distributed: dist.barrier() # If we haven't saved a checkpoint yet, do it now. if not config.best_state_path.exists(): save_state(step) for callback in callbacks: callback.post_train_loop(step, current_epoch) if config.is_local_main_process: # It's possible this file already exists if the step previously failed after # already saving the final weights. if config.final_weights_path.is_file(): os.remove(config.final_weights_path) training_engine.save_complete_weights_from_checkpoint( config.best_state_path, config.final_weights_path ) if not config.is_distributed: return training_engine.model else: return None def _cycle_through_epochs(dataloader: DataLoader, is_distributed: bool, grad_accum: int): epoch = 0 while True: if is_distributed and isinstance(dataloader.sampler, DistributedSampler): dataloader.sampler.set_epoch(epoch) for batch in chunked(dataloader, grad_accum): yield epoch, batch epoch += 1 ================================================ FILE: tango/integrations/torch/train_callback.py ================================================ import logging from pathlib import Path from typing import Any, Dict, List, Optional from tango.common.dataset_dict import DatasetDictBase from tango.common.registrable import Registrable from tango.workspace import Workspace from .data import DataLoader from .exceptions import StopEarly from .model import Model from .train_config import TrainConfig from .training_engine import TrainingEngine class TrainCallback(Registrable): """ A :class:`TrainCallback` is a :class:`~tango.common.Registrable` class that can be used within :class:`TorchTrainStep` to customize behavior in the training loop. You can set the training callbacks with the ``callbacks`` parameter to :class:`TorchTrainStep`. .. tip:: All of the parameters to this base class will be automatically set within the training loop, so you shouldn't include them in your config for your callbacks. .. tip:: You can access the model being trained through :attr:`self.model `. .. important:: The ``step`` argument to callback methods is the total/overall number of training steps so far, independent of the current epoch. .. seealso:: See :class:`~tango.integrations.wandb.WandbTrainCallback` for an example implementation. :ivar Workspace workspace: The tango workspace being used. :ivar TrainConfig train_config: The training config. :ivar TrainingEngine training_engine: The engine used to train the model. :ivar tango.common.DatasetDictBase dataset_dict: The dataset dict containing train and optional validation splits. :ivar DataLoader train_dataloader: The dataloader used for the training split. :ivar DataLoader validation_dataloader: Optional dataloader used for the validation split. """ def __init__( self, workspace: Workspace, train_config: TrainConfig, training_engine: TrainingEngine, dataset_dict: DatasetDictBase, train_dataloader: DataLoader, validation_dataloader: Optional[DataLoader] = None, ) -> None: self.workspace = workspace self.train_config = train_config self.training_engine = training_engine self.dataset_dict = dataset_dict self.train_dataloader = train_dataloader self.validation_dataloader = validation_dataloader self.logger = logging.getLogger(self.__class__.__name__) @property def step_id(self) -> str: """ The unique ID of the current :class:`~tango.Step`. """ return self.train_config.step_id @property def step_name(self) -> Optional[str]: """ The name of the current :class:`~tango.Step`. """ return self.train_config.step_name @property def work_dir(self) -> Path: """ The working directory of the current train step. """ return self.train_config.work_dir @property def is_local_main_process(self) -> bool: """ This is ``True`` if the current worker is the main distributed worker of the current node, or if we are not using distributed training. """ return self.train_config.is_local_main_process @property def model(self) -> Model: """ The :class:`Model` being trained. """ return self.training_engine.model def state_dict(self) -> Dict[str, Any]: """ Return any state that needs to be kept after a restart. Some callbacks need to maintain state across restarts. This is the callback's opportunity to save it's state. It will be restored using :meth:`load_state_dict`. """ return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Load the state on a restart. Some callbacks need to maintain state across restarts. This is the callback's opportunity to restore it's state. It gets saved using :meth:`state_dict`. """ pass def pre_train_loop(self) -> None: """ Called right before the first batch is processed, or after a restart. """ pass def post_train_loop(self, step: int, epoch: int) -> None: """ Called after the training loop completes. This is the last method that is called, so any cleanup can be done in this method. """ pass def pre_epoch(self, step: int, epoch: int) -> None: """ Called right before the start of an epoch. Epochs start at 0. """ pass def post_epoch(self, step: int, epoch: int) -> None: """ Called after an epoch is completed. Epochs start at 0. """ pass def pre_batch(self, step: int, epoch: int, batch: List[Dict[str, Any]]) -> None: """ Called directly before processing a batch. .. note:: A type of ``batch`` is a list because with gradient accumulation there will more than one "micro batch" in the batch. """ pass def post_batch( self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]] ) -> None: """ Called directly after processing a batch, but before unscaling gradients, clipping gradients, and taking an optimizer step. .. note:: The ``batch_loss`` here is the loss local to the current worker, not the overall (average) batch loss across distributed workers. If you need the average loss, use :meth:`log_batch()`. .. note:: A type of ``batch_outputs`` is a list because with gradient accumulation there will more than one "micro batch" in the batch. """ pass def log_batch( self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]] ) -> None: """ Called after the optimizer step. Here ``batch_loss`` is the average loss across all distributed workers. .. note:: This callback method is not necessarily called on every step. The frequency depends on the value of the ``log_every`` parameter of :class:`TorchTrainStep`. .. note:: A type of ``batch_outputs`` is a list because with gradient accumulation there will more than one "micro batch" in the batch. """ pass def pre_val_batch( self, step: int, val_step: int, epoch: int, val_batch: Dict[str, Any] ) -> None: """ Called right before a validation batch is processed. """ pass def post_val_batch( self, step: int, val_step: int, epoch: int, val_batch_outputs: Dict[str, Any] ) -> None: """ Called right after a validation batch is processed with the outputs of the batch. .. tip:: This method can be used to modify ``val_batch_outputs`` in place, which is useful in scenarios like distributed training where you might need to aggregate metrics in a special way other than a simple average. If that's the case, make sure to set ``auto_aggregate_val_metric`` to ``False`` in :class:`TorchTrainStep`. """ pass def post_val_loop( self, step: int, epoch: int, val_metric: float, best_val_metric: float ) -> None: """ Called right after the validation loop finishes. """ pass @TrainCallback.register("torch::stop_early") class StopEarlyCallback(TrainCallback): """ A :class:`TrainCallback` for early stopping. Training is stopped early after ``patience`` steps without an improvement to the validation metric. .. tip:: Registered as a :class:`TrainCallback` under the name "torch::stop_early". """ def __init__(self, *args, patience: int = 10000, **kwargs) -> None: super().__init__(*args, **kwargs) self.patience = patience self.best_step = 0 self.best_val_metric: Optional[float] = None def post_val_loop( self, step: int, epoch: int, val_metric: float, best_val_metric: float ) -> None: # We can't rely on the best_val_metric parameter, because then we can't detect when the metric stays # the same for many steps. if self.best_val_metric is None or val_metric > self.best_val_metric: self.best_step = step self.best_val_metric = val_metric elif step > self.best_step + self.patience: raise StopEarly def state_dict(self) -> Dict[str, Any]: """ Return any state that needs to be kept after a restart. """ return { "patience": self.patience, "best_step": self.best_step, "best_val_metric": self.best_val_metric, } def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Load the state on a restart. """ self.patience = state_dict["patience"] self.best_step = state_dict["best_step"] self.best_val_metric = state_dict["best_val_metric"] ================================================ FILE: tango/integrations/torch/train_config.py ================================================ from dataclasses import asdict, dataclass from pathlib import Path from typing import Any, Dict, List, Optional import torch @dataclass class TrainConfig: """ Encapsulates the parameters of :class:`TorchTrainStep`. This is used to pass all the training options to :class:`TrainCallback`. """ step_id: str """ The unique ID of the current step. """ work_dir: Path """ The working directory for the training run. """ step_name: Optional[str] = None """ The name of the current step. .. note:: The same step can be run under different names. """ worker_id: int = 0 """ The ID of the distributed worker. """ train_split: str = "train" """ The name of the training split. """ validation_split: Optional[str] = None """ The name of the validation split. """ seed: int = 42 """ The random seed. """ train_steps: Optional[int] = None """ The number of steps to train for. """ train_epochs: Optional[int] = None """ The number of epochs to train for. You cannot specify `train_steps` and `train_epochs` at the same time. """ validation_steps: Optional[int] = None """ The number of validation steps. The default is to validate on the entire validation set. """ grad_accum: int = 1 """ The number of micro-batches per gradient accumulation mini-batch. """ log_every: int = 10 """ Controls the frequency of log updates, in number of optimizer steps """ checkpoint_every: int = 100 """ Controls the frequency of checkpoints, in number of optimizer steps """ validate_every: Optional[int] = None """ Controls the frequency of the validation loop, in number of optimizer steps """ is_distributed: bool = False """ Whether or not the training job is distributed. """ devices: Optional[List[int]] = None """ The devices used (for distributed jobs). """ distributed_address: str = "127.0.0.1" """ The IP address of the main distributed process. """ distributed_port: int = 54761 """ The port of the main distributed process. """ val_metric_name: str = "loss" """ The name of the validation metric to track. """ minimize_val_metric: bool = True """ Should be ``True`` when the validation metric being tracked should be minimized. """ auto_aggregate_val_metric: bool = True """ Controls automatic aggregation of validation metric. """ remove_stale_checkpoints: bool = True """ Controls removal of stale checkpoints. """ world_size: int = 1 """ The number of distributed workers. """ _worker_local_default_device: Optional[torch.device] = None _device_type: Optional[str] = None # either "cuda" or "cpu" @property def worker_local_default_device(self) -> torch.device: """ The default ``torch`` device for the current worker. """ if self._worker_local_default_device is not None: return self._worker_local_default_device else: if self.devices: device_id = self.devices[self.worker_id] if device_id >= 0: device = torch.device(f"cuda:{device_id}") else: device = torch.device("cpu") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") self._worker_local_default_device = device return device @property def device_type(self) -> str: if self._device_type is None: device_type = ( "cpu" if self.worker_local_default_device == torch.device("cpu") else "cuda" ) self._device_type = device_type return device_type else: return self._device_type @property def is_local_main_process(self) -> bool: """ Whether the local process is the main distributed worker. """ return self.worker_id == 0 @property def state_path(self) -> Path: """ The path to the latest state checkpoint file. """ return self.work_dir / "checkpoint_state_latest" @property def best_state_path(self) -> Path: """ The path to the best state checkpoint file according to the validation metric or training loss (if no validation split is given). """ return self.work_dir / "checkpoint_state_best" def state_path_for_step(self, step: int) -> Path: return self.work_dir / f"checkpoint_state_step{step + 1}" @property def final_weights_path(self) -> Path: return self.work_dir / "weights.pt" def should_log_this_step(self, step: int) -> bool: assert self.train_steps is not None return step == 0 or (step + 1) % self.log_every == 0 or step == self.train_steps - 1 def should_checkpoint_this_step(self, step: int) -> bool: assert self.train_steps is not None return ((step + 1) % self.checkpoint_every == 0) or step == self.train_steps - 1 def should_log_this_val_step(self, val_step: int) -> bool: assert self.validation_steps is not None return val_step % self.log_every == 0 or val_step == self.validation_steps - 1 def as_dict(self) -> Dict[str, Any]: return {k: v for k, v in asdict(self).items() if not k.startswith("_")} ================================================ FILE: tango/integrations/torch/training_engine.py ================================================ import os import tempfile from abc import abstractmethod from pathlib import Path from typing import Any, Dict, Optional, Tuple, Union, cast import torch import torch.distributed as dist import torch.nn as nn from tango.common import Lazy, Registrable, Tqdm from .model import Model from .optim import LRScheduler, Optimizer from .train_config import TrainConfig from .util import move_to_device class TrainingEngine(Registrable): """ A :class:`TrainingEngine` defines and drives the strategy for training a model in :class:`TorchTrainStep`. :ivar TrainConfig train_config: The training config. :ivar Model model: The model being trained. :ivar Optimizer optimizer: The optimizer being used to train the model. :ivar LRScheduler lr_scheduler: The optional learning rate scheduler. """ default_implementation = "torch" """ The default implementation is :class:`TorchTrainingEngine`. """ def __init__( self, train_config: TrainConfig, model: Union[Model, Lazy[Model]], optimizer: Lazy[Optimizer], *, lr_scheduler: Optional[Lazy[LRScheduler]] = None, ) -> None: self.train_config = train_config self.model = self._construct_model(model) self.optimizer = self._construct_optimizer(optimizer) self.lr_scheduler: Optional[LRScheduler] = None if lr_scheduler is not None: self.lr_scheduler = self._construct_lr_scheduler(lr_scheduler) def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model: if isinstance(model, Lazy): model = model.construct() return model.to(self.train_config.worker_local_default_device) def _construct_optimizer(self, optimizer: Lazy[Optimizer]) -> Optimizer: optimizer: Optimizer = optimizer.construct(params=self.model.parameters()) return optimizer def _construct_lr_scheduler(self, lr_scheduler: Lazy[LRScheduler]) -> LRScheduler: lr_scheduler: LRScheduler = lr_scheduler.construct(optimizer=self.optimizer) return lr_scheduler @abstractmethod def forward_train( self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int ) -> Tuple[torch.Tensor, Dict[str, Any]]: """ Run a forward training pass on the model. """ raise NotImplementedError @abstractmethod def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]: """ Run a forward evaluation pass on the model. """ raise NotImplementedError @abstractmethod def backward(self, loss: torch.Tensor) -> None: """ Run a backwards pass on the model. This will always be called after :meth:`forward_train()`. """ raise NotImplementedError @abstractmethod def step(self) -> None: """ Take an optimization step. This will always be called after :meth:`backward()`. """ raise NotImplementedError @abstractmethod def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None: """ Save a training checkpoint with model state, optimizer state, etc., as well as the arbitrary ``client_state`` to the given ``checkpoint_dir``. """ raise NotImplementedError @abstractmethod def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]: """ Load a checkpoint to resume training. Should return the same ``client_state`` saved in :meth:`save_checkpoint()`. """ raise NotImplementedError @abstractmethod def save_complete_weights_from_checkpoint( self, checkpoint_dir: Path, weights_path: Path ) -> None: """ Gather the final weights from the best checkpoint and save to the file at ``weights_path``. """ raise NotImplementedError @TrainingEngine.register("torch") class TorchTrainingEngine(TrainingEngine): """ This train engine only uses native PyTorch functionality to provide vanilla distributed data parallel training and AMP. .. tip:: Registered as a :class:`TrainingEngine` under the name "torch". .. important:: Only the parameters listed below should be defined in a configuration file. The other parameters will be automatically passed to the constructor within :class:`TorchTrainStep`. :param amp: Use automatic mixed precision. Default is ``False``. :param max_grad_norm: If set, gradients will be clipped to have this max norm. Default is ``None``. :param amp_use_bfloat16: Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training. Only applicable when ``amp=True``. If not specified, the default behavior will be to use ``bfloat16`` when training with AMP on CPU, otherwise not. """ def __init__( self, train_config: TrainConfig, model: Union[Model, Lazy[Model]], optimizer: Lazy[Optimizer], *, lr_scheduler: Optional[Lazy[LRScheduler]] = None, amp: bool = False, max_grad_norm: Optional[float] = None, amp_use_bfloat16: Optional[bool] = None, ) -> None: self.device = train_config.worker_local_default_device if amp_use_bfloat16 is None: amp_use_bfloat16 = True if train_config.device_type == "cpu" else False self.amp = amp self.amp_dtype = torch.bfloat16 if amp_use_bfloat16 else torch.float16 self.max_grad_norm = max_grad_norm self.grad_scaler: Optional[torch.cuda.amp.GradScaler] = ( None if not amp else torch.cuda.amp.GradScaler() ) if train_config.is_distributed: # Initialize distributed process group. backend: str if train_config.device_type != "cpu": torch.cuda.set_device(self.device) backend = "nccl" else: backend = "gloo" dist.init_process_group( backend=backend, init_method=f"tcp://{train_config.distributed_address}:{train_config.distributed_port}", world_size=train_config.world_size, rank=train_config.worker_id, ) super().__init__(train_config, model, optimizer, lr_scheduler=lr_scheduler) def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model: if isinstance(model, Lazy): model = model.construct() model.to(self.train_config.worker_local_default_device) # Wrap model with DDP wrapper. if self.train_config.is_distributed: model = cast(Model, nn.parallel.DistributedDataParallel(model)) return model def forward_train( self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int ) -> Tuple[torch.Tensor, Dict[str, Any]]: if micro_batch_idx == 0: self.optimizer.zero_grad(set_to_none=True) # Move tensors to right device. micro_batch = move_to_device(micro_batch, self.device) with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype): outputs = self.model(**micro_batch) micro_batch_loss = outputs["loss"] / num_micro_batches return micro_batch_loss, outputs def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]: # Move tensors to right device. batch = move_to_device(batch, self.device) with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype): with torch.inference_mode(): outputs = self.model(**batch) return outputs def backward(self, loss: torch.Tensor) -> None: if self.grad_scaler is not None: self.grad_scaler.scale(loss).backward() else: loss.backward() def clip_grad_norm(self) -> None: if self.max_grad_norm is not None: nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) def step(self) -> None: # Unscale gradients. if self.grad_scaler is not None: self.grad_scaler.unscale_(self.optimizer) # Clip gradients. self.clip_grad_norm() # Take optimizer step. if self.grad_scaler is not None: self.grad_scaler.step(self.optimizer) self.grad_scaler.update() else: self.optimizer.step() # Adjust LR schedule. if self.lr_scheduler is not None: self.lr_scheduler.step() def get_model_state(self) -> Dict[str, torch.Tensor]: if self.train_config.is_distributed: return self.model.module.state_dict() # type: ignore[union-attr] else: return self.model.state_dict() def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None: if self.train_config.is_distributed: self.model.module.load_state_dict(state_dict) # type: ignore else: self.model.load_state_dict(state_dict) # type: ignore def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None: checkpoint_dir.mkdir(exist_ok=True) def save_state(state: Dict[str, Any], name: str): temp_state_file = tempfile.NamedTemporaryFile( "w+b", dir=checkpoint_dir, delete=False, suffix=".pt" ) try: with Tqdm.wrapattr( temp_state_file, "write", desc=f"Saving {name} state", leave=False, disable=not self.train_config.is_local_main_process, ) as f: torch.save(state, f) temp_state_file.close() os.replace( temp_state_file.name, checkpoint_dir / f"worker{self.train_config.worker_id}_{name}.pt", ) finally: if os.path.exists(temp_state_file.name): os.remove(temp_state_file.name) save_state(self.get_model_state(), "model") save_state(self.optimizer.state_dict(), "optimizer"), if self.lr_scheduler is not None: save_state(self.lr_scheduler.state_dict(), "lr_scheduler") if self.grad_scaler is not None: save_state(self.grad_scaler.state_dict(), "grad_scaler") save_state(client_state, "trainer") def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]: self.load_model_state( torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_model.pt") ) self.optimizer.load_state_dict( torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_optimizer.pt") ) if self.lr_scheduler is not None: self.lr_scheduler.load_state_dict( torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_lr_scheduler.pt") ) if self.grad_scaler is not None: self.grad_scaler.load_state_dict( torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_grad_scaler.pt") ) return torch.load(checkpoint_dir / f"worker{self.train_config.worker_id}_trainer.pt") def save_complete_weights_from_checkpoint( self, checkpoint_dir: Path, weights_path: Path ) -> None: os.link(checkpoint_dir.resolve() / "worker0_model.pt", weights_path) ================================================ FILE: tango/integrations/torch/util.py ================================================ import random import warnings from collections import UserDict from typing import Dict, Optional, TypeVar, Union import torch import torch.distributed as dist from torch.utils.data import DistributedSampler, IterableDataset from .data import DataLoader T = TypeVar("T") def move_to_device(o: T, device: torch.device) -> T: if isinstance(o, torch.Tensor): return o.to(device) # type: ignore[return-value] elif isinstance(o, dict) or isinstance(o, UserDict): return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value] elif isinstance(o, list): return [move_to_device(x, device) for x in o] # type: ignore[return-value] elif isinstance(o, tuple): return tuple((move_to_device(x, device) for x in o)) # type: ignore[return-value] else: return o def check_dataset(dataset, split: str): try: len(dataset) except TypeError: if not isinstance(dataset, IterableDataset): warnings.warn( f"Dataset for {split} split appears to be a streaming/iterable dataset, " "but is not an instance of 'torch.utils.data.IterableDataset'. This could cause issues " "within the DataLoader.", UserWarning, ) def check_dataloader(dataloader: DataLoader): # If using a regular dataset and not streaming/iterable dataset, we # should probably be using a `DistributedSampler`. if not isinstance(dataloader.dataset, IterableDataset) and not isinstance( dataloader.sampler, DistributedSampler ): warnings.warn( "DistributedSampler is required for dataloader during distributed training, " f"found {type(dataloader.sampler)} instead.", UserWarning, ) def set_seed_all(seed: int): random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) try: import numpy as np except ModuleNotFoundError: pass else: np.random.seed(seed) def resolve_device(device: Optional[Union[int, str, torch.device]] = None) -> torch.device: if device is None: if torch.cuda.is_available(): # TODO (epwalsh, dirkgr): automatically pick which GPU to use when there are multiple return torch.device("cuda") else: return torch.device("cpu") elif isinstance(device, int): if device >= 0: return torch.device(f"cuda:{device}") else: return torch.device("cpu") elif isinstance(device, str): return torch.device(device) elif isinstance(device, torch.device): return device else: raise TypeError(f"unexpected type for 'device': '{device}'") def peak_gpu_memory(reset: bool = False) -> Dict[int, int]: """ Get the peak GPU memory usage in MiB by distributed worker rank. :returns: Keys are rank ids as integers (from 0 up to world size - 1). Values are memory usage as integers in MiB. Returns an empty `dict` if GPUs are not available. """ if not torch.cuda.is_available(): return {} device = torch.device("cuda") results_dict: Dict[int, int] = {} if dist.is_available() and dist.is_initialized(): # If the backend is not 'nccl', we're training on CPU. if dist.get_backend() != "nccl": return {} global_rank = dist.get_rank() world_size = dist.get_world_size() peak_mb = torch.cuda.max_memory_allocated(device) // 1048576 peak_mb_tensor = torch.tensor([global_rank, peak_mb], device=device) # All of these tensors will be gathered into this list. gather_results = [torch.tensor([0, 0], device=device) for _ in range(world_size)] dist.all_gather(gather_results, peak_mb_tensor) for peak_mb_tensor in gather_results: results_dict[int(peak_mb_tensor[0])] = int(peak_mb_tensor[1]) else: results_dict = {0: torch.cuda.max_memory_allocated()} if reset: # Reset peak stats. torch.cuda.reset_max_memory_allocated(device) return results_dict ================================================ FILE: tango/integrations/transformers/__init__.py ================================================ """ .. important:: To use this integration you should install ``tango`` with the "transformers" extra (e.g. ``pip install tango[transformers]``) or just install the ``transformers`` library after the fact (e.g. ``pip install transformers``). Components for Tango integration with `🤗 Transformers `_. This integration provides some useful steps and also registers PyTorch components from the transformers library under the corresponding class from the `torch `_ integration, such as: - :class:`~tango.integrations.torch.Model`: All transformers "auto" model classes are registered according to their class names (e.g. "transformers::AutoModelForCausalLM::from_pretrained" or "transformers::AutoModelForCausalLM::from_config"). For example, to instantiate a pretrained transformer model from params: .. testcode:: from tango.integrations.torch import Model model = Model.from_params({ "type": "transformers::AutoModel::from_pretrained", "pretrained_model_name_or_path": "epwalsh/bert-xsmall-dummy", }) Or to instantiate a transformer model from params without loading pretrained weights: .. testcode:: from tango.integrations.torch import Model model = Model.from_params({ "type": "transformers::AutoModel::from_config", "config": {"pretrained_model_name_or_path": "epwalsh/bert-xsmall-dummy"}, }) .. tip:: You can see a list of all of the available auto model constructors from transformers by running: .. testcode:: from tango.integrations.torch import Model from tango.integrations.transformers import * available_models = [] for name in sorted(Model.list_available()): if name.startswith("transformers::AutoModel"): available_models.append(name) - :class:`~tango.integrations.torch.Optimizer`: All optimizers from transformers are registered according to their class names (e.g. "transformers::AdaFactor"). .. tip:: You can see a list of all of the available optimizers from transformers by running: .. testcode:: from tango.integrations.torch import Optimizer from tango.integrations.transformers import * for name in sorted(Optimizer.list_available()): if name.startswith("transformers::"): print(name) .. testoutput:: transformers::Adafactor transformers::AdamW transformers::LayerWiseDummyOptimizer - :class:`~tango.integrations.torch.LRScheduler`: All learning rate scheduler function from transformers are registered according to their type name (e.g. "transformers::linear"). .. tip:: You can see a list of all of the available scheduler functions from transformers by running: .. testcode:: from tango.integrations.torch import LRScheduler from tango.integrations.transformers import * for name in sorted(LRScheduler.list_available()): if name.startswith("transformers::"): print(name) .. testoutput:: transformers::constant transformers::constant_with_warmup transformers::cosine transformers::cosine_with_min_lr transformers::cosine_with_restarts transformers::inverse_sqrt transformers::linear transformers::polynomial transformers::reduce_lr_on_plateau - :class:`~tango.integrations.torch.DataCollator`: All data collators from transformers are registered according to their class name (e.g. "transformers::DefaultDataCollator"). You can instantiate any of these from a config / params like so: .. testcode:: from tango.integrations.torch import DataCollator collator = DataCollator.from_params({ "type": "transformers::DataCollatorWithPadding", "tokenizer": { "pretrained_model_name_or_path": "epwalsh/bert-xsmall-dummy", }, }) .. tip:: You can see a list of all of the available data collators from transformers by running: .. testcode:: from tango.integrations.torch import DataCollator from tango.integrations.transformers import * for name in sorted(DataCollator.list_available()): if name.startswith("transformers::"): print(name) .. testoutput:: transformers::DataCollatorForLanguageModeling transformers::DataCollatorForPermutationLanguageModeling transformers::DataCollatorForSOP transformers::DataCollatorForSeq2Seq transformers::DataCollatorForTokenClassification transformers::DataCollatorForWholeWordMask transformers::DataCollatorWithPadding transformers::DefaultDataCollator """ from tango.common.exceptions import IntegrationMissingError try: import transformers except ModuleNotFoundError: raise IntegrationMissingError("transformers") __all__ = [ "RunGeneration", "RunGenerationDataset", "Tokenizer", "Config", "add_soft_prompt", "FinetuneWrapper", "FinetuneStep", "TokenizeText2TextData", ] from .config import Config from .data import * # noqa: F403 from .finetune import FinetuneStep, FinetuneWrapper, TokenizeText2TextData from .model import * # noqa: F403 from .optim import * # noqa: F403 from .run_generation import RunGeneration, RunGenerationDataset from .soft_prompt import add_soft_prompt from .tokenizer import Tokenizer ================================================ FILE: tango/integrations/transformers/config.py ================================================ from transformers import AutoConfig, PretrainedConfig from tango.common import Registrable class Config(PretrainedConfig, Registrable): """ A :class:`~tango.common.Registrable` version of transformers' :class:`~transformers.PretrainedConfig`. """ default_implementation = "auto" """ The default registered implementation just calls :meth:`transformers.AutoConfig.from_pretrained()`. """ Config.register("auto", constructor="from_pretrained")(AutoConfig) ================================================ FILE: tango/integrations/transformers/data.py ================================================ from dataclasses import fields, is_dataclass from typing import Callable from transformers.data import data_collator as transformers_data_collator from tango.integrations.torch.data import DataCollator from .tokenizer import Tokenizer # Some data collators take a tokenizer, so in order to instantiate those collators from params, # we need to use a factory function that takes our registrable version of a tokenizer as # an argument. def data_collator_with_tokenizer_factory(cls) -> Callable[..., DataCollator]: def factory(tokenizer: Tokenizer, **kwargs) -> DataCollator: return cls(tokenizer=tokenizer, **kwargs) return factory for name, cls in transformers_data_collator.__dict__.items(): if ( isinstance(cls, type) and is_dataclass(cls) and "DataCollator" in name and hasattr(cls, "__call__") ): for field in fields(cls): if field.name == "tokenizer": factory_func = data_collator_with_tokenizer_factory(cls) DataCollator.register("transformers::" + name)(factory_func) # type: ignore break else: DataCollator.register("transformers::" + name)(cls) ================================================ FILE: tango/integrations/transformers/finetune.py ================================================ import logging from os import PathLike from typing import List, Optional, Union, cast import datasets as ds from transformers import ( AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, PreTrainedModel, ) from tango.common import Lazy, Params from tango.format import Format from tango.integrations.datasets import DatasetsFormat, convert_to_tango_dataset_dict from tango.integrations.torch import ( DataCollator, DataLoader, Model, TorchFormat, TrainCallback, TrainingEngine, ) from tango.integrations.torch.train import TorchTrainStep from tango.integrations.transformers import Tokenizer from tango.step import Step logger = logging.getLogger(__name__) SEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys() # type: ignore CAUSAL = AutoModelForCausalLM._model_mapping.keys() # type: ignore class FinetuneWrapper(PreTrainedModel): """ Wrapper `PreTrainedModel` class that returns either a `Seq2SeqLM` or `CausalLM` model. """ @classmethod def from_pretrained( # type: ignore cls, pretrained_model_name_or_path: Union[str, PathLike], num_tokens: Optional[int] = None, **kwargs, ) -> PreTrainedModel: """ :param pretrained_model_name_or_path: The name of the model to return. Any name that works in the transformers library works here. :param num_tokens: The number of token embeddings to have. """ try: model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path, **kwargs) except ValueError: model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) if num_tokens is not None: model.resize_token_embeddings(num_tokens) return model Model.register("transformers::finetune::from_pretrained", constructor="from_pretrained")( FinetuneWrapper ) def _add_special_tokens(tokenizer: Tokenizer) -> None: if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.sep_token is None: tokenizer.add_special_tokens({"sep_token": "[SEP]"}) if tokenizer.eos_token is None: tokenizer.add_special_tokens({"eos_token": "[EOS]"}) def tokenize_data( data: ds.DatasetDict, tokenizer: Tokenizer, num_workers: int = 1, source_field: str = "source", target_field: str = "target", max_source_length: Optional[int] = 1024, max_target_length: Optional[int] = 1024, pad_to_max_length: bool = False, ignore_pad_token_for_loss: bool = True, concat_source_target: bool = False, ) -> ds.DatasetDict: """ Returns a `DatasetDict` with tokenized source and target fields. :param data: The original dataset dict containing the source and target fields. :param tokenizer: The tokenizer to use. :param num_workers: The number of workers to use for processing the data. :param source_field: The string name of the field containing the source sequence. :param target_field: The string name of the field containing the target sequence. :param max_source_length: The maximum number of tokens in the source sequence. :param max_target_length: The maximum number of tokens in the target sequence. :param pad_to_max_length: Whether to pad to the maximum length when tokenizing. :param ignore_pad_token_for_loss: Whether to ignore the padded tokens for calculating loss. If set to True, all the pad tokens in the labels are replaced by -100, which is ignored by the loss function. :param concat_source_target: If the downstream model is decoder-only, like "gpt2", the source and target sequences need to be concatenated and fed to the model together. """ padding = "max_length" if pad_to_max_length else False _add_special_tokens(tokenizer) def preprocess_function(examples): # remove pairs where at least one record is None inputs, targets = [], [] input_lengths = [] for i in range(len(examples[source_field])): if examples[source_field][i] is not None and examples[target_field][i] is not None: if not concat_source_target: inputs.append(examples[source_field][i]) targets.append(examples[target_field][i]) else: text = ( examples[source_field][i] + tokenizer.sep_token + examples[target_field][i] + tokenizer.eos_token ) inputs.append(text) targets.append(text) input_lengths.append(len(examples[source_field][i])) model_inputs = tokenizer( inputs, max_length=max_source_length, padding=padding, truncation=True ) if not concat_source_target: # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer( targets, max_length=max_target_length, padding=padding, truncation=True ) else: labels = {"input_ids": []} for input_ids in model_inputs["input_ids"]: label_start_idx = input_ids.index(tokenizer.sep_token_id) label_ids = [-100] * len(input_ids) label_ids[label_start_idx + 1 :] = input_ids[label_start_idx + 1 :] labels["input_ids"].append(label_ids) # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 # when we want to ignore padding in the loss. if padding == "max_length" and ignore_pad_token_for_loss: labels["input_ids"] = [ [(lb if lb != tokenizer.pad_token_id else -100) for lb in label] for label in labels["input_ids"] ] model_inputs["labels"] = labels["input_ids"] return model_inputs data = data.map( preprocess_function, batched=True, num_proc=num_workers, remove_columns=list(data.column_names.values())[0], # remove all old columns desc="Tokenizing dataset", ) return data @Step.register("transformers::tokenize_text2text") class TokenizeText2TextData(Step): """ A step that tokenizes data containing source and target sequences. .. tip:: Registered as a :class:`~tango.step.Step` under the name "transformers::tokenize_text2text". """ DETERMINISTIC = True CACHEABLE = True FORMAT = DatasetsFormat() def run( # type: ignore[override] self, data: ds.DatasetDict, tokenizer: Tokenizer, num_workers: int = 1, source_field: str = "source", target_field: str = "target", max_source_length: Optional[int] = 1024, max_target_length: Optional[int] = 1024, pad_to_max_length: bool = False, ignore_pad_token_for_loss: bool = True, concat_source_target: bool = False, ) -> ds.DatasetDict: """ Returns a `DatasetDict` with tokenized source and target fields. :param data: The original dataset dict containing the source and target fields. :param tokenizer: The tokenizer to use. :param num_workers: The number of workers to use for processing the data. :param source_field: The string name of the field containing the source sequence. :param target_field: The string name of the field containing the target sequence. :param max_source_length: The maximum number of tokens in the source sequence. :param max_target_length: The maximum number of tokens in the target sequence. :param pad_to_max_length: Whether to pad to the maximum length when tokenizing. :param ignore_pad_token_for_loss: Whether to ignore the padded tokens for calculating loss. If set to True, all the pad tokens in the labels are replaced by -100, which is ignored by the loss function. :param concat_source_target: If the downstream model is decoder-only, like "gpt2", the source and target sequences need to be concatenated and fed to the model together. .. tip:: If concat_source_target is set to True, we pad all sequences to max length here. Otherwise, we leave it to the appropriate :class:`~tango.integrations.torch.DataCollator` object. """ return tokenize_data( data, tokenizer=tokenizer, num_workers=num_workers, source_field=source_field, target_field=target_field, max_source_length=max_source_length, max_target_length=max_target_length, pad_to_max_length=pad_to_max_length, ignore_pad_token_for_loss=ignore_pad_token_for_loss, concat_source_target=concat_source_target, ) @Step.register("transformers::finetune") class FinetuneStep(TorchTrainStep): """ Mostly similar to :class:`~tango.integrations.torch.train.TorchTrainStep` with additional preprocessing for data. .. tip:: Registered as a :class:`~tango.step.Step` under the name "transformers::finetune". .. important:: The training loop will use GPU(s) automatically when available, as long as at least ``device_count`` CUDA devices are available. Distributed data parallel training is activated when the ``device_count`` is greater than 1. You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``. For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1`` (and ``device_count`` to 2). .. warning:: During validation, the validation metric (specified by the ``val_metric_name`` parameter) is aggregated by simply averaging across validation batches and distributed processes. This behavior is usually correct when your validation metric is "loss" or "accuracy", for example, but may not be correct for other metrics like "F1". If this is not correct for your metric you will need to handle the aggregation internally in your model or with a :class:`TrainCallback` using the :meth:`TrainCallback.post_val_batch()` method. Then set the parameter ``auto_aggregate_val_metric`` to ``False``. Note that correctly aggregating your metric during distributed training will involve distributed communication. """ DETERMINISTIC = True CACHEABLE = True FORMAT: Format = TorchFormat() SKIP_ID_ARGUMENTS = {"distributed_port", "log_every"} def run( # type: ignore[override] self, model: Lazy[Model], tokenizer: Tokenizer, training_engine: Lazy[TrainingEngine], dataset_dict: ds.DatasetDict, train_dataloader: Lazy[DataLoader], *, train_split: str = "train", validation_split: Optional[str] = None, validation_dataloader: Optional[Lazy[DataLoader]] = None, source_field: str = "source", target_field: str = "target", max_source_length: Optional[int] = 1024, max_target_length: Optional[int] = 1024, seed: int = 42, train_steps: Optional[int] = None, train_epochs: Optional[int] = None, validation_steps: Optional[int] = None, grad_accum: int = 1, log_every: int = 10, checkpoint_every: int = 100, validate_every: Optional[int] = None, device_count: int = 1, distributed_port: int = 54761, val_metric_name: str = "loss", minimize_val_metric: bool = True, auto_aggregate_val_metric: bool = True, callbacks: Optional[List[Lazy[TrainCallback]]] = None, remove_stale_checkpoints: bool = True, ) -> Model: """ Run a basic training loop to train the ``model``. :param model: The model to train. It should return a ``dict`` that includes the ``loss`` during training and the ``val_metric_name`` during validation. :param tokenizer: The tokenizer to use for tokenizing source and target sequences. :param training_engine: The :class:`TrainingEngine` to use to train the model. :param dataset_dict: The train and optional validation data. :param train_dataloader: The data loader that generates training batches. The batches should be :class:`dict` objects that will be used as ``kwargs`` for the model's ``forward()`` method. :param train_split: The name of the data split used for training in the ``dataset_dict``. Default is "train". :param validation_split: Optional name of the validation split in the ``dataset_dict``. Default is ``None``, which means no validation. :param validation_dataloader: An optional data loader for generating validation batches. The batches should be :class:`dict` objects. If not specified, but ``validation_split`` is given, the validation ``DataLoader`` will be constructed from the same parameters as the train ``DataLoader``. :param source_field: The string name of the field containing the source sequence. :param target_field: The string name of the field containing the target sequence. :param max_source_length: The maximum number of tokens in the source sequence. :param max_target_length: The maximum number of tokens in the target sequence. :param seed: Used to set the RNG states at the beginning of training. :param train_steps: The number of steps to train for. If not specified training will stop after a complete iteration through the ``train_dataloader``. :param train_epochs: The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs`` at the same time. :param validation_steps: The number of steps to validate for. If not specified validation will stop after a complete iteration through the ``validation_dataloader``. :param grad_accum: The number of gradient accumulation steps. Defaults to 1. .. note:: This parameter - in conjuction with the settings of your data loader and the number distributed workers - determines the *effective batch size* of your training run. :param log_every: Log every this many steps. :param checkpoint_every: Save a checkpoint every this many steps. :param validate_every: Run the validation loop every this many steps. :param device_count: The number of devices to train on, i.e. the number of distributed data parallel workers. :param distributed_port: The port of the distributed process group. Default = "54761". :param val_metric_name: The name of the validation metric, i.e. the key of the metric in the dictionary returned by the forward pass of the model. Default is "loss". :param minimize_val_metric: Whether the validation metric is meant to be minimized (such as the loss). Default is ``True``. When using a metric such as accuracy, you should set this to ``False``. :param auto_aggregate_val_metric: If ``True`` (the default), the validation metric will be averaged across validation batches and distributed processes. This may not be the correct behavior for some metrics (such as F1), in which you should set this to ``False`` and handle the aggregation internally in your model or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`). :param callbacks: A list of :class:`TrainCallback`. :param remove_stale_checkpoints: If ``True`` (the default), stale checkpoints will be removed throughout training so that only the latest and best checkpoints are kept. :returns: The trained model on CPU with the weights from the best checkpoint loaded. """ devices = self._get_devices(device_count) is_distributed = False if devices and len(devices) > 1: is_distributed = True # Setup the tokenizer _add_special_tokens(tokenizer) # Hacky way to deal with resizing the model embeddings. model_params_dict = model._params.as_dict() if "fairscale" in model_params_dict["type"]: model_params_dict["model"]["num_tokens"] = len(tokenizer) # type: ignore else: model_params_dict["num_tokens"] = len(tokenizer) # type: ignore model = Lazy( model._constructor, Params(model_params_dict), constructor_extras=model._constructor_extras, ) # Get the config to check in order to check if the model is seq2seq or causal. config = AutoConfig.from_pretrained(tokenizer.name_or_path) seq2seq: bool = type(config) in SEQ2SEQ dataset_dict = tokenize_data( dataset_dict, tokenizer=tokenizer, source_field=source_field, target_field=target_field, max_source_length=max_source_length, max_target_length=max_target_length, concat_source_target=not seq2seq, ) if is_distributed: from torch.utils.data.distributed import DistributedSampler sampler = Lazy(DistributedSampler, drop_last=True, shuffle=True) train_dataloader = Lazy( train_dataloader._constructor, train_dataloader._params, constructor_extras=train_dataloader._constructor_extras, sampler=sampler, ) collate_fn: DataCollator collate_fn = cast(DataCollator, DataCollatorForSeq2Seq(tokenizer=tokenizer)) train_dataloader = Lazy( train_dataloader._constructor, train_dataloader._params, constructor_extras=train_dataloader._constructor_extras, collate_fn=collate_fn, ) return self._train( model=model, training_engine=training_engine, dataset_dict=convert_to_tango_dataset_dict(dataset_dict), train_dataloader=train_dataloader, train_split=train_split, validation_split=validation_split, validation_dataloader=validation_dataloader, seed=seed, train_steps=train_steps, train_epochs=train_epochs, validation_steps=validation_steps, grad_accum=grad_accum, log_every=log_every, checkpoint_every=checkpoint_every, validate_every=validate_every, devices=devices, distributed_port=distributed_port, val_metric_name=val_metric_name, minimize_val_metric=minimize_val_metric, auto_aggregate_val_metric=auto_aggregate_val_metric, callbacks=callbacks, remove_stale_checkpoints=remove_stale_checkpoints, ) ================================================ FILE: tango/integrations/transformers/ia3.py ================================================ import re from dataclasses import dataclass from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_utils import Conv1D @dataclass class WithIA3Config: """ A class for configuring which layers to modify with IA3 adaptors. :param ia3_param_names: A string used as the name for all ia3 parameters :param attention_modules: A regex that matches all attention modules which are parents to the keys and value layers to modify. :param mlp_modules: A regex that matches all modules that are parents to the feed forward layer to modify. :param mlp_layers: A regex that matches the feed forward layer in the modules specified by `mlp_modles`. :param fused_qkv_layers: A regex that matches the combined query, key, and value layer in the modules specified by `attention_modules`. :param k_layers: A regex that matches the key layer in the modules specified by `attention_modules`. :param v_layers: A regex that matches the value layer in the modules specified by `attention_modules`. """ ia3_param_names: str attention_modules: str mlp_modules: str mlp_layers: str fused_qkv_layers: Optional[str] = None k_layers: Optional[str] = None v_layers: Optional[str] = None GPT_J_IA3_CONFIG = WithIA3Config( attention_modules=".*attn", k_layers="k_proj", v_layers="v_proj", mlp_modules=".*mlp", mlp_layers="fc_in", ia3_param_names="ia3", ) GPT_2_IA3_CONFIG = WithIA3Config( attention_modules=".*attn", fused_qkv_layers="c_attn", mlp_modules=".*mlp", mlp_layers="c_fc", ia3_param_names="ia3", ) OPT_IA3_CONFIG = WithIA3Config( attention_modules=".*self_attn", k_layers="k_proj", v_layers="v_proj", mlp_modules=r".*layers\.\d*", mlp_layers="fc1", ia3_param_names="ia3", ) BLOOM_IA3_CONFIG = WithIA3Config( attention_modules=".*self_attention", fused_qkv_layers="query_key_value", mlp_modules=".*mlp", mlp_layers="dense_h_to_4h", ia3_param_names="ia3", ) MODEL_NAME_TO_CONFIG = { "sshleifer/tiny-gpt2": GPT_2_IA3_CONFIG, "gpt2": GPT_2_IA3_CONFIG, "gpt2-medium": GPT_2_IA3_CONFIG, "gpt2-large": GPT_2_IA3_CONFIG, "gpt2-xl": GPT_2_IA3_CONFIG, "bigscience/bloom-560m": BLOOM_IA3_CONFIG, "bigscience/bloom-1b1": BLOOM_IA3_CONFIG, "bigscience/bloom-1b7": BLOOM_IA3_CONFIG, "bigscience/bloom-3b": BLOOM_IA3_CONFIG, "bigscience/bloom-7b1": BLOOM_IA3_CONFIG, "bigscience/bloom": BLOOM_IA3_CONFIG, "facebook/opt-125m": OPT_IA3_CONFIG, "facebook/opt-350m": OPT_IA3_CONFIG, "facebook/opt-1.3b": OPT_IA3_CONFIG, "facebook/opt-2.7b": OPT_IA3_CONFIG, "facebook/opt-6.7b": OPT_IA3_CONFIG, "facebook/opt-13b": OPT_IA3_CONFIG, "facebook/opt-30b": OPT_IA3_CONFIG, "facebook/opt-66b": OPT_IA3_CONFIG, "EleutherAI/gpt-j-6B": GPT_J_IA3_CONFIG, } class WithIA3(nn.Module): def __init__(self, ia3_param_names: str, unfuse_size: Optional[int] = None): super().__init__() self.ia3_param_names = ia3_param_names # if (q,k,v) are stacked into one layer if unfuse_size is not None: # IA3 only operates on k and v (not q), thus the "* 2" setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1))) else: setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) # type: ignore def scale_by_ia3(self, x): ia3_params = getattr(self, self.ia3_param_names) if ia3_params.requires_grad: if self.unfuse_size is not None: # non_q means k and v q, non_q = x[:, :, : self.unfuse_size], x[:, :, self.unfuse_size :] # type: ignore ia3_params = getattr(self, self.ia3_param_names) non_q = non_q * ia3_params.flatten() x = torch.cat([q, non_q], dim=2) else: x = x * ia3_params.flatten() return x class LinearWithIA3(WithIA3): def __init__( self, linear_layer: nn.Linear, ia3_param_names: str, unfuse_size: Optional[int] = None ): """ A replacement for :class:`~torch.nn.Linear` modified with an IA3 adaptor :param linear_layer: A :class:`~torch.nn.Linear` layer to adapt. :param ia3_param_names: A `str` to use as the name of ia3 parameters. :param unfuse_size: An `int` indicating hidden dimension of the query, key, and value vectors. To be used only when the layer to modify is a fused projection of query, key, and value vectors in an attention mechanism. """ assert unfuse_size is None or (linear_layer.out_features == unfuse_size * 3) self.in_features = linear_layer.in_features self.out_features = linear_layer.out_features self.unfuse_size = unfuse_size super().__init__(ia3_param_names, unfuse_size) self.weight = linear_layer.weight self.bias = linear_layer.bias def forward(self, x): x = F.linear(x, self.weight, self.bias) return self.scale_by_ia3(x) class Conv1DWithIA3(WithIA3): def __init__( self, conv1d_layer: Conv1D, ia3_param_names: str, unfuse_size: Optional[int] = None ): """ A replacement for :class:`~transformers.modeling_utils.Conv1D` modified with an IA3 adaptor :param conv1d_layer: A :class:`~transformers.modeling_utils.Conv1D` layer to adapt. :param ia3_param_names: A `str` to use as the name of ia3 parameters. :param unfuse_size: An `int` indicating hidden dimension of the query, key, and value vectors. To be used only when the layer to modify is a fused projection of query, key, and value vectors in an attention mechanism. """ assert unfuse_size is None or (conv1d_layer.nf == unfuse_size * 3) # nf: number of output features; nx: number of input features self.out_features = conv1d_layer.nf self.unfuse_size = unfuse_size super().__init__(ia3_param_names, unfuse_size) self.weight = conv1d_layer.weight self.bias = conv1d_layer.bias def forward(self, x): # copied and pasted from the original Conv1D implemnetation size_out = x.size()[:-1] + (self.out_features,) x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) x = x.view(size_out) # ... * self.nf return self.scale_by_ia3(x) def modify_with_ia3( transformer: PreTrainedModel, *, config: Optional[WithIA3Config] = None, only_ia3_requires_grad: bool = True, ) -> PreTrainedModel: """ A function to add ia3 adaptors to the given transformer. Code modified from `t-few `_ and Qinyuan Ye :param model: A :class:`~transformers.PreTrainedModel` to modify. :param config: A :class:`~tango.integrations.transformers.ia3.WithIA3Config` that specifies the layers to modify. :param only_ia3_requires_grad: A `bool`, `True` if `requires_grad` should only be set on ia3 paramenters in the output model. Examples -------- You can use the provided configurations: .. testcode:: from transformers import AutoModelForCausalLM, AutoTokenizer from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") model = modify_with_ia3(model, config=GPT_2_IA3_CONFIG) Or you can write your own configuration with regex matching the layers to modify and their parents: .. testcode:: from transformers import AutoModelForCausalLM, AutoTokenizer from tango.integrations.transformers.ia3 import modify_with_ia3 my_config = WithIA3Config( attention_modules=".*attn", fused_qkv_layers="c_attn", mlp_modules=".*mlp", mlp_layers="c_fc", ia3_param_names="ia3", ) model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") model = modify_with_ia3(model, config=my_config) """ if config is None: model_name = transformer.config._name_or_path # type: ignore assert ( model_name in MODEL_NAME_TO_CONFIG ), f"{model_name} does not have a pre made configuration; please make your own." config = MODEL_NAME_TO_CONFIG[model_name] for m_name, module in dict(transformer.named_modules()).items(): # type: ignore if re.fullmatch(config.attention_modules, m_name) or re.fullmatch( config.mlp_modules, m_name ): attn_layers = [ regex for regex in (config.fused_qkv_layers, config.k_layers, config.v_layers) if regex is not None ] layers_to_change = ( "|".join(attn_layers) if re.fullmatch(config.attention_modules, m_name) else config.mlp_layers ) for c_name, layer in dict(module.named_children()).items(): if re.fullmatch(layers_to_change, c_name): assert isinstance(layer, Conv1D) or isinstance( layer, nn.Linear ), "This code only supports Conv1D and nn.Linear" adaptor_class = Conv1DWithIA3 if isinstance(layer, Conv1D) else LinearWithIA3 new_module = adaptor_class( layer, config.ia3_param_names, unfuse_size=transformer.config.hidden_size # type: ignore if config.fused_qkv_layers and re.fullmatch(config.fused_qkv_layers, c_name) else None, ) setattr(module, c_name, new_module) if only_ia3_requires_grad: transformer.requires_grad_(False) # type: ignore for p_name, v in dict(transformer.named_parameters()).items(): # type: ignore if re.fullmatch(".*" + config.ia3_param_names + ".*", p_name): v.requires_grad_(True) return transformer ================================================ FILE: tango/integrations/transformers/model.py ================================================ from typing import Optional, Type from transformers.models.auto import modeling_auto from tango.common.exceptions import IntegrationMissingError from tango.integrations.torch.model import Model from .config import Config def auto_model_wrapper_factory(cls: type) -> Type[Model]: class AutoModelWrapper(cls, Model): # type: ignore @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, config: Optional[Config] = None, **kwargs ) -> Model: return super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs) @classmethod def from_config(cls, config: Config, **kwargs) -> Model: return super().from_config(config, **kwargs) return AutoModelWrapper for name, cls in modeling_auto.__dict__.items(): if isinstance(cls, type) and name.startswith("AutoModel"): wrapped_cls = auto_model_wrapper_factory(cls) Model.register( "transformers::" + name + "::from_pretrained", constructor="from_pretrained" )(wrapped_cls) Model.register("transformers::" + name + "::from_config", constructor="from_config")( wrapped_cls ) try: from transformers.models.auto import modeling_flax_auto from tango.integrations.flax.model import Model as FlaxModel def flax_auto_model_wrapper_factory(cls: type) -> Type[FlaxModel]: class AutoModelWrapper(cls, FlaxModel): # type: ignore @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, config: Optional[Config] = None, **kwargs ) -> FlaxModel: return super().from_pretrained( pretrained_model_name_or_path, config=config, **kwargs ) @classmethod def from_config(cls, config: Config, **kwargs) -> FlaxModel: return super().from_config(config, **kwargs) return AutoModelWrapper for name, cls in modeling_flax_auto.__dict__.items(): if isinstance(cls, type) and name.startswith("FlaxAutoModel"): wrapped_cls_ = flax_auto_model_wrapper_factory(cls) FlaxModel.register( "transformers::" + name + "::from_pretrained", constructor="from_pretrained" )(wrapped_cls_) FlaxModel.register( "transformers::" + name + "::from_config", constructor="from_config" )(wrapped_cls_) except ModuleNotFoundError: pass except IntegrationMissingError: pass ================================================ FILE: tango/integrations/transformers/optim.py ================================================ import torch from transformers import optimization as transformers_optim from tango.integrations.torch.optim import LRScheduler, Optimizer # Register all transformers optimizers. for name, cls in transformers_optim.__dict__.items(): if ( isinstance(cls, type) and issubclass(cls, torch.optim.Optimizer) and not cls == torch.optim.Optimizer ): Optimizer.register("transformers::" + name)(cls) # Register all transformers scheduler factory functions. for scheduler_type, scheduler_func in transformers_optim.TYPE_TO_SCHEDULER_FUNCTION.items(): name = scheduler_type.value LRScheduler.register("transformers::" + name)(scheduler_func) # type: ignore ================================================ FILE: tango/integrations/transformers/run_generation.py ================================================ import logging import typing from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union, cast import more_itertools import torch from datasets import Dataset from datasets import DatasetDict as HfDatasetDict from transformers import ( AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer, CTRLLMHeadModel, CTRLTokenizer, GPT2LMHeadModel, GPT2Tokenizer, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast, TransfoXLLMHeadModel, TransfoXLTokenizer, XLMTokenizer, XLMWithLMHeadModel, XLNetLMHeadModel, XLNetTokenizer, ) from tango import Format, JsonFormat, SqliteDictFormat, Step from tango.common import DatasetDict from tango.common.sequences import MappedSequence, SqliteSparseSequence from tango.common.tqdm import Tqdm from tango.integrations.torch import Model from tango.integrations.torch.util import resolve_device, set_seed_all logger = logging.getLogger(__name__) # # A lot of the code in this step is stolen from the run_generation.py script in transformers. Unfortunately their # examples don't ship when you `pip install transformers`, so we have to duplicate it here. # MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop MODEL_CLASSES = { "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), "ctrl": (CTRLLMHeadModel, CTRLTokenizer), "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), "xlnet": (XLNetLMHeadModel, XLNetTokenizer), "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), "xlm": (XLMWithLMHeadModel, XLMTokenizer), } # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # in https://github.com/rusiaaman/XLNet-gen#methodology # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision and denounces one of the men as a horse thief. Although his father initially slaps him for making such an accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing. """ SEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys() # type: ignore CAUSAL = AutoModelForCausalLM._model_mapping.keys() # type: ignore def adjust_length_to_model(length, model): max_sequence_length = ( model.config.max_position_embeddings if hasattr(model.config, "max_position_embeddings") else MAX_LENGTH ) if length < 0 and max_sequence_length > 0: length = max_sequence_length elif 0 < max_sequence_length < length: length = max_sequence_length # No generation bigger than model size elif length < 0: length = MAX_LENGTH # avoid infinite loop return length @typing.no_type_check # mypy has somehow lost the ability to tell what PreTrainedTokenizer and Model are. def _generate( model: Model, # TODO: Change type to `Tokenizer` once HF includes `convert_tokens_to_ids` in `PretrainedTokenizerBase` class. tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], prompts: Iterable[str], *, batch_size: int = 4, max_length: int = 20, temperature: float = 1.0, repetition_penalty: float = 1.0, k: int = 0, p: float = 0.9, prefix: str = "", xlm_language: str = "", seed: int = 42, num_return_sequences: int = 1, fp16: bool = False, ) -> Iterable[List[str]]: if not isinstance(model.config, tuple(SEQ2SEQ + CAUSAL)): raise NotImplementedError( "This function is only defined for huggingface models seq2seq/causal models." ) device = resolve_device() set_seed_all(seed) tokenizer_kwargs: Dict[str, Any] = {} tokenizer.padding_side = "left" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) if tokenizer.eos_token is None: tokenizer.add_special_tokens({"eos_token": "[EOS]"}) eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) # Seq2Seq models don't return their own prefix. seq2seq_model = model.config_class in SEQ2SEQ # HF does not do this? WTF? model.eval() model.to(device) if fp16: model.half() def prepare_batch_without_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]: result = tokenizer.batch_encode_plus( prompts, add_special_tokens=False, return_tensors="pt", padding=True, **tokenizer_kwargs, ) result = {key: tensor.to(device) for key, tensor in result.items()} return result def prepare_batch_with_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]: if len(prefix) > 0: prompts = [f"{prefix} {t}" for t in prompts] return prepare_batch_without_prefix(prompts) prepare_batch_fn = prepare_batch_with_prefix num_prefix_tokens: Optional[int] = None # transformer model-specific exceptions if isinstance(model, PreTrainedModel) and model.config_class: if model.config_class.model_type == "xlm": use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb if hasattr(model.config, "lang2id") and use_lang_emb: model.config.lang_id = xlm_language # Original HF code ignores the prefix, but it looks like a bug? prepare_batch_fn = prepare_batch_without_prefix num_prefix_tokens = 0 elif model.config_class.model_type in {"xlnet", "transfo-xl"}: prefix = prefix if prefix else PREFIX if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: # This actually doesn't work in the current version of transformers, which is probably a bug in the # transformers library. tokenizer_kwargs = {"add_space_before_punct_symbol": True} if num_prefix_tokens is None: num_prefix_tokens = len(tokenizer.tokenize(prefix)) batches = more_itertools.chunked(Tqdm.tqdm(prompts, desc="Pre-processing prompts"), batch_size) encoded_batches = map(prepare_batch_fn, batches) for encoded_batch in Tqdm.tqdm(encoded_batches, desc="Processing batches"): if seq2seq_model: length = max_length else: length = adjust_length_to_model(max_length + encoded_batch["input_ids"].size(1), model) with torch.inference_mode(): generated_sequences: torch.Tensor = model.generate( # type: ignore **encoded_batch, max_length=length, temperature=temperature, top_k=k, top_p=p, repetition_penalty=repetition_penalty, do_sample=True, num_return_sequences=num_return_sequences, synced_gpus=False, # Needs to be True if we have more than one GPU running. ) generated_sequences = generated_sequences.view( -1, num_return_sequences, *generated_sequences.shape[1:] ).to("cpu") def strip_special_tokens(t: torch.Tensor) -> torch.Tensor: # amazing that torch has no capability for this start = 0 while start < len(t) and int(t[start]) in {0, eos_token_id, pad_token_id}: start += 1 end = len(t) while int(t[end - 1]) in {0, eos_token_id, pad_token_id} and end > start: end -= 1 return t[start:end] # strip padding generated_sequences_list = [ [strip_special_tokens(sequence) for sequence in per_prompt_sequences] for per_prompt_sequences in generated_sequences ] # strip prefix if not seq2seq_model: generated_sequences_list = [ [sequence[num_prefix_tokens:] for sequence in per_prompt_sequences] for per_prompt_sequences in generated_sequences_list ] texts = [ tokenizer.batch_decode(per_prompt_sequences, clean_up_tokenization_spaces=True) for per_prompt_sequences in generated_sequences_list ] yield from texts def _generate_with_model_name(model_name: str, *args, **kwargs) -> Iterable[List[str]]: try: model = AutoModelForSeq2SeqLM.from_pretrained(model_name) except ValueError: model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) return _generate(model, tokenizer, *args, **kwargs) @Step.register("transformers::run_generation") class RunGeneration(Step[Iterable[List[str]]]): """ A step that runs seq2seq Huggingface models in inference mode. .. tip:: Registered as a :class:`~tango.step.Step` under the name "transformers::run_generation". """ FORMAT: Format = JsonFormat("gz") VERSION = "001" SKIP_ID_ARGUMENTS = {"batch_size"} # TODO: multiple GPUs def run( # type: ignore self, model: Union[str, Model], prompts: Iterable[str], *, tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None, batch_size: int = 4, max_length: int = 20, temperature: float = 1.0, repetition_penalty: float = 1.0, k: int = 0, p: float = 0.9, prefix: str = "", xlm_language: str = "", seed: int = 42, num_return_sequences: int = 1, fp16: bool = False, ) -> Iterable[List[str]]: """ Run a Huggingface seq2seq model in inference mode. :param model: The name of the model to run. Any name that works in the transformers library works here. Or, you can directly provide the model to run. :param prompts: The prompts to run through the model. You can specify prompts directly in the config, but more commonly the prompts are produced by another step that reads a dataset, for example. :param tokenizer: The tokenizer to run. :param batch_size: The number of sequences to process at one time. This has no bearing on the output, so you can change this number without invalidating cached results. :param max_length: The maximum number of tokens/word pieces that the model will generate. For models that extend the prompt, the prefix does not count towards this limit. :param temperature: Passed directly to transformer's ``generate()`` method. The value used to model the next token probabilities. :param repetition_penalty: Passed directly to transformer's ``generate()`` method. The parameter for repetition penalty. 1.0 means no penalty. :param k: Passed directly to transformer's ``generate()`` method. The number of highest probability vocabulary tokens to keep for top-k-filtering. :param p: Passed directly to transformer's ``generate()`` method. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. :param prefix: A prefix that gets pre-pended to all prompts. :param xlm_language: For the XLM model, this is a way to specify the language you want to use. :param seed: Random seed :param num_return_sequences: The number of generations to return for each prompt. :param fp16: Whether to use 16-bit floats. :returns: Returns an iterator of lists of string. Each list contains the predictions for one prompt. """ if isinstance(model, str): try: model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model)) except ValueError: model = cast(Model, AutoModelForCausalLM.from_pretrained(model)) tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path) return _generate( model, tokenizer, prompts, batch_size=batch_size, max_length=max_length, temperature=temperature, repetition_penalty=repetition_penalty, k=k, p=p, prefix=prefix, xlm_language=xlm_language, seed=seed, num_return_sequences=num_return_sequences, fp16=fp16, ) @Step.register("transformers::run_generation_dataset") class RunGenerationDataset(Step[DatasetDict]): """ A step that runs seq2seq Huggingface models in inference mode. This is similar to :class:`RunGeneration`, but it takes a dataset as input and produces a new dataset as output, which contains the predictions in a new field. .. tip:: Registered as a :class:`~tango.step.Step` under the name "transformers::run_generation_dataset". """ FORMAT: Format = SqliteDictFormat() VERSION = "002" SKIP_ID_ARGUMENTS = {"batch_size"} def run( # type: ignore self, model: Union[str, Model], input: Union[DatasetDict, HfDatasetDict], prompt_field: str, *, tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None, output_field: Optional[str] = None, splits: Optional[Union[str, Set[str]]] = None, batch_size: int = 4, max_length: int = 20, temperature: float = 1.0, repetition_penalty: float = 1.0, k: int = 0, p: float = 0.9, prefix: str = "", xlm_language: str = "", seed: int = 42, num_return_sequences: int = 1, fp16: bool = False, ) -> DatasetDict: """ Augment an input dataset with generations from a Huggingface seq2seq model. :param model: The name of the model to run. Any name that works in the transformers library works here. Or, you can directly provide the model to run. :param input: The input dataset. :param prompt_field: The field in the dataset that contains the text of the prompts. :param tokenizer: The tokenizer to run. :param output_field: The field in the dataset that we will write the predictions into. In the result, this field will contain ``List[str]``. :param splits: A split, or set of splits, to process. If this is not specified, we will process all splits. :param batch_size: The number of sequences to process at one time. This has no bearing on the output, so you can change this number without invalidating cached results. :param max_length: The maximum number of tokens/word pieces that the model will generate. For models that extend the prompt, the prefix does not count towards this limit. :param temperature: Passed directly to transformer's `generate()` method. The value used to model the next token probabilities. :param repetition_penalty: Passed directly to transformer's `generate()` method. The parameter for repetition penalty. 1.0 means no penalty. :param k: Passed directly to transformer's `generate()` method. The number of highest probability vocabulary tokens to keep for top-k-filtering. :param p: Passed directly to transformer's `generate()` method. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. :param prefix: A prefix that gets pre-pended to all prompts. :param xlm_language: For the XLM model, this is a way to specify the language you want to use. :param seed: Random seed :param num_return_sequences: The number of generations to return for each prompt. :param fp16: Whether to use 16-bit floats. :returns: Returns a dataset with an extra field containing the predictions. """ if isinstance(model, str): try: model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model)) except ValueError: model = cast(Model, AutoModelForCausalLM.from_pretrained(model)) tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path) if isinstance(input, HfDatasetDict): input = DatasetDict(input, {}) if splits is None: splits = input.keys() elif isinstance(splits, str): splits = {splits} result: Dict[str, Sequence] = {} for split_name, input_split in input.items(): if split_name in splits: output_split = SqliteSparseSequence(self.work_dir / f"{split_name}.sqlite") if len(output_split) > 0: logger.info( "Found %d items already generated. Will generate %d more.", len(output_split), len(input_split) - len(output_split), ) if len(output_split) > 0: if isinstance(input_split, Dataset): input_split = input_split.select(range(len(output_split), len(input_split))) else: input_split = input_split[len(output_split) :] prompts = MappedSequence(lambda i: i[prompt_field], input_split) generations = _generate( model, tokenizer, prompts, batch_size=batch_size, max_length=max_length, temperature=temperature, repetition_penalty=repetition_penalty, k=k, p=p, prefix=prefix, xlm_language=xlm_language, seed=seed, num_return_sequences=num_return_sequences, fp16=fp16, ) for instance, generation in zip(input_split, generations): output_split.append( {**instance, **{output_field or prompt_field + "_generated": generation}} ) result[split_name] = output_split else: result[split_name] = input_split return DatasetDict(result, input.metadata) ================================================ FILE: tango/integrations/transformers/soft_prompt.py ================================================ import inspect import logging import random from typing import Any, Dict, Optional import torch from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import ( CausalLMOutputWithCrossAttentions, Seq2SeqModelOutput, ) from tango.integrations.torch import Model logger = logging.getLogger(__name__) def _get_bound_args_with_decorators(fn, *args, **kwargs): while True: try: fn = fn.__wrapped__ except AttributeError: break signature = inspect.Signature.from_callable(fn) return signature.bind(*args, **kwargs) def add_soft_prompt( model: Model, prompt_length: int, *, only_prompt_is_trainable: bool = True, initialize_from_top_embeddings: Optional[int] = 5000, random_seed: int = 1940, ) -> None: """ Takes a regular huggingface transformer, and equips it with a soft prompt. Example: .. testcode:: import transformers model = transformers.AutoModelForCausalLM.from_pretrained("gpt2") tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2") generated = model.generate(tokenizer.encode("It was the best of times.", return_tensors="pt")) original_output = tokenizer.decode(generated[0]) add_soft_prompt(model, prompt_length=3) generated = model.generate(tokenizer.encode("It was the best of times.", return_tensors="pt")) prompted_output = tokenizer.decode(generated[0]) :param model: the original huggingface transformer. This model is augmented in-place! :param prompt_length: the length of the soft prompt, in tokens :param only_prompt_is_trainable: freezes the original model's weights, leaving only the prompt trainable :param initialize_from_top_embeddings: Prompt embeddings are initialized from a random selection of the top n word piece embeddings from the original model. This is how you set n. :param random_seed: random seed used to initialize the prompt embeddings """ assert isinstance(model, PreTrainedModel) original_embedding: nn.Embedding = model.get_input_embeddings() # type: ignore prompt_embedding = nn.Parameter( torch.empty( 1, prompt_length, original_embedding.embedding_dim, dtype=original_embedding.weight.dtype, device=original_embedding.weight.device, ) ) r = random.Random(random_seed) if initialize_from_top_embeddings is None: initialize_from_top_embeddings = original_embedding.num_embeddings indices = torch.tensor(r.sample(range(initialize_from_top_embeddings), prompt_length)) with torch.no_grad(): prompt_embedding.copy_(original_embedding(indices).unsqueeze(0)) if only_prompt_is_trainable: for param in model.parameters(): param.requires_grad = False # find unique parameter name parameter_name = "prompt_embedding" parameter_name_index = 0 while True: try: model.get_parameter(parameter_name) except AttributeError: break parameter_name_index += 1 parameter_name = f"prompt_embedding_{parameter_name_index}" model.register_parameter(parameter_name, prompt_embedding) def patch_tensor(kwargs: Dict[str, torch.Tensor], key: str, value: Any = 0) -> None: t = kwargs.get(key) if t is None: return prefix = t.new_full((t.size(0), prompt_length) + t.shape[2:], value) kwargs[key] = torch.cat([prefix, t], dim=1) def patch_tensor_with_indices( kwargs: Dict[str, torch.Tensor], key: str, offset: int = 0 ) -> None: t = kwargs.get(key) if t is None: return kwargs[key] = torch.cat( [ torch.arange(0, prompt_length, dtype=t.dtype) .unsqueeze(0) .expand(t.size(0), prompt_length), t + offset, ], dim=1, ) old_forward = model.forward def new_forward(*args, **kwargs): # Massage the input to include the prompt if kwargs.get("past_key_values") is not None: # If we have already been running this model, we don't need to do anything with the prefix now. return old_forward(*args, **kwargs) if kwargs.get("encoder_outputs") is not None: # For encoder/decoder models, this runs only on the encoder. If we already have encoder outputs, # we don't have to do anything. return old_forward(*args, **kwargs) inputs_embeds: Optional[torch.Tensor] = None input_ids = kwargs.pop("input_ids", None) if input_ids is not None: inputs_embeds = original_embedding(input_ids) inputs_embeds = kwargs.get("inputs_embeds", inputs_embeds) if inputs_embeds is not None: kwargs["inputs_embeds"] = torch.cat( [prompt_embedding.expand(inputs_embeds.size(0), -1, -1), inputs_embeds], dim=1 ) patch_tensor(kwargs, "labels") patch_tensor(kwargs, "attention_mask", 1) patch_tensor(kwargs, "token_type_ids") patch_tensor_with_indices(kwargs, "position_ids", prompt_length) # Run the model result = old_forward(*args, **kwargs) # Massage the output to look like the prompt was never there unpatch_tensor = lambda t: t[:, prompt_length:] # noqa: E731 unpatch_attention_tensor = lambda t: t[:, :, prompt_length:] # noqa: E731 unpatch_kv_tensor = unpatch_attention_tensor if isinstance(result, CausalLMOutputWithCrossAttentions): if result.logits is not None: result.logits = unpatch_tensor(result.logits) if result.hidden_states is not None: result.hidden_states = tuple(map(unpatch_tensor, result.hidden_states)) if result.attentions is not None: result.attentions = tuple(map(unpatch_attention_tensor, result.attentions)) if result.cross_attentions is not None: result.cross_attentions = tuple( map(unpatch_attention_tensor, result.cross_attentions) ) return result elif isinstance(result, Seq2SeqModelOutput): if result.last_hidden_state is not None: result.last_hidden_state = unpatch_tensor(result.last_hidden_state) if result.past_key_values is not None: result.past_key_values = tuple(map(unpatch_kv_tensor, result.past_key_values)) if result.encoder_hidden_states is not None: result.hidden_states = tuple(map(unpatch_tensor, result.hidden_states)) if result.encoder_attentions is not None: result.attentions = tuple(map(unpatch_attention_tensor, result.attentions)) if result.cross_attentions is not None: result.cross_attentions = tuple( map(unpatch_attention_tensor, result.cross_attentions) ) return result else: logger.warning( "Unexpected result type from the transformer in soft_prompt_transformer: `%s`", result.__class__, ) return result model.forward = new_forward # type: ignore # For encoder/decoder models, HF doesn't call `forward()` like it should when you use `generate()`. Instead, it # calls the encoder separately, and then passes the results into `forward()`. So in that case, we have to patch # this too. if model.config.is_encoder_decoder: old_generate = model.generate def new_generate(*args, **kwargs): args = (model,) + args ba = _get_bound_args_with_decorators(old_generate, *args, **kwargs) del ba.arguments["self"] if "encoder_outputs" in ba.arguments: # For encoder/decoder models, this runs only on the encoder. If we already have encoder outputs, # we don't have to do anything. return old_generate(*ba.args, **ba.kwargs) inputs_embeds: Optional[torch.Tensor] = None inputs = ba.arguments.pop("inputs", None) if inputs is not None: inputs_embeds = original_embedding(inputs) inputs_embeds = ba.arguments.pop("inputs_embeds", inputs_embeds) if inputs_embeds is not None: inputs_embeds = torch.cat( [prompt_embedding.expand(inputs_embeds.size(0), -1, -1), inputs_embeds], dim=1 ) assert callable(model.get_encoder) encoder = model.get_encoder() kwargs = ba.kwargs kwargs["encoder_outputs"] = encoder(inputs_embeds=inputs_embeds, return_dict=True) return old_generate(*ba.args, **kwargs) model.generate = new_generate # type: ignore def _with_soft_prompt( model: Model, prompt_length: int, *, only_prompt_is_trainable: bool = True, initialize_from_top_embeddings: Optional[int] = 5000, random_seed: int = 1940, ) -> Model: """To initialize a soft-prompt model as a Registrable (i.e., to use it from a config file), we need a variant of this function that returns the resulting model. This is that variant.""" add_soft_prompt( model, prompt_length, only_prompt_is_trainable=only_prompt_is_trainable, initialize_from_top_embeddings=initialize_from_top_embeddings, random_seed=random_seed, ) return model Model.register("transformers::with_soft_prompt")(_with_soft_prompt) # type: ignore ================================================ FILE: tango/integrations/transformers/tokenizer.py ================================================ from transformers import AutoTokenizer from transformers.tokenization_utils_base import PreTrainedTokenizerBase from tango.common import Registrable class Tokenizer(PreTrainedTokenizerBase, Registrable): """ A :class:`~tango.common.Registrable` version of transformers' :class:`~transformers.PreTrainedTokenizerBase`. """ default_implementation = "auto" """ The default registered implementation just calls :meth:`transformers.AutoTokenizer.from_pretrained()`. """ Tokenizer.register("auto", constructor="from_pretrained")(AutoTokenizer) ================================================ FILE: tango/integrations/wandb/__init__.py ================================================ """ .. important:: To use this integration you should install ``tango`` with the "wandb" extra (e.g. ``pip install tango[wandb]``) or just install the ``wandb`` library after the fact (e.g. ``pip install wandb``). Components for Tango integration with `Weights & Biases `_. Overview -------- The main components provided by this integration are the :class:`WandbWorkspace` and the :class:`WandbTrainCallback`. The :class:`WandbWorkspace` is a :class:`~tango.workspace.Workspace` implementation that is great for collaboration. It tracks Tango runs and steps in the W&B project of your choosing and uses W&B Artifacts to cache step results in the cloud so that they're accessible anywhere. And if you're training PyTorch models via the :class:`~tango.integrations.torch.TorchTrainStep`, you can use the :class:`WandbTrainCallback` to track metrics throughout the run. """ from tango.common.exceptions import IntegrationMissingError try: import wandb except ModuleNotFoundError: raise IntegrationMissingError("wandb") __all__ = ["WandbWorkspace", "WandbStepCache"] from .step_cache import WandbStepCache from .workspace import WandbWorkspace try: import torch except ModuleNotFoundError: pass else: from .torch_train_callback import WandbTrainCallback __all__.append("WandbTrainCallback") try: import flax import jax import tensorflow # flax has a tensorflow dependency except ModuleNotFoundError: pass else: from .flax_train_callback import WandbFlaxTrainCallback __all__.append("WandbFlaxTrainCallback") ================================================ FILE: tango/integrations/wandb/flax_train_callback.py ================================================ from typing import Any, Dict, List, Optional import jax import wandb from flax import jax_utils from tango.common.exceptions import ConfigurationError from tango.integrations.flax.train_callback import TrainCallback from .workspace import WandbWorkspace @TrainCallback.register("wandb::log_flax") class WandbFlaxTrainCallback(TrainCallback): """ A flax :class:`~tango.integrations.flax.TrainCallback` for use with the :class:`~tango.integrations.flax.FlaxTrainStep` that logs training and validation metrics to W&B. This can be used with any :class:`~tango.workspace.Workspace` implementation, including :class:`WandbWorkspace`. .. tip:: Registered as a :class:`~tango.integrations.flax.TrainCallback` under the name "wandb::log_flax". .. important:: When this callback is used with the :class:`WandbWorkspace` it will log metrics to the same W&B project that the workspace uses. The ``group`` and ``name`` parameters will also automatically be set, so a :class:`~tango.common.exceptions.ConfigurationError` will be raised if any of ``project``, ``entity``, ``group``, or ``name`` are set in this callback. :param project: W&B project to associated this run with. :param entity: W&B entity (user or organization) to associated this run with. :param group: W&B group to associated this run with. :param name: Set the name of the run in W&B. If not set, the default will be the name of the step. :param notes: Arbitrary notes to add in W&B to this run. :param tags: Arbitrary tags to add in W&B to this run. :param watch_model: If ``True``, ``wandb.watch()`` is called to collect gradients and other information about the model throughout training. See `docs.wandb.ai/ref/python/watch `_. :param wandb_config: Arbitrary configuration fields to set in W&B for this run. See `docs.wandb.ai/guides/track/config `_. """ def __init__( self, *args, project: Optional[str] = None, entity: Optional[str] = None, group: Optional[str] = None, name: Optional[str] = None, notes: Optional[str] = None, tags: Optional[List[str]] = None, watch_model: bool = False, wandb_config: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) if isinstance(self.workspace, WandbWorkspace) or wandb.run is not None: err_msg_template = "Cannot set '{var_name}' in WandbTrainCallback " if isinstance(self.workspace, WandbWorkspace): err_msg_template += "since it has already been set from the WandbWorkspace." else: err_msg_template += "since a W&B run has already been initialized." for var, var_name in [ (project, "project"), (entity, "entity"), (group, "group"), (name, "name"), ]: if var is not None: raise ConfigurationError(err_msg_template.format(var_name=var_name)) self.project = ( project if not isinstance(self.workspace, WandbWorkspace) else self.workspace.project ) self.entity = ( entity if not isinstance(self.workspace, WandbWorkspace) else self.workspace.entity ) self.group = group or self.step_id self.notes = notes self.tags = tags self.watch_model = watch_model self.wandb_config = self.train_config.as_dict() if wandb_config is not None: self.wandb_config.update(wandb_config) if wandb.run is None: self.wandb_config["job_type"] = "train_metrics" self.run_name: str = name or self.step_name or "train" self.run_id: str = ( wandb.run.id if wandb.run is not None else self.step_id # type: ignore[attr-defined] ) self.resume: Optional[str] = None self.should_finalize_run: bool = ( wandb.run is None ) # if we have to start out own W&B run, we need to finish it def state_dict(self) -> Dict[str, Any]: return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.resume = "allow" def pre_train_loop(self) -> None: if wandb.run is None: if self.run_id is None: self.run_id = self.step_id wandb.init( id=self.run_id, dir=str(self.work_dir), project=self.project, entity=self.entity, group=self.group, name=self.run_name, notes=self.notes, config=self.wandb_config, tags=self.tags, job_type="train_metrics", ) else: # We are already running inside of a W&B run, possibly because # we're using the WandbWorkspace. wandb.config.update(self.wandb_config) if self.tags: wandb.run.tags = (wandb.run.tags or tuple()) + tuple(self.tags) if self.notes: wandb.run.notes = self.notes if self.watch_model: wandb.watch(self.model) def post_train_loop(self, step: int, epoch: int) -> None: if self.should_finalize_run: wandb.finish() def log_batch(self, step: int, epoch: int, train_metrics: Dict) -> None: if len(jax.devices()) > 1: train_metrics = jax_utils.unreplicate(train_metrics) metrics = {"train/loss": train_metrics["loss"], "epoch": epoch} wandb.log(metrics, step=step + 1) def post_val_loop( self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float] ) -> None: wandb.log( { f"val/{self.train_config.val_metric_name}": val_metric, f"val/best_{self.train_config.val_metric_name}": best_val_metric, "epoch": epoch, }, step=step + 1, ) ================================================ FILE: tango/integrations/wandb/step_cache.py ================================================ import logging from typing import Any, Optional, Union import wandb from retry import retry from wandb.errors import Error as WandbError from tango.common.aliases import PathOrStr from tango.common.util import make_safe_filename, tango_cache_dir from tango.step import Step from tango.step_cache import StepCache from tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache from tango.step_info import StepInfo from .util import ArtifactKind, check_environment, is_missing_artifact_error logger = logging.getLogger(__name__) @StepCache.register("wandb") class WandbStepCache(RemoteStepCache): """ This is a :class:`~tango.step_cache.StepCache` that's used by :class:`WandbWorkspace`. It stores the results of steps on W&B as Artifacts. It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a step's resulting subsequent times should be fast. :param project: The W&B project to use. :param entity: The W&B entity (user or organization account) to use. .. tip:: Registered as :class:`~tango.step_cache.StepCache` under the name "wandb". """ def __init__(self, project: str, entity: str): check_environment() super().__init__( tango_cache_dir() / "wandb_cache" / make_safe_filename(entity) / make_safe_filename(project) ) self.project = project self.entity = entity @property def wandb_client(self) -> wandb.Api: return wandb.Api(overrides={"entity": self.entity, "project": self.project}) @property def client(self): """ To maintain compatibility """ return self.wandb_client @property def wandb_project_url(self) -> str: """ The URL of the W&B project this workspace uses. """ app_url = self.wandb_client.client.app_url app_url = app_url.rstrip("/") return f"{app_url}/{self.entity}/{self.project}" def _step_artifact_name(self, step: Union[Step, StepInfo]) -> str: if isinstance(step, Step): return step.class_name else: return step.step_class_name def _step_result_remote( # type: ignore self, step: Union[Step, StepInfo] ) -> Optional[wandb.Artifact]: artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value) try: return self.wandb_client.artifact( f"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}", type=artifact_kind, ) except WandbError as exc: if is_missing_artifact_error(exc): return None else: raise def create_step_result_artifact(self, step: Step, objects_dir: Optional[PathOrStr] = None): self._upload_step_remote(step, objects_dir) def get_step_result_artifact(self, step: Union[Step, StepInfo]) -> Optional[wandb.Artifact]: artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value) try: return self.wandb_client.artifact( f"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}", type=artifact_kind, ) except WandbError as exc: if is_missing_artifact_error(exc): return None else: raise def _upload_step_remote(self, step: Step, objects_dir: Optional[PathOrStr] = None) -> Any: """ Create an artifact for the result of a step. """ artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value) artifact = wandb.Artifact(self._step_artifact_name(step), type=artifact_kind) # Add files if objects_dir is not None: artifact.add_dir(str(objects_dir)) # Log/persist the artifact to W&B. artifact.save() artifact.wait() # Add an alias for the step's unique ID. # Only after we've logged the artifact can we add an alias. artifact.aliases.append(step.unique_id) artifact.save() artifact.wait() def get_step_result_artifact_url(self, step: Union[Step, StepInfo]) -> str: artifact_kind = (step.metadata or {}).get("artifact_kind", ArtifactKind.STEP_RESULT.value) return ( f"{self.wandb_project_url}/artifacts/{artifact_kind}" f"/{self._step_artifact_name(step)}/{step.unique_id}" ) @retry(exceptions=(wandb.errors.CommError,), delay=10, backoff=2, max_delay=120) def use_step_result_artifact(self, step: Union[Step, StepInfo]) -> None: """ "Use" the artifact corresponding to the result of a step. """ if wandb.run is None: raise RuntimeError("This can only be called from within a W&B run") wandb.run.use_artifact( f"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}" ) def _download_step_remote(self, step_result, target_dir: PathOrStr): try: step_result.download(root=target_dir) except (WandbError, ValueError): raise RemoteNotFoundError() def __len__(self) -> int: completed_cacheable_step_runs = self.wandb_client.runs( f"{self.entity}/{self.project}", filters={ # type: ignore "config.job_type": "step", "config.cacheable": True, "state": "finished", }, ) return len(list(completed_cacheable_step_runs)) ================================================ FILE: tango/integrations/wandb/torch_train_callback.py ================================================ from typing import Any, Dict, List, Optional import torch import wandb from tango.common.exceptions import ConfigurationError from tango.integrations.torch.train_callback import TrainCallback from tango.integrations.torch.util import peak_gpu_memory from .util import check_environment from .workspace import WandbWorkspace @TrainCallback.register("wandb::log") class WandbTrainCallback(TrainCallback): """ A torch :class:`~tango.integrations.torch.TrainCallback` for use with the :class:`~tango.integrations.torch.TorchTrainStep` that logs training and validation metrics to W&B. This can be used with any :class:`~tango.workspace.Workspace` implementation, including :class:`WandbWorkspace`. .. tip:: Registered as a :class:`~tango.integrations.torch.TrainCallback` under the name "wandb::log". .. important:: When this callback is used with the :class:`WandbWorkspace` it will log metrics to the same W&B project that the workspace uses. The ``group`` and ``name`` parameters will also automatically be set, so a :class:`~tango.common.exceptions.ConfigurationError` will be raised if any of ``project``, ``entity``, ``group``, or ``name`` are set in this callback. :param project: W&B project to associated this run with. :param entity: W&B entity (user or organization) to associated this run with. :param group: W&B group to associated this run with. :param name: Set the name of the run in W&B. If not set, the default will be the name of the step. :param notes: Arbitrary notes to add in W&B to this run. :param tags: Arbitrary tags to add in W&B to this run. :param watch_model: If ``True``, ``wandb.watch()`` is called to collect gradients and other information about the model throughout training. See `docs.wandb.ai/ref/python/watch `_. :param wandb_config: Arbitrary configuration fields to set in W&B for this run. See `docs.wandb.ai/guides/track/config `_. """ def __init__( self, *args, project: Optional[str] = None, entity: Optional[str] = None, group: Optional[str] = None, name: Optional[str] = None, notes: Optional[str] = None, tags: Optional[List[str]] = None, watch_model: bool = False, wandb_config: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: super().__init__(*args, **kwargs) if self.is_local_main_process: check_environment() if isinstance(self.workspace, WandbWorkspace) or wandb.run is not None: err_msg_template = "Cannot set '{var_name}' in WandbTrainCallback " if isinstance(self.workspace, WandbWorkspace): err_msg_template += "since it has already been set from the WandbWorkspace." else: err_msg_template += "since a W&B run has already been initialized." for var, var_name in [ (project, "project"), (entity, "entity"), (group, "group"), (name, "name"), ]: if var is not None: raise ConfigurationError(err_msg_template.format(var_name=var_name)) self.project = ( project if not isinstance(self.workspace, WandbWorkspace) else self.workspace.project ) self.entity = ( entity if not isinstance(self.workspace, WandbWorkspace) else self.workspace.entity ) self.group = group or self.step_id self.notes = notes or self._get_default_notes() self.tags = tags self.watch_model = watch_model self.wandb_config = self.train_config.as_dict() del self.wandb_config["worker_id"] if wandb_config is not None: self.wandb_config.update(wandb_config) if wandb.run is None: self.wandb_config["job_type"] = "train_metrics" self.run_name: str = name or self.step_name or "train" if self.train_config.is_distributed: self.run_name += f" (rank {self.train_config.worker_id})" self.run_id: str = ( wandb.run.id # type: ignore[attr-defined] if wandb.run is not None else self.step_id + f"-rank{self.train_config.worker_id}" ) self.resume: Optional[str] = None self.should_finalize_run: bool = ( wandb.run is None ) # if we have to start out own W&B run, we need to finish it def state_dict(self) -> Dict[str, Any]: return {} def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.resume = "allow" def pre_train_loop(self) -> None: if wandb.run is None: if self.run_id is None: self.run_id = self.step_id + f"-rank{self.train_config.worker_id}" # Initialize a new W&B run. wandb.init( id=self.run_id, dir=str(self.work_dir), project=self.project, entity=self.entity, group=self.group, name=self.run_name, notes=self.notes, config=self.wandb_config, tags=self.tags, job_type="train_metrics", ) else: # We are already running inside of a W&B run, possibly because # we're using the WandbWorkspace. wandb.config.update(self.wandb_config) if self.tags: wandb.run.tags = (wandb.run.tags or tuple()) + tuple(self.tags) if self.notes: wandb.run.notes = self.notes if self.watch_model: wandb.watch(self.training_engine.model) # Log GPU memory statistics. if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() peak_gpu_mbs = peak_gpu_memory() if self.is_local_main_process: metrics = {f"sys/worker{rank}_peak_gpu_mem": mbs for rank, mbs in peak_gpu_mbs.items()} metrics["epoch"] = 0 wandb.log(metrics, step=0) def post_train_loop(self, step: int, epoch: int) -> None: if self.should_finalize_run: wandb.finish() def log_batch( self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]] ) -> None: peak_gpu_mbs = peak_gpu_memory() if self.is_local_main_process: metrics = { "train/loss": batch_loss, "train/lr": self.training_engine.optimizer.param_groups[0]["lr"], "epoch": epoch, } metrics.update( {f"sys/worker{rank}_peak_gpu_mem": mbs for rank, mbs in peak_gpu_mbs.items()} ) wandb.log( metrics, step=step + 1, ) def post_val_loop( self, step: int, epoch: int, val_metric: float, best_val_metric: float ) -> None: if self.is_local_main_process: wandb.log( { f"val/{self.train_config.val_metric_name}": val_metric, f"val/best_{self.train_config.val_metric_name}": best_val_metric, "epoch": epoch, }, step=step + 1, ) def _get_default_notes(self) -> str: notes = ( f'Metrics for Tango step "{self.step_name}" from worker {self.train_config.worker_id}.' ) if isinstance(self.workspace, WandbWorkspace): notes += f"\nMain run for step: {self.workspace.wandb_project_url}/runs/{self.step_id}/overview" return notes ================================================ FILE: tango/integrations/wandb/util.py ================================================ import os import re import warnings from enum import Enum from wandb.errors import Error as WandbError _API_KEY_WARNING_ISSUED = False _SILENCE_WARNING_ISSUED = False def is_missing_artifact_error(err: WandbError): """ Check if a specific W&B error is caused by a 404 on the artifact we're looking for. """ # This is brittle, but at least we have a test for it. # This is a workaround for a bug in the wandb API if err.message == "'NoneType' object has no attribute 'get'": return True if re.search(r"^artifact '.*' not found in '.*'$", err.message): return True return ("does not contain artifact" in err.message) or ( "Unable to fetch artifact with name" in err.message ) def check_environment(): global _API_KEY_WARNING_ISSUED, _SILENCE_WARNING_ISSUED if "WANDB_API_KEY" not in os.environ and not _API_KEY_WARNING_ISSUED: warnings.warn( "Missing environment variable 'WANDB_API_KEY' required to authenticate to Weights & Biases.", UserWarning, ) _API_KEY_WARNING_ISSUED = True if "WANDB_SILENT" not in os.environ and not _SILENCE_WARNING_ISSUED: warnings.warn( "The Weights & Biases client may produce a lot of log messages. " "You can silence these by setting the environment variable 'WANDB_SILENT=true'", UserWarning, ) _SILENCE_WARNING_ISSUED = True class RunKind(Enum): STEP = "step" TANGO_RUN = "tango_run" class ArtifactKind(Enum): STEP_RESULT = "step_result" ================================================ FILE: tango/integrations/wandb/workspace.py ================================================ import logging import tempfile from datetime import datetime from pathlib import Path from typing import Any, Dict, Iterable, Iterator, Optional, TypeVar, Union from urllib.parse import ParseResult import pytz import wandb from tango.common.exceptions import StepStateError from tango.common.file_lock import FileLock from tango.common.util import exception_to_string, tango_cache_dir, utc_now_datetime from tango.step import Step from tango.step_cache import StepCache from tango.step_info import StepInfo, StepState from tango.workspace import Run, Workspace from .step_cache import WandbStepCache from .util import RunKind, check_environment T = TypeVar("T") logger = logging.getLogger(__name__) @Workspace.register("wandb") class WandbWorkspace(Workspace): """ This is a :class:`~tango.workspace.Workspace` that tracks Tango runs in a W&B project. It also stores step results as W&B Artifacts via :class:`WandbStepCache`. Each Tango run with this workspace will generate multiple runs in your W&B project. There will always be a W&B run corresponding to each Tango run with the same name, which will contain some metadata about the Tango run. Then there will be one W&B run for each cacheable step that runs with a name corresponding to the name of the step. So if your Tango run includes 3 cacheable steps, that will result in a total of 4 new runs in W&B. :param project: The W&B project to use for the workspace. :param entity: The W&B entity (user or organization account) to use for the workspace. .. tip:: Registered as a :class:`~tango.workspace.Workspace` under the name "wandb". .. tip:: If you want to change the artifact kind for step result artifacts uploaded to W&B, add a field called ``artifact_kind`` to the ``metadata`` of the :class:`~tango.step.Step` class. This can be useful if you want model objects to be added to the model zoo. In that you would set ``artifact_kind = "model"``. For example, your config for the step would look like this: .. code-block:: { type: "trainer", step_metadata: { artifact_kind: "model" }, ... } Or just add this to the ``METADATA`` class attribute: .. code-block:: @Step.register("trainer") class TrainerStep(Step): METADATA = {"artifact_kind": "model"} """ def __init__(self, project: str, entity: Optional[str] = None): check_environment() super().__init__() self.project = project self._entity = entity self.cache = WandbStepCache(project=self.project, entity=self.entity) self.steps_dir = tango_cache_dir() / "wandb_workspace" self.locks: Dict[Step, FileLock] = {} self._running_step_info: Dict[str, StepInfo] = {} def __getstate__(self): """ We override `__getstate__()` to customize how instances of this class are pickled since we don't want to persist certain attributes. """ out = super().__getstate__() out["locks"] = {} return out @property def wandb_client(self) -> wandb.Api: overrides = {"project": self.project} if self._entity is not None: overrides["entity"] = self._entity return wandb.Api(overrides=overrides) @property def entity(self) -> str: return self._entity or self.wandb_client.default_entity @property def url(self) -> str: return f"wandb://{self.entity}/{self.project}" @classmethod def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace: entity = parsed_url.netloc project = parsed_url.path if project: project = project.strip("/") return cls(project=project, entity=entity) @property def step_cache(self) -> StepCache: return self.cache @property def wandb_project_url(self) -> str: """ The URL of the W&B project this workspace uses. """ app_url = self.wandb_client.client.app_url app_url = app_url.rstrip("/") return f"{app_url}/{self.entity}/{self.project}" def _get_unique_id(self, step_or_unique_id: Union[Step, str]) -> str: if isinstance(step_or_unique_id, Step): unique_id = step_or_unique_id.unique_id else: unique_id = step_or_unique_id return unique_id def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path: unique_id = self._get_unique_id(step_or_unique_id) path = self.steps_dir / unique_id path.mkdir(parents=True, exist_ok=True) return path def work_dir(self, step: Step) -> Path: path = self.step_dir(step) / "work" path.mkdir(parents=True, exist_ok=True) return path def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: unique_id = self._get_unique_id(step_or_unique_id) if unique_id in self._running_step_info: return self._running_step_info[unique_id] step_info = self._get_updated_step_info( unique_id, step_name=step_or_unique_id.name if isinstance(step_or_unique_id, Step) else None, ) if step_info is None: raise KeyError(step_or_unique_id) else: return step_info def step_starting(self, step: Step) -> None: if wandb.run is not None: raise RuntimeError( "There is already a W&B run initialized, cannot initialize another one." ) work_dir = self.work_dir(step) lock_path = self.step_dir(step) / "lock" lock = FileLock(lock_path, read_only_ok=True) lock.acquire_with_updates(desc=f"acquiring lock for '{step.name}'") self.locks[step] = lock step_info = self._get_updated_step_info(step.unique_id) or StepInfo.new_from_step(step) if step_info.state not in {StepState.INCOMPLETE, StepState.FAILED, StepState.UNCACHEABLE}: raise StepStateError( step, step_info.state, context="If you are certain the step is not running somewhere else, delete the lock " f"file at {lock_path}.", ) try: # Initialize W&B run for the step. wandb.init( name=step_info.step_name, job_type=RunKind.STEP.value, group=step.unique_id, dir=str(work_dir), entity=self.entity, project=self.project, # For cacheable steps we can just use the step's unique ID as the W&B run ID, # but not for uncacheable steps since those might be ran more than once, and # and will need a unique W&B run ID each time. id=step.unique_id if step.cache_results else None, resume="allow" if step.cache_results else None, notes="\n".join( [ f'Tango step "{step.name}"', f"\N{bullet} type: {step_info.step_class_name}", f"\N{bullet} ID: {step.unique_id}", ] ), config={ "job_type": RunKind.STEP.value, "_run_suite_id": self._generate_run_suite_id(), # used for testing only }, ) assert wandb.run is not None logger.info( "Tracking '%s' step on Weights and Biases: %s/runs/%s/overview", step.name, self.wandb_project_url, wandb.run.id, ) # "Use" all of the result artifacts for this step's dependencies in order to declare # those dependencies to W&B. for dependency in step.dependencies: self.cache.use_step_result_artifact(dependency) # Update StepInfo to mark as running. step_info.start_time = utc_now_datetime() step_info.end_time = None step_info.error = None step_info.result_location = None wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True) self._running_step_info[step.unique_id] = step_info except: # noqa: E722 lock.release() del self.locks[step] raise def step_finished(self, step: Step, result: T) -> T: if wandb.run is None: raise RuntimeError( f"{self.__class__.__name__}.step_finished() called outside of a W&B run. " f"Did you forget to call {self.__class__.__name__}.step_starting() first?" ) step_info = self._running_step_info.get(step.unique_id) or self._get_updated_step_info( step.unique_id ) if step_info is None: raise KeyError(step.unique_id) try: if step.cache_results: self.step_cache[step] = result if hasattr(result, "__next__"): assert isinstance(result, Iterator) # Caching the iterator will consume it, so we write it to the # cache and then read from the cache for the return value. result = self.step_cache[step] step_info.result_location = self.cache.get_step_result_artifact_url(step) else: # Create an empty artifact in order to build the DAG in W&B. self.cache.create_step_result_artifact(step) step_info.end_time = utc_now_datetime() wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True) # Finalize the step's W&B run. wandb.finish() finally: self.locks[step].release() del self.locks[step] if step.unique_id in self._running_step_info: del self._running_step_info[step.unique_id] return result def step_failed(self, step: Step, e: BaseException) -> None: if wandb.run is None: raise RuntimeError( f"{self.__class__.__name__}.step_failed() called outside of a W&B run. " f"Did you forget to call {self.__class__.__name__}.step_starting() first?" ) step_info = self._running_step_info.get(step.unique_id) or self._get_updated_step_info( step.unique_id ) if step_info is None: raise KeyError(step.unique_id) try: # Update StepInfo, marking the step as failed. if step_info.state != StepState.RUNNING: raise StepStateError(step, step_info.state) step_info.end_time = utc_now_datetime() step_info.error = exception_to_string(e) wandb.run.config.update({"step_info": step_info.to_json_dict()}, allow_val_change=True) # Finalize the step's W&B run. wandb.finish(exit_code=1) finally: self.locks[step].release() del self.locks[step] if step.unique_id in self._running_step_info: del self._running_step_info[step.unique_id] def remove_step(self, step_unique_id: str): """ Removes cached step using the given unique step id :raises KeyError: If there is no step with the given name. """ raise NotImplementedError() def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: all_steps = set(targets) for step in targets: all_steps |= step.recursive_dependencies wandb_run_id: str wandb_run_name: str with tempfile.TemporaryDirectory() as temp_dir_name: with wandb.init( # type: ignore[union-attr] job_type=RunKind.TANGO_RUN.value, entity=self.entity, project=self.project, name=name, dir=temp_dir_name, config={ "job_type": RunKind.TANGO_RUN.value, # need this in the config so we can filter runs by this "_run_suite_id": self._generate_run_suite_id(), # used for testing only }, ) as wandb_run: wandb_run_id = wandb_run.id wandb_run_name = wandb_run.name # type: ignore[assignment] logger.info("Registering run %s with Weights and Biases", wandb_run.name) logger.info( "View run at: %s/runs/%s/overview", self.wandb_project_url, wandb_run_id ) # Collect step info for all steps. step_ids: Dict[str, bool] = {} step_name_to_info: Dict[str, Dict[str, Any]] = {} for step in all_steps: step_info = StepInfo.new_from_step(step) step_name_to_info[step.name] = { k: v for k, v in step_info.to_json_dict().items() if v is not None } step_ids[step.unique_id] = True # Update config with step info. wandb_run.config.update({"steps": step_name_to_info, "_step_ids": step_ids}) # Update notes. notes = "Tango run\n--------------" cacheable_steps = {step for step in all_steps if step.cache_results} if cacheable_steps: notes += "\nCacheable steps:\n" for step in sorted(cacheable_steps, key=lambda step: step.name): notes += f"\N{bullet} {step.name}" dependencies = step.dependencies if dependencies: notes += ", depends on: " + ", ".join( sorted( [f"'{dep.name}'" for dep in dependencies], ) ) notes += "\n \N{rightwards arrow with hook} " notes += f"{self.wandb_project_url}/runs/{step.unique_id}/overview\n" wandb_run.notes = notes return self.registered_run(wandb_run_name) def _generate_run_suite_id(self) -> str: return wandb.util.generate_id() def registered_runs(self) -> Dict[str, Run]: runs: Dict[str, Run] = {} matching_runs = list( self.wandb_client.runs( f"{self.entity}/{self.project}", filters={"config.job_type": RunKind.TANGO_RUN.value}, # type: ignore ) ) for wandb_run in matching_runs: runs[wandb_run.name] = self._get_run_from_wandb_run(wandb_run) return runs def registered_run(self, name: str) -> Run: matching_runs = list( self.wandb_client.runs( f"{self.entity}/{self.project}", filters={"display_name": name, "config.job_type": RunKind.TANGO_RUN.value}, # type: ignore ) ) if not matching_runs: raise KeyError(f"Run '{name}' not found in workspace") elif len(matching_runs) > 1: raise ValueError(f"Found more than one run named '{name}' in W&B project") return self._get_run_from_wandb_run(matching_runs[0]) def _get_run_from_wandb_run( self, wandb_run: wandb.apis.public.Run, ) -> Run: step_name_to_info = {} for step_name, step_info_dict in wandb_run.config["steps"].items(): step_info = StepInfo.from_json_dict(step_info_dict) if step_info.cacheable: updated_step_info = self._get_updated_step_info( step_info.unique_id, step_name=step_name ) if updated_step_info is not None: step_info = updated_step_info step_name_to_info[step_name] = step_info return Run( name=wandb_run.name, steps=step_name_to_info, start_date=datetime.strptime(wandb_run.created_at, "%Y-%m-%dT%H:%M:%S").replace( tzinfo=pytz.utc ), ) def _get_updated_step_info( self, step_id: str, step_name: Optional[str] = None ) -> Optional[StepInfo]: # First try to find the W&B run corresponding to the step. This will only # work if the step execution was started already. filters = { "config.job_type": RunKind.STEP.value, "config.step_info.unique_id": step_id, } if step_name is not None: filters["display_name"] = step_name for wandb_run in self.wandb_client.runs( f"{self.entity}/{self.project}", filters=filters, # type: ignore ): step_info = StepInfo.from_json_dict(wandb_run.config["step_info"]) # Might need to fix the step info the step failed and we failed to update the config. if step_info.start_time is None: step_info.start_time = datetime.strptime( wandb_run.created_at, "%Y-%m-%dT%H:%M:%S" ).replace(tzinfo=pytz.utc) if wandb_run.state in {"failed", "finished"}: if step_info.end_time is None: step_info.end_time = datetime.strptime( wandb_run.heartbeatAt, "%Y-%m-%dT%H:%M:%S" ).replace(tzinfo=pytz.utc) if wandb_run.state == "failed" and step_info.error is None: step_info.error = "Exception" return step_info # If the step hasn't been started yet, we'll have to pull the step info from the # registered run. filters = { "config.job_type": RunKind.TANGO_RUN.value, f"config._step_ids.{step_id}": True, } if step_name is not None: filters[f"config.steps.{step_name}.unique_id"] = step_id for wandb_run in self.wandb_client.runs( f"{self.entity}/{self.project}", filters=filters, # type: ignore ): if step_name is not None: step_info_data = wandb_run.config["steps"][step_name] else: step_info_data = next( d for d in wandb_run.config["steps"].values() if d["unique_id"] == step_id ) step_info = StepInfo.from_json_dict(step_info_data) return step_info return None ================================================ FILE: tango/py.typed ================================================ ================================================ FILE: tango/settings.py ================================================ from dataclasses import dataclass from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional import yaml from .common.aliases import PathOrStr from .common.from_params import FromParams from .common.params import Params @dataclass class TangoGlobalSettings(FromParams): """ Defines global settings for tango. """ workspace: Optional[Dict[str, Any]] = None """ Parameters to initialize a :class:`~tango.workspace.Workspace` with. """ executor: Optional[Dict[str, Any]] = None """ Parameters to initialize an :class:`~tango.executor.Executor` with. """ include_package: Optional[List[str]] = None """ An list of modules where custom registered steps or classes can be found. """ log_level: Optional[str] = None """ The log level to use. Options are "debug", "info", "warning", and "error". .. note:: This does not affect the :data:`~tango.common.logging.cli_logger` or logs from :class:`~tango.common.Tqdm` progress bars. """ file_friendly_logging: Optional[bool] = None """ If this flag is set to ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow down tqdm's output to only once every 10 seconds. """ multiprocessing_start_method: str = "spawn" """ The ``start_method`` to use when starting new multiprocessing workers. Can be "fork", "spawn", or "forkserver". Default is "spawn". See :func:`multiprocessing.set_start_method()` for more details. """ environment: Optional[Dict[str, str]] = None """ Environment variables that will be set each time ``tango`` is run. """ _path: Optional[Path] = None _DEFAULT_LOCATION: ClassVar[Path] = Path.home() / ".config" / "tango.yml" @classmethod def default(cls) -> "TangoGlobalSettings": """ Initialize the config from files by checking the default locations in order, or just return the default if none of the files can be found. """ for directory in (Path("."), cls._DEFAULT_LOCATION.parent): for extension in ("yml", "yaml"): path = directory / f"tango.{extension}" if path.is_file(): return cls.from_file(path) return cls() @classmethod def find_or_default(cls, path: Optional[PathOrStr] = None) -> "TangoGlobalSettings": """ Initialize the config from a given configuration file, or falls back to returning the default configuration if no file is given. """ if path is not None: path = Path(path) if not path.is_file(): raise FileNotFoundError(path) return cls.from_file(path) else: return cls.default() @property def path(self) -> Optional[Path]: """ The path to the file the config was read from. """ return self._path @classmethod def from_file(cls, path: PathOrStr) -> "TangoGlobalSettings": """ Read settings from a file. """ params = Params.from_file(path) params["_path"] = Path(path).resolve() return cls.from_params(params) def to_file(self, path: PathOrStr) -> None: """ Save the settings to a file. """ data = { k: v for k, v in self.to_params().as_dict(quiet=True).items() if not k.startswith("_") } with open(path, "w") as settings_file: yaml.safe_dump(data, settings_file) def save(self) -> None: """ Save the settings to the file it was read from. :raises ValueError: If the settings was not read from a file. """ if self.path is None: raise ValueError("No path given, use .to_file() instead") self.to_file(self.path) ================================================ FILE: tango/step.py ================================================ import inspect import itertools import logging import random import re import warnings from abc import abstractmethod from copy import deepcopy from dataclasses import dataclass from pathlib import Path from tempfile import TemporaryDirectory from typing import ( TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generic, Iterable, Optional, Set, Type, TypeVar, Union, cast, ) from tango.common.det_hash import CustomDetHash, det_hash from tango.common.exceptions import ConfigurationError, StepStateError from tango.common.from_params import ( FromParams, infer_constructor_params, infer_method_params, pop_and_construct_arg, ) from tango.common.lazy import Lazy from tango.common.logging import cli_logger, log_exception from tango.common.params import Params from tango.common.registrable import Registrable from tango.format import DillFormat, Format try: from typing import get_args, get_origin # type: ignore except ImportError: def get_origin(tp): # type: ignore return getattr(tp, "__origin__", None) def get_args(tp): # type: ignore return getattr(tp, "__args__", ()) if TYPE_CHECKING: from tango.workspace import Workspace _version_re = re.compile("""^[a-zA-Z0-9]+$""") T = TypeVar("T") _random_for_step_names = random.Random() @dataclass class StepResources(FromParams): """ TaskResources describe minimum external hardware requirements which must be available for a step to run. """ machine: Optional[str] = None """ This is an executor-dependent option. With the Beaker executor, for example, you can set this to "local" to force the executor to run the step locally instead of on Beaker. """ cpu_count: Optional[float] = None """ Minimum number of logical CPU cores. It may be fractional. Examples: ``4``, ``0.5``. """ gpu_count: Optional[int] = None """ Minimum number of GPUs. It must be non-negative. """ gpu_type: Optional[str] = None """ The type of GPU that the step requires. The exact string you should use to define a GPU type depends on the executor. With the Beaker executor, for example, you should use the same strings you see in the Beaker UI, such as 'NVIDIA A100-SXM-80GB'. """ memory: Optional[str] = None """ Minimum available system memory as a number with unit suffix. Examples: ``2.5GiB``, ``1024m``. """ shared_memory: Optional[str] = None """ Size of ``/dev/shm`` as a number with unit suffix. Examples: ``2.5GiB``, ``1024m``. """ class Step(Registrable, Generic[T]): """ This class defines one step in your experiment. To write your own step, derive from this class and overwrite the :meth:`run()` method. The :meth:`run()` method must have parameters with type hints. ``Step.__init__()`` takes all the arguments we want to run the step with. They get passed to :meth:`run()` (almost) as they are. If the arguments are other instances of ``Step``, those will be replaced with the step's results before calling :meth:`run()`. Further, there are four special parameters: :param step_name: contains an optional human-readable name for the step. This name is used for error messages and the like, and has no consequence on the actual computation. :param cache_results: specifies whether the results of this step should be cached. If this is ``False``, the step is recomputed every time it is needed. If this is not set at all, and :attr:`CACHEABLE` is ``True``, we cache if the step is marked as :attr:`DETERMINISTIC`, and we don't cache otherwise. :param step_format: gives you a way to override the step's default format (which is given in :attr:`FORMAT`). :param step_config: is the original raw part of the experiment config corresponding to this step. This can be accessed via the :attr:`config` property within each step's :meth:`run()` method. :param step_unique_id_override: overrides the construction of the step's unique id using the hash of inputs. :param step_resources: gives you a way to set the minimum compute resources required to run this step. Certain executors require this information. :param step_metadata: use this to specify additional metadata for your step. This is added to the :attr:`METADATA` class variable to form the ``self.metadata`` attribute. Values in ``step_metadata`` take precedence over ``METADATA``. :param step_extra_dependencies: use this to force a dependency on other steps. Normally dependencies between steps are determined by the inputs and outputs of the steps, but you can use this parameter to force that other steps run before this step even if this step doesn't explicitly depend on the outputs of those steps. .. important:: Overriding the unique id means that the step will always map to this value, regardless of the inputs, and therefore, the step cache will only hold a single copy of the step's output (from the last execution). Thus, in most cases, this should not be used when constructing steps. We include this option for the case when the executor creates subprocesses, which also need to access the *same* ``Step`` object. """ DETERMINISTIC: bool = True """This describes whether this step can be relied upon to produce the same results every time when given the same inputs. If this is ``False``, you can still cache the output of the step, but the results might be unexpected. Tango will print a warning in this case.""" CACHEABLE: Optional[bool] = None """This provides a direct way to turn off caching. For example, a step that reads a HuggingFace dataset doesn't need to be cached, because HuggingFace datasets already have their own caching mechanism. But it's still a deterministic step, and all following steps are allowed to cache. If it is ``None``, the step figures out by itself whether it should be cacheable or not.""" VERSION: Optional[str] = None """This is optional, but recommended. Specifying a version gives you a way to tell Tango that a step has changed during development, and should now be recomputed. This doesn't invalidate the old results, so when you revert your code, the old cache entries will stick around and be picked up.""" FORMAT: Format = DillFormat("gz") """This specifies the format the results of this step will be serialized in. See the documentation for :class:`~tango.format.Format` for details.""" SKIP_ID_ARGUMENTS: Set[str] = set() """If your :meth:`run()` method takes some arguments that don't affect the results, list them here. Arguments listed here will not be used to calculate this step's unique ID, and thus changing those arguments does not invalidate the cache. For example, you might use this for the batch size in an inference step, where you only care about the model output, not about how many outputs you can produce at the same time. """ SKIP_DEFAULT_ARGUMENTS: Dict[str, Any] = {} """Sometimes, you want to add another argument to your :meth:`run()` method, but you don't want to invalidate the cache when this new argument is set to its default value. If that is the case, add the argument to this dictionary with the default value that should be ignored.""" METADATA: Dict[str, Any] = {} """ Arbitrary metadata about the step. """ _UNIQUE_ID_SUFFIX: Optional[str] = None """ Used internally for testing. """ def __init__( self, step_name: Optional[str] = None, cache_results: Optional[bool] = None, step_format: Optional[Format] = None, step_config: Optional[Union[Dict[str, Any], Params]] = None, step_unique_id_override: Optional[str] = None, step_resources: Optional[StepResources] = None, step_metadata: Optional[Dict[str, Any]] = None, step_extra_dependencies: Optional[Iterable["Step"]] = None, **kwargs, ): if self.VERSION is not None: assert _version_re.match( self.VERSION ), f"Invalid characters in version '{self.VERSION}'" run_defaults = { k: v.default for k, v in inspect.signature(self.run).parameters.items() if v.default is not inspect.Parameter.empty } self.kwargs = self.massage_kwargs({**run_defaults, **kwargs}) if step_format is None: self.format = self.FORMAT if isinstance(self.format, type): self.format = self.format() else: self.format = step_format self.unique_id_cache = step_unique_id_override if step_name is None: self.name = self.unique_id else: self.name = step_name # TODO: It is bad design to have the step_name in the Step class. The same step can be part of multiple # runs at the same time, and they can have different names in different runs. Step names are # a property of the run, not of the step. if cache_results is True: if not self.CACHEABLE: raise ConfigurationError( f"Step {self.name} is configured to use the cache, but it's not a cacheable step." ) if not self.DETERMINISTIC: warnings.warn( f"Step {self.name} is going to be cached despite not being deterministic.", UserWarning, ) self.cache_results = True elif cache_results is False: self.cache_results = False elif cache_results is None: c = (self.DETERMINISTIC, self.CACHEABLE) if c == (False, None): self.cache_results = False elif c == (True, None): self.cache_results = True elif c == (False, False): self.cache_results = False elif c == (True, False): self.cache_results = False elif c == (False, True): warnings.warn( f"Step {self.name} is set to be cacheable despite not being deterministic.", UserWarning, ) self.cache_results = True elif c == (True, True): self.cache_results = True else: assert False, "Step.DETERMINISTIC or step.CACHEABLE are set to an invalid value." else: raise ConfigurationError( f"Step {self.name}'s cache_results parameter is set to an invalid value." ) self._workspace: Optional["Workspace"] = None self.work_dir_for_run: Optional[ Path ] = None # This is set only while the run() method runs. if isinstance(step_config, Params): self._config = step_config.as_dict(quiet=True) else: self._config = step_config assert step_resources is None or isinstance(step_resources, StepResources) self.step_resources = step_resources self.metadata = deepcopy(self.METADATA) if step_metadata: self.metadata.update(step_metadata) self.extra_dependencies = set(step_extra_dependencies) if step_extra_dependencies else set() @property def class_name(self) -> str: return self.__class__.__name__ @classmethod def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: """ Override this method in your step if you want to change the step's arguments before they are passed to the :meth:`run()` method. This can be useful if you want to normalize arguments that are passed to your step. For example, you might not care about the case of a string that's passed in. You can lowercase the string in this method, and the step will function as if it had been created with a lowercase string from the start. This way you can make sure that the step's unique ID does not change when the case of the input changes. .. note:: When the input to a step is another step, this method will see the step in the input, not the other step's result. .. warning:: This is an advanced feature of Tango that you won't need most of the time. By default, this method does nothing and just returns its input unchanged. :param kwargs: The original kwargs that were passed to the step during construction. :return: New kwargs that will be passed to the step's :meth:`run()` method. """ return kwargs @property def logger(self) -> logging.Logger: """ A :class:`logging.Logger` that can be used within the :meth:`run()` method. """ return logging.getLogger(self.__class__.__name__) @classmethod def from_params( # type: ignore[override] cls: Type["Step"], params: Union[Params, dict, str], constructor_to_call: Optional[Callable[..., "Step"]] = None, constructor_to_inspect: Optional[ Union[Callable[..., "Step"], Callable[["Step"], None]] ] = None, step_name: Optional[str] = None, **extras, ) -> "Step": # Why do we need a custom from_params? Step classes have a run() method that takes all the # parameters necessary to perform the step. The __init__() method of the step takes those # same parameters, but each of them could be wrapped in another Step instead of being # supplied directly. from_params() doesn't know anything about these shenanigans, so # we have to supply the necessary logic here. if constructor_to_call is not None: raise ConfigurationError( f"{cls.__name__}.from_params cannot be called with a constructor_to_call." ) if constructor_to_inspect is not None: raise ConfigurationError( f"{cls.__name__}.from_params cannot be called with a constructor_to_inspect." ) if isinstance(params, str): params = Params({"type": params}) if not isinstance(params, Params): if isinstance(params, dict): params = Params(params) else: raise ConfigurationError( "from_params was passed a ``params`` object that was not a ``Params``. This probably " "indicates malformed parameters in a configuration file, where something that " "should have been a dictionary was actually a list, or something else. " f"This happened when constructing an object of type {cls}." ) # Build up a raw step config def replace_steps_with_refs(o: Any) -> Any: if isinstance(o, (list, tuple, set)): return o.__class__(replace_steps_with_refs(i) for i in o) elif isinstance(o, (dict, Params)): result = {key: replace_steps_with_refs(value) for key, value in o.items()} if isinstance(o, dict): return result elif isinstance(o, Params): return Params(result, history=o.history) elif isinstance(o, Step): return {"type": "ref", "ref": o.name} else: return deepcopy(o) raw_step_config = replace_steps_with_refs(params.as_dict(quiet=True)) as_registrable = cast(Type[Registrable], cls) if "type" in params and params["type"] not in as_registrable.list_available(): as_registrable.search_modules(params["type"]) choice = params.pop_choice( "type", choices=as_registrable.list_available(), default_to_first_choice=False ) subclass, constructor_name = as_registrable.resolve_class_name(choice) if not issubclass(subclass, Step): # This can happen if `choice` is a fully qualified name. raise ConfigurationError( f"Tried to make a Step of type {choice}, but ended up with a {subclass}." ) if issubclass(subclass, FunctionalStep): parameters = infer_method_params(subclass, subclass.WRAPPED_FUNC, infer_kwargs=False) if subclass.BIND: if "self" not in parameters: raise ConfigurationError( f"Functional step for {subclass.WRAPPED_FUNC} is bound but is missing argument 'self'" ) else: del parameters["self"] else: parameters = infer_method_params(subclass, subclass.run, infer_kwargs=False) del parameters["self"] init_parameters = infer_constructor_params(subclass) del init_parameters["self"] del init_parameters["kwargs"] parameter_overlap = parameters.keys() & init_parameters.keys() assert len(parameter_overlap) <= 0, ( f"If this assert fails it means that you wrote a Step with a run() method that takes one of the " f"reserved parameters ({', '.join(init_parameters.keys())})" ) parameters.update(init_parameters) kwargs: Dict[str, Any] = {} accepts_kwargs = False for param_name, param in parameters.items(): if param.kind == param.VAR_KEYWORD: # When a class takes **kwargs we store the fact that the method allows extra keys; if # we get extra parameters, instead of crashing, we'll just pass them as-is to the # constructor, and hope that you know what you're doing. accepts_kwargs = True continue explicitly_set = param_name in params constructed_arg = pop_and_construct_arg( subclass.__name__, param_name, param.annotation, param.default, params, extras ) # If the param wasn't explicitly set in `params` and we just ended up constructing # the default value for the parameter, we can just omit it. # Leaving it in can cause issues with **kwargs in some corner cases, where you might end up # with multiple values for a single parameter (e.g., the default value gives you lazy=False # for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes # lazy=True - the superclass sees both lazy=True and lazy=False in its constructor). if explicitly_set or constructed_arg is not param.default: kwargs[param_name] = constructed_arg if accepts_kwargs: kwargs.update(params) else: params.assert_empty(subclass.__name__) return subclass(step_name=step_name, step_config=raw_step_config, **kwargs) @abstractmethod def run(self, **kwargs) -> T: """ Execute the step's action. This method needs to be implemented when creating a ``Step`` subclass, but it shouldn't be called directly. Instead, call :meth:`result()`. """ raise NotImplementedError() def _run_with_work_dir(self, workspace: "Workspace", needed_by: Optional["Step"] = None) -> T: if self.work_dir_for_run is not None: raise RuntimeError("You can only run a Step's run() method once at a time.") if self.DETERMINISTIC: random.seed(784507111) self._workspace = workspace if self.cache_results: self.work_dir_for_run = workspace.work_dir(self) dir_for_cleanup = None else: dir_for_cleanup = TemporaryDirectory(prefix=f"{self.unique_id}-", suffix=".step_dir") self.work_dir_for_run = Path(dir_for_cleanup.name) try: self._replace_steps_with_results(self.extra_dependencies, workspace) kwargs = self._replace_steps_with_results(self.kwargs, workspace) self.log_starting(needed_by=needed_by) workspace.step_starting(self) try: result = self.run(**kwargs) result = workspace.step_finished(self, result) except BaseException as e: self.log_failure(e) workspace.step_failed(self, e) raise self.log_finished() return result finally: self._workspace = None self.work_dir_for_run = None if dir_for_cleanup is not None: dir_for_cleanup.cleanup() @property def work_dir(self) -> Path: """ The working directory that a step can use while its ``:meth:run()`` method runs. This is a convenience property for you to call inside your :meth:`run()` method. This directory stays around across restarts. You cannot assume that it is empty when your step runs, but you can use it to store information that helps you restart a step if it got killed half-way through the last time it ran.""" if self.work_dir_for_run is None: raise RuntimeError( "You can only call this method while the step is running with a working directory. " "Did you call '.run()' directly? You should only run a step with '.result()'." ) return self.work_dir_for_run @property def workspace(self) -> "Workspace": """ The :class:`~tango.workspace.Workspace` being used. This is a convenience property for you to call inside your :meth:`run()` method. """ if self._workspace is None: raise RuntimeError( "You can only call this method while the step is running with a workspace. " "Did you call '.run()' directly? You should only run a step with '.result()'." ) return self._workspace @property def config(self) -> Dict[str, Any]: """ The configuration parameters that were used to construct the step. This can be empty if the step was not constructed from a configuration file. """ if self._config is None: raise ValueError(f"No config has been assigned to this step! ('{self.name}')") else: return self._config def det_hash_object(self) -> Any: return self.unique_id @property def resources(self) -> StepResources: """ Defines the minimum compute resources required to run this step. Certain executors require this information in order to allocate resources for each step. You can set this with the ``step_resources`` argument to :class:`Step` or you can override this method to automatically define the required resources. """ return self.step_resources or StepResources() @property def unique_id(self) -> str: """Returns the unique ID for this step. Unique IDs are of the shape ``$class_name-$version-$hash``, where the hash is the hash of the inputs for deterministic steps, and a random string of characters for non-deterministic ones. """ if self.unique_id_cache is None: self.unique_id_cache = self.class_name if self.VERSION is not None: self.unique_id_cache += "-" self.unique_id_cache += self.VERSION self.unique_id_cache += "-" if self.DETERMINISTIC: hash_kwargs = { key: value for key, value in self.kwargs.items() if (key not in self.SKIP_ID_ARGUMENTS) and ( ( key not in self.SKIP_DEFAULT_ARGUMENTS or self.SKIP_DEFAULT_ARGUMENTS[key] != value ) ) } self.unique_id_cache += det_hash( ( (self.format.__class__.__module__, self.format.__class__.__qualname__), self.format.VERSION, hash_kwargs, ) )[:32] else: self.unique_id_cache += det_hash( _random_for_step_names.getrandbits((58**32).bit_length()) )[:32] if self._UNIQUE_ID_SUFFIX is not None: self.unique_id_cache += f"-{self._UNIQUE_ID_SUFFIX}" return self.unique_id_cache def __str__(self): return self.unique_id def __hash__(self): """ A step's hash is just its unique ID. """ return hash(self.unique_id) def __eq__(self, other): """ Determines whether this step is equal to another step. Two steps with the same unique ID are considered identical. """ if isinstance(other, Step): return self.unique_id == other.unique_id else: return False def _replace_steps_with_results(self, o: Any, workspace: "Workspace"): if isinstance(o, (Step, StepIndexer)): return o.result(workspace=workspace, needed_by=self) elif isinstance(o, Lazy): return Lazy( o._constructor, params=Params( self._replace_steps_with_results(o._params.as_dict(quiet=True), workspace) ), constructor_extras=self._replace_steps_with_results( o._constructor_extras, workspace ), ) elif isinstance(o, WithUnresolvedSteps): return o.construct(workspace) elif isinstance(o, (list, tuple, set)): return o.__class__(self._replace_steps_with_results(i, workspace) for i in o) elif isinstance(o, dict): return { key: self._replace_steps_with_results(value, workspace) for key, value in o.items() } else: return o def result( self, workspace: Optional["Workspace"] = None, needed_by: Optional["Step"] = None ) -> T: """Returns the result of this step. If the results are cached, it returns those. Otherwise it runs the step and returns the result from there. If necessary, this method will first produce the results of all steps it depends on.""" if workspace is None: from tango.workspaces import default_workspace workspace = default_workspace from tango.step_info import StepState if not self.cache_results or self not in workspace.step_cache: # Try running the step. It might get completed by a different tango process # if there is a race, so we catch "StepStateError" and check if it's "COMPLETED" # at that point. try: return self._run_with_work_dir(workspace, needed_by=needed_by) except StepStateError as exc: if exc.step_state != StepState.COMPLETED or not self.cache_results: raise elif self not in workspace.step_cache: raise StepStateError( self, exc.step_state, "because it's not found in the cache" ) else: # Step has been completed (and cached) by a different process, so we're done. pass self.log_cache_hit(needed_by=needed_by) return workspace.step_cache[self] def ensure_result( self, workspace: Optional["Workspace"] = None, ) -> None: """This makes sure that the result of this step is in the cache. It does not return the result.""" if not self.cache_results: raise RuntimeError( "It does not make sense to call ensure_result() on a step that's not cacheable." ) if workspace is None: from tango.workspaces import default_workspace workspace = default_workspace if self in workspace.step_cache: self.log_cache_hit() else: self.result(workspace) def _ordered_dependencies(self) -> Iterable["Step"]: def dependencies_internal(o: Any) -> Iterable[Step]: if isinstance(o, Step): yield o elif isinstance(o, Lazy): yield from dependencies_internal(o._params.as_dict(quiet=True)) elif isinstance(o, WithUnresolvedSteps): yield from dependencies_internal(o.args) yield from dependencies_internal(o.kwargs) elif isinstance(o, StepIndexer): yield o.step elif isinstance(o, str): return # Confusingly, str is an Iterable of itself, resulting in infinite recursion. elif isinstance(o, (dict, Params)): yield from dependencies_internal(o.values()) elif isinstance(o, Iterable): yield from itertools.chain(*(dependencies_internal(i) for i in o)) else: return yield from self.extra_dependencies yield from dependencies_internal(self.kwargs.values()) @property def dependencies(self) -> Set["Step"]: """ Returns a set of steps that this step depends on. This does not return recursive dependencies. """ return set(self._ordered_dependencies()) @property def recursive_dependencies(self) -> Set["Step"]: """ Returns a set of steps that this step depends on. This returns recursive dependencies. """ seen = set() steps = list(self.dependencies) while len(steps) > 0: step = steps.pop() if step in seen: continue seen.add(step) steps.extend(step.dependencies) return seen def log_cache_hit(self, needed_by: Optional["Step"] = None) -> None: if needed_by is not None: cli_logger.info( '[green]\N{check mark} Found output for step [bold]"%s"[/bold] in cache ' '(needed by "%s")...[/green]', self.name, needed_by.name, ) else: cli_logger.info( '[green]\N{check mark} Found output for step [bold]"%s"[/] in cache...[/]', self.name, ) def log_starting(self, needed_by: Optional["Step"] = None) -> None: if needed_by is not None: cli_logger.info( '[blue]\N{black circle} Starting step [bold]"%s"[/] (needed by "%s")...[/]', self.name, needed_by.name, ) else: cli_logger.info( '[blue]\N{black circle} Starting step [bold]"%s"[/]...[/]', self.name, ) def log_finished(self, run_name: Optional[str] = None) -> None: if run_name is not None: cli_logger.info( '[green]\N{check mark} Finished run for step [bold]"%s"[/] (%s)[/]', self.name, run_name, ) else: cli_logger.info( '[green]\N{check mark} Finished step [bold]"%s"[/][/]', self.name, ) def log_failure(self, exception: Optional[BaseException] = None) -> None: if exception is not None: log_exception(exception, logger=self.logger) cli_logger.error('[red]\N{ballot x} Step [bold]"%s"[/] failed[/]', self.name) class FunctionalStep(Step): WRAPPED_FUNC: ClassVar[Callable] BIND: ClassVar[bool] = False @property def class_name(self) -> str: return self.WRAPPED_FUNC.__name__ def run(self, *args, **kwargs): if self.BIND: return self.WRAPPED_FUNC(*args, **kwargs) else: return self.__class__.WRAPPED_FUNC(*args, **kwargs) def step( name: Optional[str] = None, *, exist_ok: bool = False, bind: bool = False, deterministic: bool = True, cacheable: Optional[bool] = None, version: Optional[str] = None, format: Format = DillFormat("gz"), skip_id_arguments: Optional[Set[str]] = None, metadata: Optional[Dict[str, Any]] = None, ): """ A decorator to create a :class:`Step` from a function. :param name: A name to register the step under. By default the name of the function is used. :param exist_ok: If True, overwrites any existing step registered under the same ``name``. Else, throws an error if a step is already registered under ``name``. :param bind: If ``True``, the first argument passed to the step function will be the underlying :class:`Step` instance, i.e. the function will be called as an instance method. In this case you must name the first argument 'self' or you will get a :class:`~tango.common.exceptions.ConfigurationError` when instantiating the class. See the :class:`Step` class for an explanation of the other parameters. Example ------- .. testcode:: from tango import step @step(version="001") def add(a: int, b: int) -> int: return a + b @step(bind=True) def bound_step(self) -> None: assert self.work_dir.is_dir() """ def step_wrapper(step_func): @Step.register(name or step_func.__name__, exist_ok=exist_ok) class WrapperStep(FunctionalStep): DETERMINISTIC = deterministic CACHEABLE = cacheable VERSION = version FORMAT = format SKIP_ID_ARGUMENTS = skip_id_arguments or set() METADATA = metadata or {} WRAPPED_FUNC = step_func BIND = bind return WrapperStep return step_wrapper class StepIndexer(CustomDetHash): def __init__(self, step: Step, key: Union[str, int]): self.step = step self.key = key def result( self, workspace: Optional["Workspace"] = None, needed_by: Optional["Step"] = None ) -> Any: return self.step.result(workspace=workspace, needed_by=needed_by)[self.key] def det_hash_object(self) -> Any: return self.step.unique_id, self.key class WithUnresolvedSteps(CustomDetHash): """ This is a helper class for some scenarios where steps depend on other steps. Let's say we have two steps, :class:`ConsumeDataStep` and :class:`ProduceDataStep`. The easiest way to make :class:`ConsumeDataStep` depend on :class:`ProduceDataStep` is to specify ``Produce`` as one of the arguments to the step. This works when ``Consume`` takes the output of ``Produce`` directly, or if it takes it inside standard Python container, like a list, set, or dictionary. But what if the output of :class:`ConsumeDataStep` needs to be added to a complex, custom data structure? :class:`WithUnresolvedSteps` takes care of this scenario. For example, this works without any help: .. code-block:: Python class ProduceDataStep(Step[MyDataClass]): def run(self, ...) -> MyDataClass ... return MyDataClass(...) class ConsumeDataStep(Step): def run(self, input_data: MyDataClass): ... produce = ProduceDataStep() consume = ConsumeDataStep(input_data = produce) This scenario needs help: .. code-block:: Python @dataclass class DataWithTimestamp: data: MyDataClass timestamp: float class ProduceDataStep(Step[MyDataClass]): def run(self, ...) -> MyDataClass ... return MyDataClass(...) class ConsumeDataStep(Step): def run(self, input_data: DataWithTimestamp): ... produce = ProduceDataStep() consume = ConsumeDataStep( input_data = DataWithTimestamp(produce, time.now()) ) That does not work, because :class:`DataWithTimestamp` needs an object of type :class:`MyDataClass`, but we're giving it an object of type :class:`Step[MyDataClass]`. Instead, we change the last line to this: .. code-block:: Python consume = ConsumeDataStep( input_data = WithUnresolvedSteps( DataWithTimestamp, produce, time.now() ) ) :class:`WithUnresolvedSteps` will delay calling the constructor of ``DataWithTimestamp`` until the :meth:`run()` method runs. Tango will make sure that the results from the ``produce`` step are available at that time, and replaces the step in the arguments with the step's results. :param function: The function to call after resolving steps to their results. :param args: The args to pass to the function. These may contain steps, which will be resolved before the function is called. :param kwargs: The kwargs to pass to the function. These may contain steps, which will be resolved before the function is called. """ def __init__(self, function, *args, **kwargs): self.function = function self.args = args self.kwargs = kwargs @classmethod def with_resolved_steps( cls, o: Any, workspace: "Workspace", ): """ Recursively goes through a Python object and replaces all instances of :class:`.Step` with the results of that step. :param o: The Python object to go through :param workspace: The workspace in which to resolve all steps :return: A new object that's a copy of the original object, with all instances of :class:`.Step` replaced with the results of the step. """ if isinstance(o, (Step, StepIndexer)): return o.result(workspace=workspace) elif isinstance(o, Lazy): return Lazy( o._constructor, params=Params(cls.with_resolved_steps(o._params.as_dict(quiet=True), workspace)), constructor_extras=cls.with_resolved_steps(o._constructor_extras, workspace), ) elif isinstance(o, cls): return o.construct(workspace) elif isinstance(o, (dict, Params)): return o.__class__( {key: cls.with_resolved_steps(value, workspace) for key, value in o.items()} ) elif isinstance(o, (list, tuple, set)): return o.__class__(cls.with_resolved_steps(item, workspace) for item in o) else: return o def construct(self, workspace: "Workspace"): """ Replaces all steps in the args that are stored in this object, and calls the function with those args. :param workspace: The :class:`.Workspace` in which to resolve all the steps. :return: The result of calling the function. """ resolved_args = self.with_resolved_steps(self.args, workspace) resolved_kwargs = self.with_resolved_steps(self.kwargs, workspace) return self.function(*resolved_args, **resolved_kwargs) def det_hash_object(self) -> Any: return self.function.__qualname__, self.args, self.kwargs ================================================ FILE: tango/step_cache.py ================================================ import logging from abc import abstractmethod from dataclasses import dataclass from typing import Any, TypeVar, Union from .common.from_params import FromParams from .common.registrable import Registrable from .format import Format from .step import Step from .step_info import StepInfo logger = logging.getLogger(__name__) T = TypeVar("T") class StepCache(Registrable): """ This is a mapping from instances of :class:`~tango.step.Step` to the results of that step. Generally :class:`StepCache` implementations are used internally by :class:`~tango.workspace.Workspace` implementations. """ default_implementation = "memory" """ The default implementation is :class:`.MemoryStepCache`. """ def __contains__(self, step: Any) -> bool: """This is a generic implementation of ``__contains__``. If you are writing your own ``StepCache``, you might want to write a faster one yourself.""" if not isinstance(step, (Step, StepInfo)): return False try: self.__getitem__(step) return True except KeyError: return False @abstractmethod def __getitem__(self, step: Union[Step, StepInfo]) -> Any: """Returns the results for the given step.""" raise NotImplementedError() @abstractmethod def __setitem__(self, step: Step, value: Any) -> None: """Writes the results for the given step. Throws an exception if the step is already cached.""" raise NotImplementedError() @abstractmethod def __delitem__(self, step_unique_id: Union[Step, StepInfo]) -> None: """Removes a step from step cache""" raise NotImplementedError() @abstractmethod def __len__(self) -> int: """Returns the number of results saved in this cache.""" raise NotImplementedError() @dataclass class CacheMetadata(FromParams): step: str """ The step name. """ format: Format """ The format used to serialize the step's result. """ ================================================ FILE: tango/step_caches/__init__.py ================================================ """ Built-in :class:`~tango.step_cache.StepCache` implementations. """ from .local_step_cache import LocalStepCache from .memory_step_cache import MemoryStepCache, default_step_cache ================================================ FILE: tango/step_caches/local_step_cache.py ================================================ import collections import logging import os import shutil import warnings import weakref from pathlib import Path from typing import Any, MutableMapping, Optional, OrderedDict, Union, cast from tango.common.aliases import PathOrStr from tango.common.params import Params from tango.step import Step from tango.step_cache import CacheMetadata, StepCache from tango.step_info import StepInfo logger = logging.getLogger(__name__) @StepCache.register("local") class LocalStepCache(StepCache): """ This is a :class:`.StepCache` that stores its results on disk, in the location given in ``dir``. Every cached step gets a directory under ``dir`` with that step's :attr:`~tango.step.Step.unique_id`. In that directory we store the results themselves in some format according to the step's :attr:`~tango.step.Step.FORMAT`, and we also write a ``cache-metadata.json`` file that stores the :class:`.CacheMetadata`. The presence of ``cache-metadata.json`` signifies that the cache entry is complete and has been written successfully. .. tip:: Registered as :class:`.StepCache` under the name "local". """ LRU_CACHE_MAX_SIZE = 8 METADATA_FILE_NAME = "cache-metadata.json" def __init__(self, dir: PathOrStr): self.dir = Path(dir) self.dir.mkdir(parents=True, exist_ok=True) # We keep an in-memory cache as well so we don't have to de-serialize stuff # we happen to have in memory already. self.weak_cache: MutableMapping[str, Any] # Not all Python objects can be referenced weakly, and even if they can they # might get removed too quickly, so we also keep an LRU cache. self.strong_cache: OrderedDict[str, Any] self._init_mem_caches() def _init_mem_caches(self): self.weak_cache = weakref.WeakValueDictionary() self.strong_cache = collections.OrderedDict() def __getstate__(self): """ We override `__getstate__()` to customize how instances of this class are pickled since we don't want to persist values in the weak and strong in-memory caches during pickling. And `WeakValueDictionary` can't be pickled anyway. """ return {"dir": self.dir} def __setstate__(self, state): for k, v in state.items(): setattr(self, k, v) self._init_mem_caches() def _add_to_cache(self, key: str, o: Any) -> None: if hasattr(o, "__next__"): # We never cache iterators, because they are mutable, storing their current position. return self.strong_cache[key] = o self.strong_cache.move_to_end(key) while len(self.strong_cache) > self.LRU_CACHE_MAX_SIZE: del self.strong_cache[next(iter(self.strong_cache))] try: self.weak_cache[key] = o except TypeError: pass # Many native Python objects cannot be referenced weakly, and they throw TypeError when you try def _get_from_cache(self, key: str) -> Optional[Any]: result = self.strong_cache.get(key) if result is not None: self.strong_cache.move_to_end(key) return result try: return self.weak_cache[key] except KeyError: return None def _remove_from_cache(self, key: str) -> None: # check and remove from strong cache if key in self.strong_cache: del self.strong_cache[key] assert key not in self.strong_cache # check and remove from weak cache if key in self.weak_cache: del self.weak_cache[key] assert key not in self.weak_cache def _metadata_path(self, step_or_unique_id: Union[Step, StepInfo, str]) -> Path: return self.step_dir(step_or_unique_id) / self.METADATA_FILE_NAME def __contains__(self, step: object) -> bool: if (isinstance(step, Step) and step.cache_results) or ( isinstance(step, StepInfo) and step.cacheable ): key = step.unique_id if key in self.strong_cache: return True if key in self.weak_cache: return True return self._metadata_path( cast(Union[Step, StepInfo], step) # cast is for mypy :/ ).exists() else: return False def __getitem__(self, step: Union[Step, StepInfo]) -> Any: key = step.unique_id result = self._get_from_cache(key) if result is None: if step not in self: raise KeyError(step) metadata = CacheMetadata.from_params(Params.from_file(self._metadata_path(step))) result = metadata.format.read(self.step_dir(step)) self._add_to_cache(key, result) return result def __setitem__(self, step: Step, value: Any) -> None: if not step.cache_results: warnings.warn( f"Tried to cache step '{step.name}' despite being marked as uncacheable", UserWarning, ) return location = self.step_dir(step) location.mkdir(parents=True, exist_ok=True) metadata_location = self._metadata_path(step) if metadata_location.exists(): raise ValueError(f"{metadata_location} already exists! Will not overwrite.") temp_metadata_location = metadata_location.with_suffix(".temp") try: step.format.write(value, location) metadata = CacheMetadata(step=step.unique_id, format=step.format) metadata.to_params().to_file(temp_metadata_location) self._add_to_cache(step.unique_id, value) temp_metadata_location.rename(metadata_location) except: # noqa: E722 try: temp_metadata_location.unlink() except FileNotFoundError: pass raise def __delitem__(self, step: Union[Step, StepInfo]) -> None: location = str(self.dir) + "/" + str(step.unique_id) try: shutil.rmtree(location) self._remove_from_cache(step.unique_id) except OSError: raise OSError(f"Step cache folder for '{step.unique_id}' not found. Cannot be deleted.") def __len__(self) -> int: return sum(1 for _ in self.dir.glob(f"*/{self.METADATA_FILE_NAME}")) def step_dir(self, step_or_unique_id: Union[Step, StepInfo, str]) -> Path: """Returns the directory that contains the results of the step. You can use this even for a step that's not cached yet. In that case it will return the directory where the results will be written.""" if isinstance(step_or_unique_id, (Step, StepInfo)): cacheable = ( step_or_unique_id.cache_results if isinstance(step_or_unique_id, Step) else step_or_unique_id.cacheable ) if not cacheable: class_name = ( step_or_unique_id.class_name if isinstance(step_or_unique_id, Step) else step_or_unique_id.step_class_name ) raise RuntimeError( f"Uncacheable steps (like '{class_name}') don't have step directories." ) unique_id = step_or_unique_id.unique_id else: unique_id = step_or_unique_id return self.dir / unique_id ================================================ FILE: tango/step_caches/memory_step_cache.py ================================================ import logging import warnings from typing import Any, Dict, Union from tango.step import Step from tango.step_cache import StepCache from tango.step_info import StepInfo logger = logging.getLogger(__name__) @StepCache.register("memory") class MemoryStepCache(StepCache): """ This is a :class:`.StepCache` that stores results in memory. It is little more than a Python dictionary. .. tip:: Registered as :class:`.StepCache` under the name "memory". """ def __init__(self): self.cache: Dict[str, Any] = {} def __getitem__(self, step: Union[Step, StepInfo]) -> Any: return self.cache[step.unique_id] def __setitem__(self, step: Step, value: Any) -> None: if step in self: raise ValueError(f"{step.unique_id} is already cached! Will not overwrite.") if step.cache_results: self.cache[step.unique_id] = value else: warnings.warn( f"Tried to cache step '{step.name}' despite being marked as uncacheable.", UserWarning, ) def __delitem__(self, step: Union[Step, StepInfo]) -> None: if step.unique_id in self.cache: del self.cache[step.unique_id] else: raise KeyError(f"{step.unique_id} not present in the memory cache. Cannot be deleted.") def __contains__(self, step: object) -> bool: if isinstance(step, (Step, StepInfo)): return step.unique_id in self.cache else: return False def __len__(self) -> int: return len(self.cache) default_step_cache = MemoryStepCache() ================================================ FILE: tango/step_caches/remote_step_cache.py ================================================ import logging import os import shutil import tempfile from abc import abstractmethod from pathlib import Path from typing import Any, Union from tango.common.aliases import PathOrStr from tango.common.exceptions import TangoError from tango.common.file_lock import FileLock from tango.common.params import Params from tango.common.remote_utils import RemoteConstants from tango.step import Step from tango.step_cache import CacheMetadata from tango.step_caches.local_step_cache import LocalStepCache from tango.step_info import StepInfo logger = logging.getLogger(__name__) class RemoteNotFoundError(TangoError): """ Classes inheriting from the RemoteStepCache should raise this if a step result object is not found. """ # This class inherits from `LocalStepCache` to benefit from its in-memory "weak cache" and "strong cache", # but it handles saving artifacts to disk a little differently. class RemoteStepCache(LocalStepCache): """ This is a :class:`~tango.step_cache.StepCache` that's used by :class:`RemoteWorkspace`. It stores the results of steps on some RemoteWorkspace. It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a step's resulting subsequent times should be fast. .. tip:: All remote step caches inherit from this. """ Constants = RemoteConstants def __init__(self, local_dir: Path): super().__init__(local_dir) @abstractmethod def _step_result_remote(self, step: Union[Step, StepInfo]): raise NotImplementedError() @abstractmethod def _upload_step_remote(self, step: Step, objects_dir: Path): raise NotImplementedError() @abstractmethod def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None: raise NotImplementedError() @abstractmethod def __len__(self): raise NotImplementedError() def _acquire_step_lock_file(self, step: Union[Step, StepInfo], read_only_ok: bool = False): return FileLock( self.step_dir(step).with_suffix(".lock"), read_only_ok=read_only_ok ).acquire_with_updates(desc=f"acquiring step cache lock for '{step.unique_id}'") def __contains__(self, step: Any) -> bool: if isinstance(step, (Step, StepInfo)): cacheable = step.cache_results if isinstance(step, Step) else step.cacheable if not cacheable: return False key = step.unique_id # First check if we have a copy in memory. if key in self.strong_cache: return True if key in self.weak_cache: return True # Then check if we have a copy on disk in our cache directory. with self._acquire_step_lock_file(step, read_only_ok=True): if self.step_dir(step).is_dir(): return True # If not, check the remote location. return self._step_result_remote(step) is not None else: return False def __getitem__(self, step: Union[Step, StepInfo]) -> Any: key = step.unique_id step_result = self._step_result_remote(step) if step_result is None: raise KeyError(step) # Try getting the result from our in-memory caches first. result = self._get_from_cache(key) if result is not None: return result def load_and_return(): metadata = CacheMetadata.from_params(Params.from_file(self._metadata_path(step))) result = metadata.format.read(self.step_dir(step) / self.Constants.STEP_RESULT_DIR) self._add_to_cache(key, result) return result # Next check our local on-disk cache. with self._acquire_step_lock_file(step, read_only_ok=True): if self.step_dir(step).is_dir(): return load_and_return() # Finally, check the remote location for the corresponding dataset. with self._acquire_step_lock_file(step): # Make sure the step wasn't cached since the last time we checked (above). if self.step_dir(step).is_dir(): return load_and_return() # We'll download the dataset to a temporary directory first, in case something goes wrong. temp_dir = tempfile.mkdtemp(dir=self.dir, prefix=key) try: self._download_step_remote(step_result, target_dir=temp_dir) # Download and extraction was successful, rename temp directory to final step result directory. os.replace(temp_dir, self.step_dir(step)) except RemoteNotFoundError: raise KeyError(step) finally: shutil.rmtree(temp_dir, ignore_errors=True) return load_and_return() def __setitem__(self, step: Step, value: Any) -> None: if not step.cache_results: logger.warning("Tried to cache step %s despite being marked as uncacheable.", step.name) return with self._acquire_step_lock_file(step): # We'll write the step's results to temporary directory first, and try to upload to # remote workspace from there in case anything goes wrong. temp_dir = Path(tempfile.mkdtemp(dir=self.dir, prefix=step.unique_id)) (temp_dir / self.Constants.STEP_RESULT_DIR).mkdir() try: step.format.write(value, temp_dir / self.Constants.STEP_RESULT_DIR) metadata = CacheMetadata(step=step.unique_id, format=step.format) metadata.to_params().to_file(temp_dir / self.METADATA_FILE_NAME) # Create the dataset and upload serialized result to it. self._upload_step_remote(step, temp_dir) # Upload successful, rename temp directory to the final step result directory. if self.step_dir(step).is_dir(): shutil.rmtree(self.step_dir(step), ignore_errors=True) os.replace(temp_dir, self.step_dir(step)) finally: shutil.rmtree(temp_dir, ignore_errors=True) # Finally, add to in-memory caches. self._add_to_cache(step.unique_id, value) ================================================ FILE: tango/step_graph.py ================================================ import logging from typing import Any, Dict, Iterator, List, Mapping, Set, Type, Union from tango.common import PathOrStr from tango.common.exceptions import ConfigurationError from tango.common.params import Params from tango.step import Step, StepIndexer logger = logging.getLogger(__name__) class StepGraph(Mapping[str, Step]): """ Represents an experiment as a directed graph. It can be treated as a :class:`~collections.abc.Mapping` of step names (``str``) to :class:`Step`. """ def __init__(self, step_dict: Dict[str, Step]): # TODO: What happens with anonymous steps in here? is_ordered = self._is_ordered(step_dict) if not is_ordered: self.parsed_steps = {step.name: step for step in self.ordered_steps(step_dict)} else: self.parsed_steps = {} for step_name, step in step_dict.items(): step.name = step_name self.parsed_steps[step_name] = step # Sanity-check the graph self._sanity_check() @classmethod def _is_ordered(cls, step_dict: Dict[str, Step]): present = set() for _, step in step_dict.items(): for dep in step.dependencies: if dep.name not in present: return False present.add(step.name) return True @classmethod def _check_unsatisfiable_dependencies(cls, dependencies: Dict[str, Set[str]]) -> None: # Check whether some of those dependencies can never be satisfied. unsatisfiable_dependencies = { dep for step_deps in dependencies.values() for dep in step_deps if dep not in dependencies.keys() } if len(unsatisfiable_dependencies) > 0: if len(unsatisfiable_dependencies) == 1: dep = next(iter(unsatisfiable_dependencies)) raise ConfigurationError( f"Specified dependency '{dep}' can't be found in the config." ) else: raise ConfigurationError( f"Some dependencies can't be found in the config: {', '.join(unsatisfiable_dependencies)}" ) @classmethod def _get_ordered_steps(cls, dependencies: Dict[str, Set[str]]) -> List[str]: done: Set[str] = set() todo = list(dependencies.keys()) ordered_steps = list() while len(todo) > 0: new_todo = [] for step_name in todo: if len(dependencies[step_name] & done) == len(dependencies[step_name]): done.add(step_name) ordered_steps.append(step_name) else: new_todo.append(step_name) if len(todo) == len(new_todo): raise ConfigurationError( "Could not make progress parsing the steps. " "You probably have a circular reference between the steps, " "Or a missing dependency." ) todo = new_todo del dependencies del done del todo return ordered_steps def _sanity_check(self) -> None: for step in self.parsed_steps.values(): if step.cache_results: nondeterministic_dependencies = [ s for s in step.recursive_dependencies if not s.DETERMINISTIC ] if len(nondeterministic_dependencies) > 0: nd_step = nondeterministic_dependencies[0] logger.warning( f"Task {step.name} is set to cache results, but depends on non-deterministic " f"step {nd_step.name}. This will produce confusing results." ) @classmethod def from_params(cls: Type["StepGraph"], params: Dict[str, Params]) -> "StepGraph": # type: ignore[override] # Determine the order in which to create steps so that all dependent steps are available when we need them. # This algorithm for resolving step dependencies is O(n^2). Since we're # anticipating the number of steps in a single config to be in the dozens at most (#famouslastwords), # we choose simplicity over cleverness. dependencies = { step_name: cls._find_step_dependencies(step_params) for step_name, step_params in params.items() } cls._check_unsatisfiable_dependencies(dependencies) # We need ordered dependencies to construct the steps with refs. ordered_steps = cls._get_ordered_steps(dependencies) # Parse the steps step_dict: Dict[str, Step] = {} for step_name in ordered_steps: step_params = params.pop(step_name) if step_name in step_dict: raise ConfigurationError(f"Duplicate step name {step_name}") step_params = cls._replace_step_dependencies(step_params, step_dict) step_dict[step_name] = Step.from_params(step_params, step_name=step_name) return cls(step_dict) def sub_graph(self, *step_names: str) -> "StepGraph": step_dict: Dict[str, Step] = {} for step_name in step_names: if step_name not in self.parsed_steps: raise KeyError( f"{step_name} is not a part of this StepGraph. " f"Available steps are: {list(self.parsed_steps.keys())}" ) step_dict.update( {dep.name: dep for dep in self.parsed_steps[step_name].recursive_dependencies} ) step_dict[step_name] = self.parsed_steps[step_name] return StepGraph(step_dict) @staticmethod def _dict_is_ref(d: Union[dict, Params]) -> bool: keys = set(d.keys()) if keys == {"ref"}: return True if keys >= {"type", "ref"} and d["type"] == "ref": return True return False @classmethod def _find_step_dependencies(cls, o: Any) -> Set[str]: dependencies: Set[str] = set() if isinstance(o, (list, tuple, set)): for item in o: dependencies = dependencies | cls._find_step_dependencies(item) elif isinstance(o, (dict, Params)): if cls._dict_is_ref(o): dependencies.add(o["ref"]) else: for value in o.values(): dependencies = dependencies | cls._find_step_dependencies(value) elif o is not None and not isinstance(o, (bool, str, int, float)): raise ValueError(o) return dependencies @classmethod def _replace_step_dependencies(cls, o: Any, existing_steps: Mapping[str, Step]) -> Any: if isinstance(o, (list, tuple, set)): return o.__class__(cls._replace_step_dependencies(i, existing_steps) for i in o) elif isinstance(o, (dict, Params)): if cls._dict_is_ref(o): if "key" in o: return StepIndexer(existing_steps[o["ref"]], o["key"]) else: return existing_steps[o["ref"]] else: result = { key: cls._replace_step_dependencies(value, existing_steps) for key, value in o.items() } if isinstance(o, dict): return result elif isinstance(o, Params): return Params(result, history=o.history) else: raise RuntimeError(f"Object {o} is of unexpected type {o.__class__}.") elif o is not None and not isinstance(o, (bool, str, int, float)): raise ValueError(o) return o def __getitem__(self, name: str) -> Step: """ Get the step with the given name. """ return self.parsed_steps[name] def __len__(self) -> int: """ The number of steps in the experiment. """ return len(self.parsed_steps) def __iter__(self) -> Iterator[str]: """ The names of the steps in the experiment. """ return iter(self.parsed_steps) @classmethod def ordered_steps(cls, step_dict: Dict[str, Step]) -> List[Step]: """ Returns the steps in this step graph in an order that can be executed one at a time. This does not take into account which steps may be cached. It simply returns an executable order of steps. """ dependencies = { step_name: set([dep.name for dep in step.dependencies]) for step_name, step in step_dict.items() } result: List[Step] = [] for step_name in cls._get_ordered_steps(dependencies): step_dict[step_name].name = step_name result.append(step_dict[step_name]) return result def uncacheable_leaf_steps(self) -> Set[Step]: interior_steps: Set[Step] = set() for _, step in self.parsed_steps.items(): for dependency in step.dependencies: interior_steps.add(dependency) uncacheable_leaf_steps = { step for step in set(self.values()) - interior_steps if not step.cache_results } return uncacheable_leaf_steps @classmethod def from_file(cls, filename: PathOrStr) -> "StepGraph": params = Params.from_file(filename) return StepGraph.from_params(params.pop("steps", keep_as_dict=True)) def to_config(self, include_unique_id: bool = False) -> Dict[str, Dict]: step_dict = {} def _to_config(o: Any): if isinstance(o, (list, tuple, set)): return o.__class__(_to_config(i) for i in o) elif isinstance(o, dict): return {key: _to_config(value) for key, value in o.items()} elif isinstance(o, Step): return {"type": "ref", "ref": o.name} elif isinstance(o, StepIndexer): return {"type": "ref", "ref": o.step.name, "key": o.key} elif o is not None and not isinstance(o, (bool, str, int, float)): raise ValueError(o) return o for step_name, step in self.parsed_steps.items(): try: step_dict[step_name] = { key: _to_config(value) for key, value in step.config.items() } except ValueError: # step.config throws an error. # If the step_graph was not constructed using a config, we attempt to create # the config using the step object. step_dict[step_name] = { key: _to_config(val) for key, val in step._to_params()["kwargs"].items() } step_dict[step_name]["type"] = step.__module__ + "." + step.class_name # We only add cache_results and format to the config if the values are different from default. if step.cache_results != step.CACHEABLE: step_dict[step_name]["cache_results"] = step.cache_results if step.format != step.FORMAT: step_dict[step_name]["step_format"] = _to_config(step.format._to_params()) if include_unique_id: step_dict[step_name]["step_unique_id_override"] = step.unique_id return step_dict def to_file(self, filename: PathOrStr, include_unique_id: bool = False) -> None: """ Note: In normal use cases, `include_unique_id` should always be False. We do not want to save the unique id in the config, because we want the output to change if we modify other kwargs in the config file. We include this flag for `MulticoreExecutor`, to ensure that steps have the same unique id in the main process and the created subprocesses. """ step_dict = self.to_config(include_unique_id=include_unique_id) params = Params({"steps": step_dict}) params.to_file(filename) def __repr__(self) -> str: result = [f'"{name}": {step}' for name, step in self.items()] result = ", ".join(result) return f"{self.__class__.__name__}({result})" ================================================ FILE: tango/step_info.py ================================================ import getpass import logging import os import platform import socket import sys from dataclasses import dataclass, field from datetime import datetime, timedelta from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple import pytz from .common.from_params import FromParams from .common.logging import log_exception from .common.util import StrEnum, jsonify, local_timezone, replace_steps_with_unique_id from .step import Step from .version import VERSION logger = logging.getLogger(__name__) def get_pip_packages() -> Optional[List[Tuple[str, str]]]: """ Get the current working set of pip packages. Equivalent to running ``pip freeze``. """ # Adapted from the Weights & Biases client library: # github.com/wandb/client/blob/a04722575eee72eece7eef0419d0cea20940f9fe/wandb/sdk/internal/meta.py#L56-L72 try: import pkg_resources return sorted([(d.key, d.version) for d in iter(pkg_resources.working_set)]) except Exception as exc: logger.error("Error saving pip packages") log_exception(exc) return None class StepState(StrEnum): """Describes the possible state a step can be in.""" INCOMPLETE = "incomplete" """The step has not run yet.""" RUNNING = "running" """The step is running right now.""" COMPLETED = "completed" """The step finished running successfully.""" FAILED = "failed" """The step ran, but failed.""" UNCACHEABLE = "uncacheable" """The step is uncacheable. It will be executed as many times as the results are needed, so we don't keep track of the state.""" @dataclass class GitMetadata(FromParams): commit: Optional[str] = None """ The commit SHA of the current repo. """ remote: Optional[str] = None """ The URL of the primary remote. """ @classmethod def check_for_repo(cls) -> Optional["GitMetadata"]: from git import InvalidGitRepositoryError, Repo try: repo = Repo(".") except InvalidGitRepositoryError: return None return cls(commit=str(repo.commit()), remote=repo.remote().url) @dataclass class TangoMetadata(FromParams): version: str = VERSION """ The tango release version. """ @dataclass class PlatformMetadata(FromParams): operating_system: str = field(default_factory=platform.platform) """ Full operating system name. """ cpu_count: Optional[int] = field(default_factory=os.cpu_count) """ Numbers of CPUs on the machine. """ user: str = field(default_factory=getpass.getuser) """ The user that ran this step. """ host: str = field(default_factory=socket.gethostname) """ Name of the host machine. """ @dataclass class EnvironmentMetadata(FromParams): python: str = field(default_factory=platform.python_version) """ The Python version. """ executable: Path = field(default_factory=lambda: Path(sys.executable)) """ Path to the Python executable. """ command: str = field(default_factory=lambda: " ".join(sys.argv)) """ The exact command used. """ root: Path = field(default_factory=lambda: Path(os.getcwd())) """ The root directory from where the Python executable was ran. """ packages: Optional[List[Tuple[str, str]]] = field(default_factory=get_pip_packages) """ The current set of Python packages in the Python environment. Each entry is a tuple of strings. The first element is the name of the package, the second element is the version. """ git: Optional[GitMetadata] = field(default_factory=GitMetadata.check_for_repo) """ The :class:`GitMetadata`. """ tango: Optional[TangoMetadata] = field(default_factory=TangoMetadata) """ The :class:`TangoMetadata`. """ @dataclass class StepInfo(FromParams): """Stores step information without being the :class:`.Step` itself. It's not always possible to get a :class:`.Step` object, because :class:`.Step` objects can't be serialized. But you can always serialize a :class:`.StepInfo` object. """ unique_id: str """ The unique ID of the step """ step_class_name: str """ The name of the :class:`.Step` class """ dependencies: Set[str] """ The unique ids of all the steps that this step depends on """ cacheable: bool """ Whether or not the step is cacheable. """ step_name: Optional[str] = None """ The name of the step, if it has one. Anonymous steps are identified only by their unique ID. The same step can have different names in different runs. The last run wins, so don't rely on this property in your code. It is just here to aid readability. """ version: Optional[str] = None """ The version string of the :class:`.Step`, if it has one. """ start_time: Optional[datetime] = None """ The time (in UTC) that this step started running. .. seealso:: :meth:`start_time_local()`. """ end_time: Optional[datetime] = None """ The time (in UTC) that this step stopped running. This will be set whether the step succeeded or failed. .. seealso:: :meth:`end_time_local()`. """ error: Optional[str] = None """ If the step failed, this is where the error goes. .. note:: Some ``Workspace`` implementations need to serialize ``StepInfo`` (using pickle or dill, for example), but some exceptions can't be pickled. In those cases ``error`` will just be a string representation of the exception. """ result_location: Optional[str] = None """ Location of the result. This could be a path or a URL. """ config: Optional[Dict[str, Any]] = None """ The raw config of the step. """ metadata: Optional[Dict[str, Any]] = None """ Metadata from the step. This comes from the ``step_metadata`` argument to the :class:`~tango.step.Step` class. """ platform: PlatformMetadata = field(default_factory=PlatformMetadata) """ The :class:`PlatformMetadata`. """ environment: EnvironmentMetadata = field(default_factory=EnvironmentMetadata) """ The :class:`EnvironmentMetadata`. """ @property def start_time_local(self) -> Optional[datetime]: """ The time the step started running with respect to the local timezone, if the timezone can be determined. """ return None if self.start_time is None else self.start_time.astimezone(local_timezone()) @property def end_time_local(self) -> Optional[datetime]: """ The time the step stopped running with respect to the local timezone, if the timezone can be determined. """ return None if self.end_time is None else self.end_time.astimezone(local_timezone()) @property def duration(self) -> Optional[timedelta]: """ The time it took to run this step. """ if self.start_time is not None and self.end_time is not None: return self.end_time - self.start_time else: return None @property def state(self) -> StepState: """ Returns the state of the step """ if self.cacheable: if self.start_time is None and self.end_time is None and self.error is None: return StepState.INCOMPLETE if self.start_time is not None and self.end_time is None and self.error is None: return StepState.RUNNING if self.start_time is not None and self.end_time is not None and self.error is None: return StepState.COMPLETED if self.start_time is not None and self.end_time is not None and self.error is not None: return StepState.FAILED else: return StepState.UNCACHEABLE raise RuntimeError(f"{self.__class__.__name__} is in an invalid state.") def to_json_dict(self) -> Dict[str, Any]: """ Generates a JSON-safe, human-readable, dictionary representation of this dataclass. """ return jsonify(self) @classmethod def from_json_dict(cls, json_dict: Dict[str, Any]) -> "StepInfo": """ The inverse of :meth:`to_json_dict()`. :param json_dict: A dictionary representation, such as the one produced by :meth:`to_json_dict()`. """ step_info = cls.from_params( { k: ( datetime.strptime(v, "%Y-%m-%dT%H:%M:%S").replace(tzinfo=pytz.utc) if k in {"start_time", "end_time"} and v is not None else v ) for k, v in json_dict.items() if k != "config" } ) step_info.config = json_dict.get("config") return step_info @classmethod def new_from_step(cls, step: Step, **kwargs) -> "StepInfo": try: config = step.config except ValueError: config = None return cls( unique_id=step.unique_id, step_name=step.name, step_class_name=step.class_name, version=step.VERSION, dependencies={dep.unique_id for dep in step.dependencies}, cacheable=step.cache_results, config=replace_steps_with_unique_id(config), metadata=step.metadata, **kwargs, ) def refresh(self): """ Refresh environment and platform metadata. """ self.platform = PlatformMetadata() self.environment = EnvironmentMetadata() ================================================ FILE: tango/steps/__init__.py ================================================ """ Built-in :class:`~tango.step.Step` implementations that are not tied to any particular integration. """ __all__ = ["DatasetCombineStep", "DatasetRemixStep", "PrintStep", "ShellStep"] from .dataset_remix import DatasetCombineStep, DatasetRemixStep from .print import PrintStep from .shell_step import ShellStep ================================================ FILE: tango/steps/dataset_remix.py ================================================ import collections import random import re from typing import Any, Dict, List, Mapping, Sequence from tango.common.dataset_dict import DatasetDict from tango.common.sequences import ( ConcatenatedSequence, ShuffledSequence, SlicedSequence, ) from tango.step import Step @Step.register("dataset_remix") class DatasetRemixStep(Step[DatasetDict]): """ This step can remix splits in a :class:`~tango.common.dataset_dict.DatasetDict` into new splits. .. tip:: Registered as a :class:`~tango.step.Step` under the name "dataset_remix". Examples -------- .. testcode:: :hide: from tango.common.logging import initialize_logging initialize_logging(enable_cli_logs=True) .. testcode:: input = DatasetDict({ "train": list(range(10)), "dev": list(range(10, 15)), }) new_splits = { "all": "train + dev", "crossval_train": "train[0:5] + train[7:]", "crossval_test": "train[5:7]", } remix_step = DatasetRemixStep(input=input, new_splits=new_splits) remixed_dataset = remix_step.result() .. testoutput:: :hide: :options: +ELLIPSIS ... """ DETERMINISTIC = True CACHEABLE = False # This is so fast it's not worth caching. VERSION = "001" def run( # type: ignore self, input: DatasetDict, new_splits: Dict[str, str], keep_old_splits: bool = True, shuffle_before: bool = False, shuffle_after: bool = False, random_seed: int = 1532637578, ) -> DatasetDict: """ Remixes and shuffles a dataset. This is done lazily, so all operations are fast. :param input: The input dataset that will be remixed. :param new_splits: Specifies the new splits that the output dataset should have. Keys are the name of the new splits. Values refer to the original splits. You can refer to original splits in the following ways: * Mention the original split name to copy it to a new name. * Mention the original split name with Python's slicing syntax to select part of the original split's instances. For example, ``"train[:1000]"`` selects the first 1000 instances from the ``"train"`` split. * ``"instances + instances"`` concatenates the instances into one split. You can combine these possibilities. :param keep_old_splits: Whether to keep the splits from the input dataset in addition to the new ones given by ``new_splits``. :param shuffle_before: Whether to shuffle the input splits before creating the new ones. If you need shuffled instances and you're not sure the input is properly shuffled, use this. :param shuffle_after: Whether to shuffle the input splits after creating the new ones. If you need shuffled instances and you're slicing or concatenating splits, use this. If you want to be on the safe side, shuffle both before and after. Shuffling is a cheap operation. :param random_seed: Random seed, affects shuffling :returns: Returns a new dataset that is appropriately remixed. """ random.seed(random_seed) if shuffle_before: input_splits: Mapping[str, Sequence[Any]] = { split_name: ShuffledSequence(split_instances) for split_name, split_instances in input.splits.items() } else: input_splits = input.splits def get_slice(split_name: str) -> Sequence[Any]: slice_match = re.match(r"(.*)\[(-?[0-9]*:-?[0-9]*)\]", split_name) if slice_match is None: return input[split_name] else: split_name = slice_match[1] slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(":")] return SlicedSequence(input[split_name], slice(*slice_args)) def parse_split_spec(split_spec: str): parts = [get_slice(name.strip()) for name in split_spec.split("+")] if len(parts) == 1: return parts[0] else: return ConcatenatedSequence(*parts) if keep_old_splits: result = dict(input_splits.items()) else: result = {} result.update( { new_split_name: parse_split_spec(new_split_spec) for new_split_name, new_split_spec in new_splits.items() } ) if shuffle_after: result = { split_name: ShuffledSequence(split_instances) for split_name, split_instances in result.items() } return DatasetDict(splits=result, metadata=input.metadata) @Step.register("dataset_combine") class DatasetCombineStep(Step[DatasetDict]): """ This step combines multiple :class:`~tango.common.dataset_dict.DatasetDict` s into one. .. tip:: Registered as a :class:`~tango.step.Step` under the name "dataset_combine". Examples -------- .. testcode:: :hide: from tango.common.logging import initialize_logging initialize_logging(enable_cli_logs=True) .. testcode:: input1 = DatasetDict({ "train": list(range(10)), "dev": list(range(10, 15)), }) input2 = DatasetDict({ "train": list(range(15, 25)), "val": list(range(25, 30)), }) combined = DatasetCombineStep(inputs=[input1, input2]) combined_dataset = combined.result() .. testoutput:: :hide: :options: +ELLIPSIS ... """ DETERMINISTIC = True CACHEABLE = False # This is so fast it's not worth caching. VERSION = "001" def run( # type: ignore self, inputs: List[DatasetDict], shuffle: bool = False, random_seed: int = 1532637578, ) -> DatasetDict: """ Combines multiple datasets into one. This is done lazily, so all operations are fast. If a split is present in more than one input dataset, the output dataset will have a split that's the concatenation of the input splits. :param inputs: The list of input datasets that will be combined. :param shuffle: Whether to shuffle the combined datasets. If you don't do this, the new splits will contain first all the instances from one dataset, and then all the instances from another dataset. :param random_seed: Random seed, affects shuffling :returns: Returns a new dataset that is the combination of the input datasets. """ split_to_datasets: Dict[str, List[Sequence]] = collections.defaultdict(lambda: []) for input in inputs: for split_name, sequence in input.items(): split_to_datasets[split_name].append(sequence) result: Dict[str, Sequence] = { split_name: ConcatenatedSequence(*sequences) for split_name, sequences in split_to_datasets.items() } if shuffle: random.seed(random_seed) result = { split_name: ShuffledSequence(split_instances) for split_name, split_instances in result.items() } return DatasetDict(result, {}) ================================================ FILE: tango/steps/print.py ================================================ import logging from typing import Any from tango.common.logging import cli_logger from tango.step import Step @Step.register("print") class PrintStep(Step): """ This step just logs or prints its input and also returns what it prints. """ DETERMINISTIC = True CACHEABLE = False # so fast it's not worth caching def run(self, input: Any) -> str: # type: ignore[override] """ Print out the input. """ out = str(input) if self.logger.isEnabledFor(logging.INFO): self.logger.info(out) elif cli_logger.isEnabledFor(logging.INFO): cli_logger.info(out) else: print(out) return out ================================================ FILE: tango/steps/shell_step.py ================================================ import os import subprocess from typing import List, Optional, Union from tango.common import PathOrStr, RegistrableFunction, make_registrable from tango.step import Step @make_registrable(exist_ok=True) def check_path_existence(path: PathOrStr): assert os.path.exists(path), f"Output not found at {path}!" @Step.register("shell_step") class ShellStep(Step): """ This step runs a shell command, and returns the standard output as a string. .. tip:: Registered as a :class:`~tango.step.Step` under the name "shell_step". :param shell_command: The shell command to run. :param output_path: The step makes no assumptions about the command being run. If your command produces some output, you can optionally specify the output path for recording the output location, and optionally validating it. See `validate_output` argument for this. :param validate_output: If an expected `output_path` has been specified, you can choose to validate that the step produced the correct output. By default, it will just check if the `output_path` exists, but you can pass any other validating function. For example, if your command is a script generating a model output, you can check if the model weights can be loaded. :param kwargs: Other kwargs to be passed to `subprocess.run()`. If you need to take advantage of environment variables, set `shell = True`. """ def run( # type: ignore[override] self, shell_command: Union[str, List[str]], output_path: Optional[PathOrStr] = None, validate_output: RegistrableFunction = check_path_existence, **kwargs, ): output = self.run_command(shell_command, **kwargs) self.logger.info(output) if output_path is not None: validate_output(output_path) self.logger.info(f"Output found at: {output_path}") return str(output.decode("utf-8")) def run_command(self, command: Union[str, List[str]], **kwargs): import shlex if kwargs.get("shell", None): if isinstance(command, list): command = shlex.join(command) else: if isinstance(command, str): command = shlex.split(command) self.logger.info(f"Command: {command}") process = subprocess.run(command, capture_output=True, **kwargs) if process.returncode != 0: raise RuntimeError(f"The command failed with the following errors: {process.stderr}") return process.stdout ================================================ FILE: tango/version.py ================================================ _MAJOR = "1" _MINOR = "3" _PATCH = "2" # This is mainly for pre-releases which have the suffix "rc[0-9]+". _SUFFIX = "" VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) ================================================ FILE: tango/workspace.py ================================================ import logging from abc import abstractmethod from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime from pathlib import Path from tempfile import TemporaryDirectory from typing import ( Any, ContextManager, Dict, Generator, Iterable, List, Optional, TypeVar, Union, cast, ) from urllib.parse import ParseResult, urlparse import pytz from .common import Registrable from .common.from_params import FromParams from .common.util import StrEnum, jsonify, utc_now_datetime from .step import Step from .step_cache import StepCache from .step_info import StepInfo, StepState logger = logging.getLogger(__name__) T = TypeVar("T") @dataclass class Run(FromParams): """ Stores information about a single Tango run. """ name: str """ The name of the run """ steps: Dict[str, StepInfo] """ A mapping from step names to :class:`~tango.step_info.StepInfo`, for all the target steps in the run. This only contains the targets of a run. Usually, that means it contains all named steps. Un-named dependencies (or dependencies that are not targets) are not contained in ``steps``. """ start_date: datetime """ The time at which the run was registered in the workspace. """ def to_json_dict(self) -> Dict[str, Any]: return jsonify(self) @classmethod def from_json_dict(cls, json_dict: Dict[str, Any]) -> "Run": params = {**json_dict} params["start_date"] = datetime.strptime(params["start_date"], "%Y-%m-%dT%H:%M:%S").replace( tzinfo=pytz.utc ) params["steps"] = {k: StepInfo.from_json_dict(v) for k, v in params["steps"].items()} return cls.from_params(params) @dataclass class RunInfo(FromParams): """ Stores partial data about a run. This is the type that you get back from :meth:`Workspace.search_registered_runs()`. The data here is a subset of the data in the :class:`Run` type because not all workspaces can fetch all of the data in the :class:`Run` type efficiently. """ name: str """ The name of the run. """ steps: Optional[Dict[str, str]] = None """ The steps within the run. An optional mapping of step name to step unique ID. """ start_date: Optional[datetime] = None """ The time at which the run was registered in the workspace. """ class RunSort(StrEnum): START_DATE = "start_date" NAME = "name" class StepInfoSort(StrEnum): UNIQUE_ID = "unique_id" START_TIME = "start_time" class Workspace(Registrable): """ A workspace is a place for Tango to put the results of steps, intermediate results, and various other pieces of metadata. If you don't want to worry about all that, do nothing and Tango will use the default :class:`.LocalWorkspace` that puts everything into a directory of your choosing. If you want to do fancy things like store results in the cloud, share state across machines, etc., this is your integration point. If you got here solely because you want to share results between machines, consider that :class:`.LocalWorkspace` works fine on an NFS drive. """ default_implementation = "local" # # As a general rule, workspaces can never return `Step`, only `StepInfo`, because we can't reliably serialize # objects of type `Step`. Doing that would require serializing the code that runs the step, and we can't # do that. # def __init__(self): self._delayed_cleanup_temp_dirs: List[TemporaryDirectory] = [] def __getstate__(self): """ We override `__getstate__()` to customize how instances of this class are pickled since we don't want to persist certain attributes. """ out = {k: v for k, v in self.__dict__.items() if k not in {"_delayed_cleanup_temp_dirs"}} out["_delayed_cleanup_temp_dirs"] = [] return out @property @abstractmethod def url(self) -> str: """ Get a URL for the workspace that can be used to instantiate the same workspace using :meth:`.from_url()`. """ raise NotImplementedError @classmethod def from_url(cls, url: str) -> "Workspace": """ Initialize a :class:`Workspace` from a workspace URL or path, e.g. ``local:///tmp/workspace`` would give you a :class:`~tango.workspaces.LocalWorkspace` in the directory ``/tmp/workspace``. For :class:`~tango.workspaces.LocalWorkspace`, you can also just pass in a plain path, e.g. ``/tmp/workspace``. .. tip:: Registered as a workspace constructor under the name "from_url". """ parsed = urlparse(url) workspace_type = parsed.scheme or "local" workspace_cls = cast(Workspace, cls.by_name(workspace_type)) return workspace_cls.from_parsed_url(parsed) @classmethod @abstractmethod def from_parsed_url(cls, parsed_url: ParseResult) -> "Workspace": """ Subclasses should override this so that can be initialized from a URL. :param parsed_url: The parsed URL object. """ raise NotImplementedError @property @abstractmethod def step_cache(self) -> StepCache: """ A :class:`.StepCache` to store step results in """ raise NotImplementedError() def work_dir(self, step: Step) -> Path: """Steps that can be restarted (like a training job that gets interrupted half-way through) must save their state somewhere. A :class:`.StepCache` can help by providing a suitable location in this method. By default, the step dir is a temporary directory that gets cleaned up after every run. This effectively disables restartability of steps.""" # TemporaryDirectory cleans up the directory automatically when the TemporaryDirectory object # gets garbage collected, so we hold on to it in the Workspace. dir = TemporaryDirectory(prefix=f"{step.unique_id}-", suffix=".step_dir") self._delayed_cleanup_temp_dirs.append(dir) return Path(dir.name) @abstractmethod def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: """ Returns a :class:`~tango.step_info.StepInfo` for a given step. :raises KeyError: If the corresponding step info cannot be found or created. This should never happen if you pass a :class:`~tango.step.Step` object to this method since a :class:`~tango.step_info.StepInfo` can always be created from a :class:`~tango.step.Step`. """ raise NotImplementedError() def search_step_info( self, *, sort_by: Optional[StepInfoSort] = None, sort_descending: bool = True, match: Optional[str] = None, state: Optional[StepState] = None, start: int = 0, stop: Optional[int] = None, ) -> List[StepInfo]: """ Search through steps in the workspace. This method is primarily meant to be used to implement a UI, and workspaces don't necessarily need to implement all `sort_by` or filter operations. They should only implement those that can be done efficiently. :param sort_by: The field to sort the results by. :param sort_descending: Sort the results in descending order of the ``sort_by`` field. :param match: Only return steps with a unique ID matching this string. :param state: Only return steps that are in the given state. :param start: Start from a certain index in the results. :param stop: Stop at a certain index in the results. :raises NotImplementedError: If a workspace doesn't support an efficient implementation for the given sorting/filtering criteria. """ steps = [ step for run in self.registered_runs().values() for step in run.steps.values() if (match is None or match in step.unique_id) and (state is None or step.state == state) ] if sort_by == StepInfoSort.START_TIME: now = utc_now_datetime() steps = sorted( steps, key=lambda step: step.start_time or now, reverse=sort_descending, ) elif sort_by == StepInfoSort.UNIQUE_ID: steps = sorted(steps, key=lambda step: step.unique_id, reverse=sort_descending) elif sort_by is not None: raise NotImplementedError return steps[slice(start, stop)] def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int: """ Get the total number of registered steps. :param match: Only count steps with a unique ID matching this string. :param state: Only count steps that are in the given state. """ return len(self.search_step_info(match=match, state=state)) @abstractmethod def step_starting(self, step: Step) -> None: """ The :class:`.Step` class calls this when a step is about to start running. :param step: The step that is about to start. :raises StepStateError: If the step is in an unexpected state (e.g. RUNNING). """ raise NotImplementedError() @abstractmethod def step_finished(self, step: Step, result: T) -> T: """ The :class:`.Step` class calls this when a step finished running. :param step: The step that finished. :raises StepStateError: If the step is in an unexpected state (e.g. RUNNING). This method is given the result of the step's :meth:`.Step.run` method. It is expected to return that result. This gives it the opportunity to make changes to the result if necessary. For example, if the :meth:`.Step.run` method returns an iterator, that iterator would be consumed when it's written to the cache. So this method can handle the situation and return something other than the now-consumed iterator. """ raise NotImplementedError() @abstractmethod def step_failed(self, step: Step, e: BaseException) -> None: """ The :class:`.Step` class calls this when a step failed. :param step: The step that failed. :param e: The exception thrown by the step's :meth:`.Step.run` method. :raises StepStateError: If the step is in an unexpected state (e.g. RUNNING). """ raise NotImplementedError() @abstractmethod def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: """ Register a run in the workspace. A run is a set of target steps that a user wants to execute. :param targets: The steps that the user wants to execute. This could come from a :class:`.StepGraph`. :param name: A name for the run. Runs must have unique names. If not given, this method invents a name and returns it. :return: The run object """ raise NotImplementedError() def search_registered_runs( self, *, sort_by: Optional[RunSort] = None, sort_descending: bool = True, match: Optional[str] = None, start: int = 0, stop: Optional[int] = None, ) -> List[RunInfo]: """ Search through registered runs in the workspace. This method is primarily meant to be used to implement a UI, and workspaces don't necessarily need to implement all `sort_by` or filter operations. They should only implement those that can be done efficiently. .. note:: The data type returned in the list here is :class:`RunInfo`, which contains a subset of the data in the :class:`Run` type. :param sort_by: The field to sort the results by. :param sort_descending: Sort the results in descending order of the ``sort_by`` field. :param match: Only return results with a name matching this string. :param start: Start from a certain index in the results. :param stop: Stop at a certain index in the results. :raises NotImplementedError: If a workspace doesn't support an efficient implementation for the given sorting/filtering criteria. """ runs = [ run for run in self.registered_runs().values() if match is None or match in run.name ] if sort_by == RunSort.START_DATE: runs = sorted(runs, key=lambda run: run.start_date, reverse=sort_descending) elif sort_by == RunSort.NAME: runs = sorted(runs, key=lambda run: run.name, reverse=sort_descending) elif sort_by is not None: raise NotImplementedError return [ RunInfo( name=run.name, start_date=run.start_date, steps={k: s.unique_id for k, s in run.steps.items()}, ) for run in runs[slice(start, stop)] ] def num_registered_runs(self, *, match: Optional[str] = None) -> int: """ Get the number of registered runs. :param match: Only count runs with a name matching this string. """ return len(self.search_registered_runs(match=match)) @abstractmethod def registered_runs(self) -> Dict[str, Run]: """ Returns all runs in the workspace :return: A dictionary mapping run names to :class:`Run` objects """ raise NotImplementedError @abstractmethod def registered_run(self, name: str) -> Run: """ Returns the run with the given name :return: A :class:`Run` object representing the named run :raises KeyError: If there is no run with the given name. """ raise NotImplementedError() def step_result_for_run(self, run_name: str, step_name: str) -> Any: """ Get the result of a step from a run. :raises KeyError: If there is no run or step with the given name. """ run = self.registered_run(run_name) step_info = run.steps[step_name] try: return self.step_cache[step_info] except KeyError: raise KeyError(f"Step result for '{step_name}' not found in workspace") def step_result(self, step_name: str) -> Any: """ Get the result of a step from the latest run with a step by that name. :raises KeyError: If there is no run with the given step. """ runs = sorted(self.registered_runs().values(), key=lambda run: run.start_date, reverse=True) for run in runs: if step_name in run.steps: return self.step_cache[run.steps[step_name]] raise KeyError(f"No step named '{step_name}' found in previous runs") @abstractmethod def remove_step(self, step_unique_id: str): """ Removes cached step using the given unique step id :raises KeyError: If there is no step with the given name. """ raise NotImplementedError() def capture_logs_for_run(self, name: str) -> ContextManager[None]: """ Should return a context manager that can be used to capture the logs for a run. By default, this doesn't do anything. Examples -------- The :class:`.LocalWorkspace` implementation uses this method to capture logs to a file in the workspace's directory using the :func:`~tango.common.logging.file_handler()` context manager, similar to this: .. testcode:: from tango.common.logging import file_handler from tango.workspace import Workspace class MyLocalWorkspace(Workspace): def capture_logs_for_run(self, name: str): return file_handler("/path/to/workspace/" + name + ".log") """ @contextmanager def do_nothing() -> Generator[None, None, None]: yield None return do_nothing() Workspace.register("from_url", constructor="from_url")(Workspace) # type: ignore ================================================ FILE: tango/workspaces/__init__.py ================================================ """ Built-in :class:`~tango.workspace.Workspace` implementations. """ from .local_workspace import LocalWorkspace from .memory_workspace import MemoryWorkspace, default_workspace ================================================ FILE: tango/workspaces/local_workspace.py ================================================ import json import logging import os from datetime import datetime from pathlib import Path from typing import Dict, Iterable, Iterator, List, Optional, Set, TypeVar, Union from urllib.parse import ParseResult import dill import petname from sqlitedict import SqliteDict from tango.common import PathOrStr from tango.common.exceptions import StepStateError from tango.common.file_lock import FileLock from tango.common.logging import file_handler from tango.common.util import exception_to_string, utc_now_datetime from tango.step import Step from tango.step_cache import StepCache from tango.step_caches import LocalStepCache from tango.step_info import StepInfo, StepState from tango.workspace import Run, StepInfoSort, Workspace logger = logging.getLogger(__name__) T = TypeVar("T") @Workspace.register("local") class LocalWorkspace(Workspace): """ This is a :class:`.Workspace` that keeps all its data in a local directory. This works great for single-machine jobs, or for multiple machines in a cluster if they can all access the same NFS drive. :param dir: The directory to store all the data in The directory will have three subdirectories, ``cache/`` for the step cache, ``runs/`` for the runs, and ``latest/`` for the results of the latest run. For the format of the ``cache/`` directory, refer to :class:`.LocalStepCache`. The ``runs/`` directory will contain one subdirectory for each registered run. Each one of those contains a symlink from the name of the step to the results directory in the step cache. Note that :class:`.LocalWorkspace` creates these symlinks even for steps that have not finished yet. You can tell the difference because either the symlink points to a directory that doesn't exist, or it points to a directory in the step cache that doesn't contain results. .. tip:: Registered as a :class:`~tango.workspace.Workspace` under the name "local". You can also instantiate this workspace from a URL with the scheme ``local://``. For example, ``Workspace.from_url("local:///tmp/workspace")`` gives you a :class:`LocalWorkspace` in the directory ``/tmp/workspace``. """ def __init__(self, dir: PathOrStr): super().__init__() self.dir = Path(dir) self.dir.mkdir(parents=True, exist_ok=True) self.cache = LocalStepCache(self.dir / "cache") self.locks: Dict[Step, FileLock] = {} self.runs_dir = self.dir / "runs" self.runs_dir.mkdir(parents=True, exist_ok=True) self.step_info_file = self.dir / "stepinfo.sqlite" self.latest_dir = self.dir / "latest" # Check the version of the local workspace try: with open(self.dir / "settings.json", "r") as settings_file: settings = json.load(settings_file) except FileNotFoundError: settings = {"version": 1} # Upgrade to version 2 if settings["version"] == 1: with SqliteDict(self.step_info_file) as d: for stepinfo_file in self.cache.dir.glob("*/stepinfo.dill"): with stepinfo_file.open("rb") as f: stepinfo = dill.load(f) # The `StepInfo` class changed from one version to the next. The deserialized version # ends up being a `StepInfo` instance that is missing the `cacheable` member. This # hack adds it in. kwargs = stepinfo.__dict__ kwargs[ "cacheable" ] = True # Only cacheable steps were saved in v1. That's what v2 fixes. d[stepinfo.unique_id] = StepInfo(**kwargs) d.commit() for stepinfo_file in self.cache.dir.glob("*/stepinfo.dill"): stepinfo_file.unlink() settings["version"] = 2 with open(self.dir / "settings.json", "w") as settings_file: json.dump(settings, settings_file) def __getstate__(self): """ We override `__getstate__()` to customize how instances of this class are pickled since we don't want to persist certain attributes. """ out = super().__getstate__() out["locks"] = {} return out @property def url(self) -> str: return "local://" + str(self.dir) @classmethod def from_parsed_url(cls, parsed_url: ParseResult) -> "Workspace": workspace_dir: Path if parsed_url.netloc: workspace_dir = Path(parsed_url.netloc) if parsed_url.path: workspace_dir = workspace_dir / parsed_url.path.lstrip("/") elif parsed_url.path: workspace_dir = Path(parsed_url.path) else: workspace_dir = Path(".") return cls(workspace_dir.resolve()) def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path: return self.cache.step_dir(step_or_unique_id) @property def step_cache(self) -> StepCache: return self.cache def work_dir(self, step: Step) -> Path: result = self.step_dir(step) / "work" result.mkdir(parents=True, exist_ok=True) return result @classmethod def guess_step_dir_state(cls, dir: Path) -> Set[StepState]: """ Returns the possible states of a given step dir, to the best of our knowledge. :param dir: the step dir to example :return: a set of possible states for the step """ # If the directory doesn't exist, the step is incomplete or uncacheable. if not dir.exists(): return {StepState.INCOMPLETE, StepState.UNCACHEABLE} # If the lock file exists and is locked, the step is running. lock_file = dir / "lock" if lock_file.exists(): lock = FileLock(lock_file) try: lock.acquire(0.1) lock.release() except TimeoutError: return {StepState.RUNNING} # If the directory is empty except for the work dir and the lock file, the step is running, incomplete, # or failed. But it can't be running because then the lockfile would be locked, so it can only be # incomplete or failed. for dir_entry in dir.iterdir(): if dir_entry.name == "work" and dir_entry.is_dir(): continue if dir_entry.name == "lock" and dir_entry.is_file(): continue break else: return {StepState.INCOMPLETE, StepState.FAILED} return set(StepState) @staticmethod def _fix_step_info(step_info: StepInfo) -> None: """ Tragically we need to run a fix-up step over StepInfo objects that are freshly read from the database. This is for backwards compatibility. This function operates on the `step_info` object in place. """ if isinstance(step_info.error, BaseException): step_info.error = exception_to_string(step_info.error) def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: with SqliteDict(self.step_info_file) as d: def find_or_add_step_info(step_or_unique_id: Union[Step, str]) -> StepInfo: if isinstance(step_or_unique_id, Step): unique_id = step_or_unique_id.unique_id else: unique_id = step_or_unique_id try: step_info = d[unique_id] except KeyError: if not isinstance(step_or_unique_id, Step): raise step = step_or_unique_id for dep in step.dependencies: find_or_add_step_info(dep) step_info = StepInfo.new_from_step(step) d[unique_id] = step_info del step # Perform some sanity checks. Sqlite and the file system can get out of sync # when a process dies suddenly. step_dir = self.step_dir(unique_id) step_state_guesses = self.guess_step_dir_state(step_dir) or step_info.state if step_info.state not in step_state_guesses: if step_info.state == StepState.RUNNING: # We think the step is running, but it can't possibly be running, so we go ahead and # assume the step is incomplete. step_info.start_time = None step_info.end_time = None d[unique_id] = step_info else: possible_states = ", ".join(s.value for s in step_state_guesses) raise IOError( f"The step '{unique_id}' is marked as being {step_info.state.value}, but we " f"determined it can only be one of {{{possible_states}}}. If you are positive " f"this is a screw-up, delete the directory at '{step_dir}' and try again." ) return step_info result = find_or_add_step_info(step_or_unique_id) d.commit() self._fix_step_info(result) return result def _step_lock_file(self, step_or_unique_id: Union[Step, str]) -> Path: step_dir = self.step_dir(step_or_unique_id) step_dir.mkdir(parents=True, exist_ok=True) return step_dir / "lock" def step_starting(self, step: Step) -> None: # We don't do anything with uncacheable steps. if not step.cache_results: return # Gather the existing step info first. Step info automatically fixes itself if steps are # marked as "running" but are not locked. This happens, for example, when a process # gets killed. To make sure this works, we have to get the step info before we start # messing with locks. step_info = self.step_info(step) if step_info.state not in {StepState.INCOMPLETE, StepState.FAILED}: raise StepStateError( step, step_info.state, context="If you are certain the step is not running somewhere else, delete the lock " f"file at {self._step_lock_file(step)}.", ) if step_info.state == StepState.FAILED: # Refresh environment metadata since it might be out-of-date now. step_info.refresh() lock = FileLock(self._step_lock_file(step), read_only_ok=True) lock.acquire_with_updates(desc=f"acquiring lock for '{step.name}'") self.locks[step] = lock try: step_info.start_time = utc_now_datetime() step_info.end_time = None step_info.error = None step_info.result_location = None with SqliteDict(self.step_info_file) as d: d[step.unique_id] = step_info d.commit() except: # noqa: E722 lock.release() del self.locks[step] raise def step_finished(self, step: Step, result: T) -> T: # We don't do anything with uncacheable steps. if not step.cache_results: return result lock = self.locks[step] step_info = self.step_info(step) if step_info.state != StepState.RUNNING: raise StepStateError(step, step_info.state) self.step_cache[step] = result if hasattr(result, "__next__"): assert isinstance(result, Iterator) # Caching the iterator will consume it, so we write it to the cache and then read from the cache # for the return value. result = self.step_cache[step] # Mark the step as finished step_info.end_time = utc_now_datetime() step_info.result_location = str(self.step_dir(step).absolute()) with SqliteDict(self.step_info_file) as d: d[step.unique_id] = step_info d.commit() lock.release() del self.locks[step] return result def step_failed(self, step: Step, e: BaseException) -> None: # We don't do anything with uncacheable steps. if not step.cache_results: return lock = self.locks[step] try: step_info = self.step_info(step) if step_info.state != StepState.RUNNING: raise StepStateError(step, step_info.state) step_info.end_time = utc_now_datetime() step_info.error = exception_to_string(e) with SqliteDict(self.step_info_file) as d: d[step.unique_id] = step_info d.commit() finally: lock.release() del self.locks[step] def remove_step(self, step_unique_id: str) -> None: """ Get Step unique id from the user and remove the step information from cache :raises KeyError: If no step with the unique name found in the cache dir """ with SqliteDict(self.step_info_file) as d: try: step_info = self.step_info(step_unique_id) del d[step_unique_id] d.commit() del self.cache[step_info] except KeyError: raise KeyError(f"No step named '{step_unique_id}' found") def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: # sanity check targets targets = list(targets) if name is None: while name is None or (self.runs_dir / name).exists(): name = petname.generate() run_dir = self.runs_dir / name # clean any existing run directory if run_dir.exists(): for filename in run_dir.iterdir(): filename.unlink() else: run_dir.mkdir(parents=True, exist_ok=True) # write step info for all steps all_steps = set(targets) for step in targets: all_steps |= step.recursive_dependencies self._save_registered_run(name, all_steps) # write targets for target in targets: if target.cache_results: target_path = self.step_dir(target) (run_dir / target.name).symlink_to(os.path.relpath(target_path, run_dir)) self.latest_dir.unlink(missing_ok=True) self.latest_dir.symlink_to(run_dir) return self.registered_run(name) def registered_runs(self) -> Dict[str, Run]: return { str(run_dir.name): self.registered_run(run_dir.name) for run_dir in self.runs_dir.iterdir() if run_dir.is_dir() } def search_step_info( self, *, sort_by: Optional[StepInfoSort] = None, sort_descending: bool = True, match: Optional[str] = None, state: Optional[StepState] = None, start: int = 0, stop: Optional[int] = None, ) -> List[StepInfo]: with SqliteDict(self.step_info_file, flag="r") as d: steps = [ step for step in d.values() if (match is None or match in step.unique_id) and (state is None or step.state == state) ] if sort_by == StepInfoSort.START_TIME: now = utc_now_datetime() steps = sorted( steps, key=lambda step: step.start_time or now, reverse=sort_descending, ) elif sort_by == StepInfoSort.UNIQUE_ID: steps = sorted(steps, key=lambda step: step.unique_id, reverse=sort_descending) elif sort_by is not None: raise NotImplementedError return steps[slice(start, stop)] def registered_run(self, name: str) -> Run: run_dir = self.runs_dir / name if not run_dir.is_dir(): raise KeyError(name) steps_for_run = self._load_registered_run(name) return Run(name, steps_for_run, datetime.fromtimestamp(run_dir.stat().st_ctime)) def _run_step_info_file(self, name: str) -> Path: return self.runs_dir / name / "stepinfo.json" def _save_registered_run(self, name: str, all_steps: Iterable[Step]) -> None: step_unique_ids = {} with SqliteDict(self.step_info_file) as d: for step in all_steps: try: step_info = d[step.unique_id] step_info.name = step.name d[step.unique_id] = step_info except KeyError: d[step.unique_id] = StepInfo.new_from_step(step) step_unique_ids[step.name] = step.unique_id d.commit() run_step_info_file = self._run_step_info_file(name) with open(run_step_info_file, "w") as file_ref: json.dump(step_unique_ids, file_ref) def _load_registered_run(self, name: str) -> Dict[str, StepInfo]: run_step_info_file = self._run_step_info_file(name) try: with open(run_step_info_file, "r") as file_ref: step_ids = json.load(file_ref) except FileNotFoundError: # for backwards compatibility run_dir = self.runs_dir / name step_ids = {} for step_symlink in run_dir.iterdir(): if not step_symlink.is_symlink(): continue step_name = str(step_symlink.name) unique_id = str(step_symlink.resolve().name) step_ids[step_name] = unique_id with SqliteDict(self.step_info_file, flag="r") as d: steps_for_run = {} for step_name, unique_id in step_ids.items(): step_info = d[unique_id] assert isinstance(step_info, StepInfo) self._fix_step_info(step_info) steps_for_run[step_name] = step_info return steps_for_run def run_dir(self, name: str) -> Path: """ Returns the directory where a given run is stored. :param name: The name of the run :return: The directory where the results of the run are stored If the run does not exist, this returns the directory where it will be stored if you call :meth:`register_run()` with that name. """ return self.runs_dir / name def capture_logs_for_run(self, name: str): return file_handler(self.run_dir(name) / "out.log") ================================================ FILE: tango/workspaces/memory_workspace.py ================================================ import copy from typing import Dict, Iterable, Iterator, Optional, TypeVar, Union from urllib.parse import ParseResult import petname from tango.common.exceptions import StepStateError from tango.common.util import exception_to_string, utc_now_datetime from tango.step import Step from tango.step_cache import StepCache from tango.step_caches import default_step_cache from tango.step_info import StepInfo, StepState from tango.workspace import Run, Workspace T = TypeVar("T") @Workspace.register("memory") class MemoryWorkspace(Workspace): """ This is a workspace that keeps all its data in memory. This is useful for debugging or for quick jobs, but of course you don't get any caching across restarts. .. tip:: Registered as a :class:`~tango.workspace.Workspace` under the name "memory". """ def __init__(self): super().__init__() self.unique_id_to_info: Dict[str, StepInfo] = {} self.runs: Dict[str, Run] = {} @property def url(self) -> str: return "memory://" @classmethod def from_parsed_url(cls, parsed_url: ParseResult) -> "Workspace": return cls() @property def step_cache(self) -> StepCache: return default_step_cache def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: unique_id = ( step_or_unique_id.unique_id if isinstance(step_or_unique_id, Step) else step_or_unique_id ) try: return self.unique_id_to_info[unique_id] except KeyError: if isinstance(step_or_unique_id, Step): step = step_or_unique_id return StepInfo.new_from_step(step) else: raise KeyError() def step_starting(self, step: Step) -> None: # We don't do anything with uncacheable steps. if not step.cache_results: return self.unique_id_to_info[step.unique_id] = StepInfo.new_from_step( step, start_time=utc_now_datetime() ) def step_finished(self, step: Step, result: T) -> T: # We don't do anything with uncacheable steps. if not step.cache_results: return result existing_step_info = self.unique_id_to_info[step.unique_id] if existing_step_info.state != StepState.RUNNING: raise StepStateError(step, existing_step_info.state) existing_step_info.end_time = utc_now_datetime() if step.cache_results: self.step_cache[step] = result if hasattr(result, "__next__"): assert isinstance(result, Iterator) # Caching the iterator will consume it, so we write it to the cache and then read from the cache # for the return value. return self.step_cache[step] return result def step_failed(self, step: Step, e: BaseException) -> None: # We don't do anything with uncacheable steps. if not step.cache_results: return assert e is not None existing_step_info = self.unique_id_to_info[step.unique_id] if existing_step_info.state != StepState.RUNNING: raise StepStateError(step, existing_step_info.state) existing_step_info.end_time = utc_now_datetime() existing_step_info.error = exception_to_string(e) def remove_step(self, step_unique_id: str) -> None: """ Get Step unique id from the user and remove the step information from memory cache :raises KeyError: If no step with the unique name found in the cache dir """ try: step_info = self.step_info(step_unique_id) del self.unique_id_to_info[step_unique_id] del self.step_cache[step_info] except KeyError: raise KeyError(f"{step_unique_id} step info not found, step cache cannot be deleted") def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: if name is None: name = petname.generate() steps: Dict[str, StepInfo] = {} for step in targets: step_info = StepInfo.new_from_step(step) self.unique_id_to_info[step.unique_id] = step_info steps[step.unique_id] = step_info run = Run(name, steps, utc_now_datetime()) self.runs[name] = run return run def registered_runs(self) -> Dict[str, Run]: return copy.deepcopy(self.runs) def registered_run(self, name: str) -> Run: return copy.deepcopy(self.runs[name]) default_workspace = MemoryWorkspace() ================================================ FILE: tango/workspaces/remote_workspace.py ================================================ import logging import tempfile import warnings from abc import abstractmethod from contextlib import contextmanager from pathlib import Path from typing import Dict, Generator, Iterable, Iterator, Optional, Tuple, TypeVar, Union from urllib.parse import ParseResult from tango.common.exceptions import StepStateError from tango.common.logging import file_handler from tango.common.remote_utils import RemoteConstants from tango.common.util import exception_to_string, tango_cache_dir, utc_now_datetime from tango.step import Step from tango.step_caches.remote_step_cache import RemoteStepCache from tango.step_info import StepInfo, StepState from tango.workspace import Run, Workspace T = TypeVar("T") logger = logging.getLogger(__name__) class RemoteWorkspace(Workspace): """ This is a :class:`~tango.workspace.Workspace` that stores step artifacts on some remote storage location. .. tip:: All remote workspaces inherit from this. """ Constants = RemoteConstants NUM_CONCURRENT_WORKERS: int = 9 @property @abstractmethod def cache(self) -> RemoteStepCache: raise NotImplementedError() @property @abstractmethod def steps_dir_name(self) -> str: raise NotImplementedError() @property @abstractmethod def locks(self) -> Dict: raise NotImplementedError() @property def steps_dir(self) -> Path: return tango_cache_dir() / self.steps_dir_name @property @abstractmethod def url(self) -> str: raise NotImplementedError() @classmethod @abstractmethod def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace: raise NotImplementedError() @property def step_cache(self) -> RemoteStepCache: return self.cache def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path: unique_id = ( step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id ) path = self.steps_dir / unique_id path.mkdir(parents=True, exist_ok=True) return path def work_dir(self, step: Step) -> Path: path = self.step_dir(step) / "work" path.mkdir(parents=True, exist_ok=True) return path def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo: raise NotImplementedError() @abstractmethod def _remote_lock(self, step: Step): raise NotImplementedError() @abstractmethod def _step_location(self, step: Step) -> str: raise NotImplementedError() def step_starting(self, step: Step) -> None: # We don't do anything with uncacheable steps. if not step.cache_results: return # Get local file lock + remote dataset lock. lock = self._remote_lock(step) lock.acquire() self.locks[step] = lock step_info = self.step_info(step) if step_info.state == StepState.RUNNING: # Since we've acquired the step lock we know this step can't be running # elsewhere. But the step state can still say its running if the last warnings.warn( f"Step info for step '{step.unique_id}' is invalid - says step is running " "although it shouldn't be. Ignoring and overwriting step start time.", UserWarning, ) elif step_info.state not in {StepState.INCOMPLETE, StepState.FAILED, StepState.UNCACHEABLE}: self.locks.pop(step).release() raise StepStateError( step, step_info.state, context=f"If you are certain the step is not running somewhere else, delete the step " f"datasets at {self._step_location(step)}", ) if step_info.state == StepState.FAILED: # Refresh the environment metadata since it might be out-of-date now. step_info.refresh() # Update StepInfo to mark as running. try: step_info.start_time = utc_now_datetime() step_info.end_time = None step_info.error = None step_info.result_location = None self._update_step_info(step_info) except: # noqa: E722 self.locks.pop(step).release() raise def step_finished(self, step: Step, result: T) -> T: # We don't do anything with uncacheable steps. if not step.cache_results: return result step_info = self.step_info(step) if step_info.state != StepState.RUNNING: raise StepStateError(step, step_info.state) # Update step info and save step execution metadata. # This needs to be done *before* adding the result to the cache, since adding # the result to the cache will commit the step dataset, making it immutable. step_info.end_time = utc_now_datetime() step_info.result_location = self._step_location(step) self._update_step_info(step_info) self.cache[step] = result if hasattr(result, "__next__"): assert isinstance(result, Iterator) # Caching the iterator will consume it, so we write it to the cache and then read from the cache # for the return value. result = self.cache[step] self.locks.pop(step).release() return result def step_failed(self, step: Step, e: BaseException) -> None: # We don't do anything with uncacheable steps. if not step.cache_results: return try: step_info = self.step_info(step) if step_info.state != StepState.RUNNING: raise StepStateError(step, step_info.state) step_info.end_time = utc_now_datetime() step_info.error = exception_to_string(e) self._update_step_info(step_info) finally: self.locks.pop(step).release() def remove_step(self, step_unique_id: str) -> None: """ Get Step unique id from the user and remove the step information from cache :raises KeyError: If no step with the unique name found in the cache dir """ try: step_info = self.step_info(step_unique_id) # remove remote objects self._remove_step_info(step_info) # remove cache info del self.cache[step_info] except KeyError: raise KeyError(f"No step named '{step_unique_id}' found.") return None def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]: import concurrent.futures all_steps = set(targets) for step in targets: all_steps |= step.recursive_dependencies steps: Dict[str, StepInfo] = {} run_data: Dict[str, str] = {} # Collect step info. with concurrent.futures.ThreadPoolExecutor( thread_name_prefix="RemoteWorkspace._get_run_step_info()-" ) as executor: step_info_futures = [] for step in all_steps: if step.name is None: continue step_info_futures.append(executor.submit(self.step_info, step)) for future in concurrent.futures.as_completed(step_info_futures): step_info = future.result() assert step_info.step_name is not None steps[step_info.step_name] = step_info run_data[step_info.step_name] = step_info.unique_id return steps, run_data @abstractmethod def _save_run( self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None ) -> Run: raise NotImplementedError() def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run: steps, run_data = self._get_run_step_info(targets) run = self._save_run(steps, run_data, name) return run @abstractmethod def _save_run_log(self, name: str, log_file: Path): raise NotImplementedError() @contextmanager def capture_logs_for_run(self, name: str) -> Generator[None, None, None]: with tempfile.TemporaryDirectory() as tmp_dir_name: log_file = Path(tmp_dir_name) / "out.log" try: with file_handler(log_file): yield None finally: self._save_run_log(name, log_file) @abstractmethod def _update_step_info(self, step_info: StepInfo): raise NotImplementedError() @abstractmethod def _remove_step_info(self, step_info: StepInfo): raise NotImplementedError() ================================================ FILE: test_fixtures/__init__.py ================================================ ================================================ FILE: test_fixtures/beaker/nvidia_smi.yml ================================================ # Used to test that GPUs in a cluster are available. Submit this to beaker with: # $ beaker experiment create test_fixtures/beaker/nvidia_smi.yml --workspace ai2/tango-testing --name tango-test-1 version: v2-alpha description: NvidiaSMI tasks: - name: nvidia-smi image: docker: nvidia/cuda:11.0-base command: [nvidia-smi] result: path: '/unused' resources: gpuCount: 2 context: cluster: ai2/tango-gpu-tests priority: normal ================================================ FILE: test_fixtures/common/params_example.jsonnet ================================================ { "model": { "type": "classifier", "num_classes": 3, "layers": [ { "type": "ff", "activation": "relu", }, { "type": "ff", "activation": "softmax", }, ] }, "data_path": "data.txt", } ================================================ FILE: test_fixtures/common/params_example.yaml ================================================ model: type: classifier num_classes: 3 layers: - type: ff activation: relu - type: ff activation: softmax data_path: data.txt ================================================ FILE: test_fixtures/experiment/hello_world.jsonnet ================================================ { "steps": { "hello": {"type": "string", "result": "Hello"}, "hello_world": { "type": "concat_strings", "string1": {"type": "ref", "ref": "hello"}, "string2": "World!", "join_with": ", ", }, }, } ================================================ FILE: test_fixtures/experiment/logging_check.jsonnet ================================================ { "steps": { "stringA": {"type": "logging-step", "string": "This is a logging test.", "num_log_lines": 5}, "stringB": { "type": "concat_strings", "string1": {"type": "ref", "ref": "stringA"}, "string2": "This is being logged." }, "stringC": {"type": "logging-step", "string": "This is also a logging test.", "num_log_lines": 5}, "final_string": { "type": "logging-step", "string": {"type": "ref", "ref": "stringB"}, "num_log_lines": 3 }, "multiprocessing_result": { "type": "multiprocessing_step", } } } ================================================ FILE: test_fixtures/experiment/multiprocessing.jsonnet ================================================ { "steps": { "result": { "type": "multiprocessing_step", } } } ================================================ FILE: test_fixtures/experiment/noisy.jsonnet ================================================ { steps: { hello_world: { type: "string", result: "Hello, World!" }, noisy_step: { type: "noisy_step" }, } } ================================================ FILE: test_fixtures/experiment/random.jsonnet ================================================ { "steps": { "rand_string1": {"type": "random_string", "length": 5}, "rand_string2": {"type": "random_string", "length": 5}, "string1": { "type": "concat_strings", "string1": {"type": "ref", "ref": "rand_string1"}, "string2": {"type": "ref", "ref": "rand_string2"}, }, "string2": { "type": "string", "result": "foo", }, "final_string": { "type": "concat_strings", "string1": {"type": "ref", "ref": "string1"}, "string2": {"type": "ref", "ref": "string2"}, } } } ================================================ FILE: test_fixtures/integrations/__init__.py ================================================ ================================================ FILE: test_fixtures/integrations/common/__init__.py ================================================ import torch from torch.utils.data import IterableDataset from tango import Step from tango.common import DatasetDict, IterableDatasetDict @Step.register("generate_data") class GenerateData(Step): DETERMINISTIC = True CACHEABLE = False def run(self) -> DatasetDict: # type: ignore[override] torch.manual_seed(1) return DatasetDict( { "train": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)], "validation": [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)], } ) class RandomIterableDataset(IterableDataset): def __init__(self, data): self.data = data def __iter__(self): return iter(self.data) @Step.register("generate_streaming_data") class GenerateStreamingData(Step): DETERMINISTIC = True CACHEABLE = False def run(self) -> IterableDatasetDict: # type: ignore[override] torch.manual_seed(1) return IterableDatasetDict( { "train": RandomIterableDataset( [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(64)] ), "validation": RandomIterableDataset( [{"x": torch.rand(10), "y": torch.rand(1)} for _ in range(32)] ), } ) ================================================ FILE: test_fixtures/integrations/datasets/config.json ================================================ { "steps": { "train_data": { "type": "datasets::load", "path": "lhoestq/test", "split": "train" }, "dev_data": { "type": "datasets::load", "path": "lhoestq/test", "split": "validation" }, "all_data": { "type": "datasets::concatenate", "datasets": [ { "type": "ref", "ref": "train_data" }, { "type": "ref", "ref": "dev_data" } ] }, "mixed_data": { "type": "datasets::interleave", "datasets": [ { "type": "ref", "ref": "train_data" }, { "type": "ref", "ref": "dev_data" } ], "probabilities": [0.9, 0.1] } } } ================================================ FILE: test_fixtures/integrations/fairscale/__init__.py ================================================ ================================================ FILE: test_fixtures/integrations/fairscale/components.py ================================================ import torch import torch.nn as nn from tango import Step from tango.common import DatasetDict from tango.integrations.torch import Model from tango.integrations.torch.util import set_seed_all class FeedForward(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(4, 4) self.activation = nn.ReLU() def forward(self, x): return self.activation(self.linear(x)) @Model.register("simple_regression_model", exist_ok=True) class SimpleRegressionModel(Model): def __init__(self): super().__init__() self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)]) self.regression_head = nn.Linear(4, 1) self.loss_fcn = nn.MSELoss() def forward(self, x, y): output = self.blocks(x) output = self.regression_head(output) loss = self.loss_fcn(output, y) return {"loss": loss} @Step.register("simple_regression_data") class SimpleRegressionDataStep(Step): DETERMINISTIC = True CACHEABLE = False def run(self, seed: int = 317) -> DatasetDict: # type: ignore set_seed_all(seed) def get_data(n: int): return [{"x": torch.randn(4), "y": torch.randn(1)} for _ in range(n)] dataset_dict = DatasetDict(splits={"train": get_data(32), "dev": get_data(16)}) return dataset_dict ================================================ FILE: test_fixtures/integrations/fairscale/config.jsonnet ================================================ local pretrained_model = "sshleifer/tiny-gpt2"; #################### # Trainer settings # #################### local training_steps = 4; local validate_every = 4; local devices = 2; local grad_accum = 1; local batch_size = 2; local activation_checkpointing = true; local amp = false; local fsdp = true; local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. ###################### # Optimizer settings # ###################### local warmup_steps = 2; local learning_rate = 0.005; local fsdp_config = { reshard_after_forward: true, move_params_to_cpu: cpu_offloading, move_grads_to_cpu: cpu_offloading, mixed_precision: amp, }; local training_engine = { type: "fairscale", optimizer: { type: "torch::AdamW", lr: learning_rate, betas: [0.9, 0.95], eps: 1e-6, }, amp: amp, fsdp_config: fsdp_config, }; local dataloader = { batch_size: batch_size, sampler: { type: "torch::DistributedSampler", shuffle: true, drop_last: true, }, }; { steps: { regression_data: { type: "simple_regression_data", }, trained_model: { type: "torch::train", model: { type: "fairscale::with_wrapped_modules", model: { type: "simple_regression_model", }, modules_to_wrap: ["blocks\\.[0-9]+"], fsdp_config: fsdp_config, activation_checkpointing: activation_checkpointing, }, training_engine: training_engine, dataset_dict: { type: "ref", ref: "regression_data" }, train_dataloader: dataloader, validation_split: "dev", grad_accum: grad_accum, train_steps: training_steps, validate_every: training_steps, validation_steps: 2, checkpoint_every: training_steps, log_every: 1, device_count: devices, }, } } ================================================ FILE: test_fixtures/integrations/flax/__init__.py ================================================ ================================================ FILE: test_fixtures/integrations/flax/config.jsonnet ================================================ { "steps": { "data_full": { "type": "datasets::load", "path": "iohadrubin/mini_xsum", }, "data": { "type": "datasets::dataset_remix", "input": {"type": "ref", "ref": "data_full"}, "new_splits": {"train": "train[:20]", "validation": "validation[:20]"}, }, "tokenize": { "type": "tokenize_data", "dataset": { "type": "ref", "ref": "data" } }, "train": { "type": "flax::train", "model": { "type" : "transformers::FlaxAutoModelForSeq2SeqLM::from_pretrained", "pretrained_model_name_or_path" : "t5-small" }, "dataset": { "type": "ref", "ref": "tokenize" }, "optimizer": { "type" : "optax::adamw", "learning_rate" : 2e-5 }, "train_dataloader": { "batch_size": 16, "drop_last": true }, "wrapper": { "type": "xsum_wrapper" }, "train_split": "train", "validation_split" : "validation", "validate_every" : 1, "validation_dataloader": { "batch_size": 16, "drop_last": true }, "train_epoch": 1, "checkpoint_every": 1, "log_every": 1 } } } ================================================ FILE: test_fixtures/integrations/flax/xsum.py ================================================ import jax.numpy as jnp import numpy as np import optax from flax.training.common_utils import onehot from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSeq2SeqLM from tango.integrations.flax import FlaxWrapper from tango.step import Step """ A minimal xsum t5-small config for testing. """ @Step.register("tokenize_data") class PreProcessing(Step): DETERMINISTIC = False def run(self, dataset): tokenizer = AutoTokenizer.from_pretrained("t5-small") model = FlaxAutoModelForSeq2SeqLM.from_pretrained("t5-small") model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") config = AutoConfig.from_pretrained("t5-small") MAX_SOURCE_LENGTH = 512 MAX_TGT_LENGTH = 64 def preprocess_function(examples): inputs = examples["document"] targets = examples["summary"] inputs = [inp for inp in inputs] model_inputs = tokenizer( inputs, max_length=MAX_SOURCE_LENGTH, padding="max_length", truncation=True, return_tensors="np", ) # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): labels = tokenizer( targets, max_length=MAX_TGT_LENGTH, padding="max_length", truncation=True, return_tensors="np", ) model_inputs["labels"] = labels["input_ids"] decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id ) model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs column_names = dataset["train"].column_names dataset = dataset.map( preprocess_function, batched=True, remove_columns=column_names, desc="Running tokenizer on dataset", ) return dataset @FlaxWrapper.register("xsum_wrapper") # type: ignore class TransformerWrapper(FlaxWrapper): def train_metrics(self, state, batch, labels): # return empty dict if no other metrics to compute return {} def loss_helper(self, logits, labels, batch): label_smoothing_factor = 0 padding_mask = batch["decoder_attention_mask"] vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20) ) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.sum() / padding_mask.sum() return loss def train_loss(self, params, state, batch, dropout_rng, labels): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = self.loss_helper(logits, labels, batch) return loss def val_metrics(self, batch, logits, labels): loss = self.loss_helper(logits, labels, batch) metrics = {"loss": loss} return metrics def eval_metrics(self, batch, logits, labels): loss = self.loss_helper(logits, labels, batch) metrics = {"loss": loss} return metrics ================================================ FILE: test_fixtures/integrations/torch/__init__.py ================================================ import torch.nn as nn from tango.integrations.torch import Model @Model.register("basic_regression") class BasicRegression(Model): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) self.sigmoid = nn.Sigmoid() self.mse = nn.MSELoss() def forward(self, x, y=None): pred = self.sigmoid(self.linear(x)) out = {"pred": pred} if y is not None: out["loss"] = self.mse(pred, y) return out def _to_params(self): return {} ================================================ FILE: test_fixtures/integrations/torch/eval.jsonnet ================================================ { "steps": { "data": { "type": "generate_data", }, "eval": { "type": "torch::eval", "model": { "type": "basic_regression", }, "dataset_dict": { "type": "ref", "ref": "data" }, "dataloader": { "batch_size": 8, "shuffle": true }, "test_split": "validation", "log_every": 1 } } } ================================================ FILE: test_fixtures/integrations/torch/train.jsonnet ================================================ { "steps": { "data": { "type": "generate_data", }, "train": { "type": "torch::train", "model": { "type": "basic_regression", }, "training_engine": { "optimizer": { "type": "torch::Adam", }, }, "dataset_dict": { "type": "ref", "ref": "data" }, "train_dataloader": { "batch_size": 8, "shuffle": true }, "validation_split": "validation", "validation_dataloader": { "batch_size": 8, "shuffle": false }, "train_steps": 100, "validate_every": 10, "checkpoint_every": 10, "log_every": 1 } } } ================================================ FILE: test_fixtures/integrations/torch/train_dist.jsonnet ================================================ { "steps": { "data": { "type": "generate_data", }, "train": { "type": "torch::train", "model": { "type": "basic_regression", }, "training_engine": { "optimizer": { "type": "torch::Adam", }, }, "dataset_dict": { "type": "ref", "ref": "data", }, "train_dataloader": { "batch_size": 8, "sampler": { "type": "torch::DistributedSampler", "shuffle": true, "drop_last": true, } }, "validation_split": "validation", "validation_dataloader": { "batch_size": 8, "sampler": { "type": "torch::DistributedSampler", "shuffle": true, "drop_last": true, } }, "train_steps": 100, "validate_every": 10, "checkpoint_every": 10, "log_every": 1, "device_count": 2, } } } ================================================ FILE: test_fixtures/integrations/torch/train_streaming.jsonnet ================================================ { "steps": { "data": { "type": "generate_streaming_data", }, "train": { "type": "torch::train", "model": { "type": "basic_regression", }, "training_engine": { "optimizer": { "type": "torch::Adam", }, }, "dataset_dict": { "type": "ref", "ref": "data" }, "train_dataloader": { "batch_size": 8, "shuffle": true }, "train_steps": 100, "checkpoint_every": 10, "log_every": 1 } } } ================================================ FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/cache-metadata.json ================================================ { "step": "AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP" } ================================================ FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/conda-environment.yaml ================================================ name: tango channels: - pytorch - defaults dependencies: - appnope=0.1.2=py38hecd8cb5_1001 - backcall=0.2.0=pyhd3eb1b0_0 - blas=1.0=mkl - bzip2=1.0.8=h1de35cc_0 - ca-certificates=2021.10.26=hecd8cb5_2 - certifi=2021.10.8=py38hecd8cb5_0 - decorator=5.1.0=pyhd3eb1b0_0 - ffmpeg=4.3=h0a44026_0 - freetype=2.11.0=hd8bbffd_0 - gettext=0.21.0=h7535e17_0 - giflib=5.2.1=haf1e3a3_0 - gmp=6.2.1=h23ab428_2 - gnutls=3.6.15=hed9c0bf_0 - icu=58.2=h0a44026_3 - intel-openmp=2021.4.0=hecd8cb5_3538 - ipython=7.29.0=py38h01d92e1_0 - jedi=0.18.0=py38hecd8cb5_1 - jpeg=9d=h9ed2024_0 - lame=3.100=h1de35cc_0 - lcms2=2.12=hf1fd2bf_0 - libcxx=12.0.0=h2f01273_0 - libffi=3.3=hb1e8313_2 - libiconv=1.16=h1de35cc_0 - libidn2=2.3.2=h9ed2024_0 - libpng=1.6.37=ha441bb4_0 - libtasn1=4.16.0=h9ed2024_0 - libtiff=4.2.0=h87d7836_0 - libunistring=0.9.10=h9ed2024_0 - libuv=1.40.0=haf1e3a3_0 - libwebp=1.2.0=hacca55c_0 - libwebp-base=1.2.0=h9ed2024_0 - libxml2=2.9.12=hcdb78fc_0 - llvm-openmp=12.0.0=h0dcd299_1 - lz4-c=1.9.3=h23ab428_1 - matplotlib-inline=0.1.2=pyhd3eb1b0_2 - mkl=2021.4.0=hecd8cb5_637 - mkl-service=2.4.0=py38h9ed2024_0 - mkl_fft=1.3.1=py38h4ab4a9b_0 - mkl_random=1.2.2=py38hb2f4e1b_0 - ncurses=6.3=hca72f7f_1 - nettle=3.7.3=h230ac6f_1 - numpy=1.21.2=py38h4b4dc7a_0 - numpy-base=1.21.2=py38he0bd621_0 - olefile=0.46=pyhd3eb1b0_0 - openh264=2.1.0=hd9629dc_0 - openssl=1.1.1l=h9ed2024_0 - parso=0.8.2=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 - pillow=8.4.0=py38h98e4679_0 - pip=21.2.4=py38hecd8cb5_0 - prompt-toolkit=3.0.20=pyhd3eb1b0_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pygments=2.10.0=pyhd3eb1b0_0 - python=3.8.12=h88f2d9e_0 - pytorch=1.10.0=py3.8_0 - readline=8.1=h9ed2024_0 - setuptools=58.0.4=py38hecd8cb5_0 - six=1.16.0=pyhd3eb1b0_0 - sqlite=3.36.0=hce871da_0 - tk=8.6.11=h7bc2e8c_0 - torchaudio=0.10.0=py38_cpu - torchvision=0.11.1=py38_cpu - traitlets=5.1.0=pyhd3eb1b0_0 - typing_extensions=3.10.0.2=pyh06a4308_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.37.0=pyhd3eb1b0_1 - xz=5.2.5=h1de35cc_0 - zlib=1.2.11=h1de35cc_3 - zstd=1.4.9=h322a384_0 - pip: - absl-py==0.15.0 - aiohttp==3.8.0 - aiosignal==1.2.0 - alabaster==0.7.12 - astunparse==1.6.3 - async-timeout==4.0.0 - attrs==21.2.0 - babel==2.9.1 - base58==2.1.1 - beautifulsoup4==4.10.0 - black==21.12b0 - bleach==4.1.0 - boto3==1.19.12 - botocore==1.22.12 - cached-path==1.0.0 - cachetools==4.2.4 - charset-normalizer==2.0.7 - click==8.0.3 - click-help-colors==0.9.1 - codecov==2.1.12 - colorama==0.4.4 - configparser==5.1.0 - coverage==6.1.1 - datasets==1.15.1 - dill==0.3.4 - docker-pycreds==0.4.0 - docutils==0.17.1 - filelock==3.4.0 - flake8==4.0.1 - flaky==3.7.0 - flatbuffers==2.0 - frozenlist==1.2.0 - fsspec==2021.11.0 - furo==2022.1.2 - future==0.18.2 - gast==0.4.0 - gitdb==4.0.9 - gitpython==3.1.24 - glob2==0.7 - google-api-core==2.2.2 - google-auth==2.3.3 - google-auth-oauthlib==0.4.6 - google-cloud-core==2.1.0 - google-cloud-storage==1.42.3 - google-crc32c==1.3.0 - google-pasta==0.2.0 - google-resumable-media==2.1.0 - googleapis-common-protos==1.53.0 - grpcio==1.41.1 - h5py==3.6.0 - huggingface-hub==0.1.1 - idna==3.3 - imagesize==1.2.0 - importlib-metadata==4.8.1 - iniconfig==1.1.1 - isort==5.10.1 - jinja2==3.0.2 - jmespath==0.10.0 - joblib==1.1.0 - jsonnet==0.17.0 - keras==2.7.0 - keras-preprocessing==1.1.2 - keyring==23.2.1 - libclang==12.0.0 - livereload==2.6.3 - markdown==3.3.4 - markdown-it-py==1.1.0 - markupsafe==2.0.1 - mccabe==0.6.1 - mdit-py-plugins==0.3.0 - more-itertools==8.10.0 - multidict==5.2.0 - multiprocess==0.70.12.2 - mypy==0.931 - mypy-extensions==0.4.3 - myst-parser==0.16.1 - nltk==3.6.7 - oauthlib==3.1.1 - opt-einsum==3.3.0 - overrides==6.1.0 - packaging==21.2 - pandas==1.3.4 - pathspec==0.9.0 - pathtools==0.1.2 - petname==2.6 - pkginfo==1.7.1 - platformdirs==2.4.0 - pluggy==1.0.0 - promise==2.3 - protobuf==3.19.1 - psutil==5.8.0 - py==1.11.0 - pyarrow==6.0.0 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pycodestyle==2.8.0 - pydeprecate==0.3.1 - pyflakes==2.4.0 - pyparsing==2.4.7 - pytest==6.2.5 - pytest-cov==3.0.0 - pytest-sphinx==0.3.1 - python-dateutil==2.8.2 - pytorch-lightning==1.5.1 - pytz==2021.3 - pyyaml==6.0 - readme-renderer==30.0 - regex==2021.11.2 - requests==2.26.0 - requests-oauthlib==1.3.0 - requests-toolbelt==0.9.1 - rfc3986==1.5.0 - rouge-score==0.0.4 - rsa==4.7.2 - s3transfer==0.5.0 - sacremoses==0.0.46 - sentencepiece==0.1.96 - sentry-sdk==1.4.3 - shortuuid==1.0.1 - smmap==5.0.0 - snowballstemmer==2.1.0 - soupsieve==2.3 - sphinx==4.3.1 - sphinx-autobuild==2021.3.14 - sphinx-copybutton==0.4.0 - sphinxcontrib-applehelp==1.0.2 - sphinxcontrib-devhelp==1.0.2 - sphinxcontrib-htmlhelp==2.0.0 - sphinxcontrib-jsmath==1.0.1 - sphinxcontrib-qthelp==1.0.3 - sphinxcontrib-serializinghtml==1.1.5 - sqlitedict==1.7.0 - subprocess32==3.5.4 - tensorboard==2.7.0 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.0 - tensorflow==2.7.0 - tensorflow-estimator==2.7.0 - tensorflow-io-gcs-filesystem==0.23.1 - termcolor==1.1.0 - tokenizers==0.10.3 - toml==0.10.2 - tomli==1.2.2 - torchmetrics==0.6.0 - tornado==6.1 - tqdm==4.62.3 - transformers==4.12.3 - twine==3.5.0 - types-pyyaml==6.0.0 - types-setuptools==57.4.2 - typing-utils==0.1.0 - urllib3==1.26.7 - wandb==0.12.6 - webencodings==0.5.1 - werkzeug==2.0.2 - wrapt==1.13.3 - xxhash==2.0.2 - yarl==1.7.2 - yaspin==2.1.0 - zipp==3.6.0 prefix: /opt/miniconda3/envs/tango ================================================ FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/executor-metadata.json ================================================ { "config": { "type": "cadd", "a": { "type": "ref", "ref": "CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk" }, "b": { "type": "ref", "ref": "MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf" } }, "duration": 0.0007, "finished_at": 1642546363.9658601, "git": { "commit": "8e09b66caffbff20fd0b1504c961932b97417e8d", "remote": "https://github.com/allenai/tango.git" }, "platform": { "cpu_count": 16, "executable": "/opt/miniconda3/envs/tango/bin/python", "host": "ip-192-168-1-194.us-west-2.compute.internal", "operating_system": "macOS-10.16-x86_64-i386-64bit", "python": "3.8.12", "root": "/Users/dirkg/Documents/tango/examples/euler", "user": "dirkg" }, "started_at": 1642546363.965193, "step": "AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP", "tango": { "command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic", "version": "0.4.0rc4" } } ================================================ FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/lock ================================================ ================================================ FILE: test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/requirements.txt ================================================ absl-py==0.15.0 ai2-tango==0.4.0rc1 aiohttp==3.8.0 aiosignal==1.2.0 alabaster==0.7.12 appnope==0.1.2 astunparse==1.6.3 async-timeout==4.0.0 attrs==21.2.0 babel==2.9.1 backcall==0.2.0 base58==2.1.1 beautifulsoup4==4.10.0 black==21.12b0 bleach==4.1.0 boto3==1.19.12 botocore==1.22.12 cached-path==1.0.0 cachetools==4.2.4 certifi==2021.10.8 charset-normalizer==2.0.7 click-help-colors==0.9.1 click==8.0.3 codecov==2.1.12 colorama==0.4.4 configparser==5.1.0 coverage==6.1.1 datasets==1.15.1 decorator==5.1.0 dill==0.3.4 docker-pycreds==0.4.0 docutils==0.17.1 filelock==3.4.0 flake8==4.0.1 flaky==3.7.0 flatbuffers==2.0 frozenlist==1.2.0 fsspec==2021.11.0 furo==2022.1.2 future==0.18.2 gast==0.4.0 gitdb==4.0.9 gitpython==3.1.24 glob2==0.7 google-api-core==2.2.2 google-auth-oauthlib==0.4.6 google-auth==2.3.3 google-cloud-core==2.1.0 google-cloud-storage==1.42.3 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==2.1.0 googleapis-common-protos==1.53.0 grpcio==1.41.1 h5py==3.6.0 huggingface-hub==0.1.1 idna==3.3 imagesize==1.2.0 importlib-metadata==4.8.1 iniconfig==1.1.1 ipython==7.29.0 isort==5.10.1 jedi==0.18.0 jinja2==3.0.2 jmespath==0.10.0 joblib==1.1.0 jsonnet==0.17.0 keras-preprocessing==1.1.2 keras==2.7.0 keyring==23.2.1 libclang==12.0.0 livereload==2.6.3 markdown-it-py==1.1.0 markdown==3.3.4 markupsafe==2.0.1 matplotlib-inline==0.1.2 mccabe==0.6.1 mdit-py-plugins==0.3.0 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 more-itertools==8.10.0 multidict==5.2.0 multiprocess==0.70.12.2 mypy-extensions==0.4.3 mypy==0.931 myst-parser==0.16.1 nltk==3.6.7 numpy==1.21.2 oauthlib==3.1.1 olefile==0.46 opt-einsum==3.3.0 overrides==6.1.0 packaging==21.2 pandas==1.3.4 parso==0.8.2 pathspec==0.9.0 pathtools==0.1.2 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==8.4.0 pip==21.2.4 pkginfo==1.7.1 platformdirs==2.4.0 pluggy==1.0.0 promise==2.3 prompt-toolkit==3.0.20 protobuf==3.19.1 psutil==5.8.0 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.0 pyasn1-modules==0.2.8 pyasn1==0.4.8 pycodestyle==2.8.0 pydeprecate==0.3.1 pyflakes==2.4.0 pygments==2.10.0 pyparsing==2.4.7 pytest-cov==3.0.0 pytest-sphinx==0.3.1 pytest==6.2.5 python-dateutil==2.8.2 pytorch-lightning==1.5.1 pytz==2021.3 pyyaml==6.0 readme-renderer==30.0 regex==2021.11.2 requests-oauthlib==1.3.0 requests-toolbelt==0.9.1 requests==2.26.0 rfc3986==1.5.0 rouge-score==0.0.4 rsa==4.7.2 s3transfer==0.5.0 sacremoses==0.0.46 sentencepiece==0.1.96 sentry-sdk==1.4.3 setuptools==58.0.4 shortuuid==1.0.1 six==1.16.0 smmap==5.0.0 snowballstemmer==2.1.0 soupsieve==2.3 sphinx-autobuild==2021.3.14 sphinx-copybutton==0.4.0 sphinx==4.3.1 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlitedict==1.7.0 subprocess32==3.5.4 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorboard==2.7.0 tensorflow-estimator==2.7.0 tensorflow-io-gcs-filesystem==0.23.1 tensorflow==2.7.0 termcolor==1.1.0 tokenizers==0.10.3 toml==0.10.2 tomli==1.2.2 torch==1.10.0 torchaudio==0.10.0 torchmetrics==0.6.0 torchvision==0.11.1 tornado==6.1 tqdm==4.62.3 traitlets==5.1.0 transformers==4.12.3 twine==3.5.0 types-pyyaml==6.0.0 types-setuptools==57.4.2 typing-extensions==3.10.0.2 typing-utils==0.1.0 urllib3==1.26.7 wandb==0.12.6 wcwidth==0.2.5 webencodings==0.5.1 werkzeug==2.0.2 wheel==0.37.0 wrapt==1.13.3 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==3.6.0 ================================================ FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/cache-metadata.json ================================================ { "step": "CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk" } ================================================ FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/conda-environment.yaml ================================================ name: tango channels: - pytorch - defaults dependencies: - appnope=0.1.2=py38hecd8cb5_1001 - backcall=0.2.0=pyhd3eb1b0_0 - blas=1.0=mkl - bzip2=1.0.8=h1de35cc_0 - ca-certificates=2021.10.26=hecd8cb5_2 - certifi=2021.10.8=py38hecd8cb5_0 - decorator=5.1.0=pyhd3eb1b0_0 - ffmpeg=4.3=h0a44026_0 - freetype=2.11.0=hd8bbffd_0 - gettext=0.21.0=h7535e17_0 - giflib=5.2.1=haf1e3a3_0 - gmp=6.2.1=h23ab428_2 - gnutls=3.6.15=hed9c0bf_0 - icu=58.2=h0a44026_3 - intel-openmp=2021.4.0=hecd8cb5_3538 - ipython=7.29.0=py38h01d92e1_0 - jedi=0.18.0=py38hecd8cb5_1 - jpeg=9d=h9ed2024_0 - lame=3.100=h1de35cc_0 - lcms2=2.12=hf1fd2bf_0 - libcxx=12.0.0=h2f01273_0 - libffi=3.3=hb1e8313_2 - libiconv=1.16=h1de35cc_0 - libidn2=2.3.2=h9ed2024_0 - libpng=1.6.37=ha441bb4_0 - libtasn1=4.16.0=h9ed2024_0 - libtiff=4.2.0=h87d7836_0 - libunistring=0.9.10=h9ed2024_0 - libuv=1.40.0=haf1e3a3_0 - libwebp=1.2.0=hacca55c_0 - libwebp-base=1.2.0=h9ed2024_0 - libxml2=2.9.12=hcdb78fc_0 - llvm-openmp=12.0.0=h0dcd299_1 - lz4-c=1.9.3=h23ab428_1 - matplotlib-inline=0.1.2=pyhd3eb1b0_2 - mkl=2021.4.0=hecd8cb5_637 - mkl-service=2.4.0=py38h9ed2024_0 - mkl_fft=1.3.1=py38h4ab4a9b_0 - mkl_random=1.2.2=py38hb2f4e1b_0 - ncurses=6.3=hca72f7f_1 - nettle=3.7.3=h230ac6f_1 - numpy=1.21.2=py38h4b4dc7a_0 - numpy-base=1.21.2=py38he0bd621_0 - olefile=0.46=pyhd3eb1b0_0 - openh264=2.1.0=hd9629dc_0 - openssl=1.1.1l=h9ed2024_0 - parso=0.8.2=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 - pillow=8.4.0=py38h98e4679_0 - pip=21.2.4=py38hecd8cb5_0 - prompt-toolkit=3.0.20=pyhd3eb1b0_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pygments=2.10.0=pyhd3eb1b0_0 - python=3.8.12=h88f2d9e_0 - pytorch=1.10.0=py3.8_0 - readline=8.1=h9ed2024_0 - setuptools=58.0.4=py38hecd8cb5_0 - six=1.16.0=pyhd3eb1b0_0 - sqlite=3.36.0=hce871da_0 - tk=8.6.11=h7bc2e8c_0 - torchaudio=0.10.0=py38_cpu - torchvision=0.11.1=py38_cpu - traitlets=5.1.0=pyhd3eb1b0_0 - typing_extensions=3.10.0.2=pyh06a4308_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.37.0=pyhd3eb1b0_1 - xz=5.2.5=h1de35cc_0 - zlib=1.2.11=h1de35cc_3 - zstd=1.4.9=h322a384_0 - pip: - absl-py==0.15.0 - aiohttp==3.8.0 - aiosignal==1.2.0 - alabaster==0.7.12 - astunparse==1.6.3 - async-timeout==4.0.0 - attrs==21.2.0 - babel==2.9.1 - base58==2.1.1 - beautifulsoup4==4.10.0 - black==21.12b0 - bleach==4.1.0 - boto3==1.19.12 - botocore==1.22.12 - cached-path==1.0.0 - cachetools==4.2.4 - charset-normalizer==2.0.7 - click==8.0.3 - click-help-colors==0.9.1 - codecov==2.1.12 - colorama==0.4.4 - configparser==5.1.0 - coverage==6.1.1 - datasets==1.15.1 - dill==0.3.4 - docker-pycreds==0.4.0 - docutils==0.17.1 - filelock==3.4.0 - flake8==4.0.1 - flaky==3.7.0 - flatbuffers==2.0 - frozenlist==1.2.0 - fsspec==2021.11.0 - furo==2022.1.2 - future==0.18.2 - gast==0.4.0 - gitdb==4.0.9 - gitpython==3.1.24 - glob2==0.7 - google-api-core==2.2.2 - google-auth==2.3.3 - google-auth-oauthlib==0.4.6 - google-cloud-core==2.1.0 - google-cloud-storage==1.42.3 - google-crc32c==1.3.0 - google-pasta==0.2.0 - google-resumable-media==2.1.0 - googleapis-common-protos==1.53.0 - grpcio==1.41.1 - h5py==3.6.0 - huggingface-hub==0.1.1 - idna==3.3 - imagesize==1.2.0 - importlib-metadata==4.8.1 - iniconfig==1.1.1 - isort==5.10.1 - jinja2==3.0.2 - jmespath==0.10.0 - joblib==1.1.0 - jsonnet==0.17.0 - keras==2.7.0 - keras-preprocessing==1.1.2 - keyring==23.2.1 - libclang==12.0.0 - livereload==2.6.3 - markdown==3.3.4 - markdown-it-py==1.1.0 - markupsafe==2.0.1 - mccabe==0.6.1 - mdit-py-plugins==0.3.0 - more-itertools==8.10.0 - multidict==5.2.0 - multiprocess==0.70.12.2 - mypy==0.931 - mypy-extensions==0.4.3 - myst-parser==0.16.1 - nltk==3.6.7 - oauthlib==3.1.1 - opt-einsum==3.3.0 - overrides==6.1.0 - packaging==21.2 - pandas==1.3.4 - pathspec==0.9.0 - pathtools==0.1.2 - petname==2.6 - pkginfo==1.7.1 - platformdirs==2.4.0 - pluggy==1.0.0 - promise==2.3 - protobuf==3.19.1 - psutil==5.8.0 - py==1.11.0 - pyarrow==6.0.0 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pycodestyle==2.8.0 - pydeprecate==0.3.1 - pyflakes==2.4.0 - pyparsing==2.4.7 - pytest==6.2.5 - pytest-cov==3.0.0 - pytest-sphinx==0.3.1 - python-dateutil==2.8.2 - pytorch-lightning==1.5.1 - pytz==2021.3 - pyyaml==6.0 - readme-renderer==30.0 - regex==2021.11.2 - requests==2.26.0 - requests-oauthlib==1.3.0 - requests-toolbelt==0.9.1 - rfc3986==1.5.0 - rouge-score==0.0.4 - rsa==4.7.2 - s3transfer==0.5.0 - sacremoses==0.0.46 - sentencepiece==0.1.96 - sentry-sdk==1.4.3 - shortuuid==1.0.1 - smmap==5.0.0 - snowballstemmer==2.1.0 - soupsieve==2.3 - sphinx==4.3.1 - sphinx-autobuild==2021.3.14 - sphinx-copybutton==0.4.0 - sphinxcontrib-applehelp==1.0.2 - sphinxcontrib-devhelp==1.0.2 - sphinxcontrib-htmlhelp==2.0.0 - sphinxcontrib-jsmath==1.0.1 - sphinxcontrib-qthelp==1.0.3 - sphinxcontrib-serializinghtml==1.1.5 - sqlitedict==1.7.0 - subprocess32==3.5.4 - tensorboard==2.7.0 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.0 - tensorflow==2.7.0 - tensorflow-estimator==2.7.0 - tensorflow-io-gcs-filesystem==0.23.1 - termcolor==1.1.0 - tokenizers==0.10.3 - toml==0.10.2 - tomli==1.2.2 - torchmetrics==0.6.0 - tornado==6.1 - tqdm==4.62.3 - transformers==4.12.3 - twine==3.5.0 - types-pyyaml==6.0.0 - types-setuptools==57.4.2 - typing-utils==0.1.0 - urllib3==1.26.7 - wandb==0.12.6 - webencodings==0.5.1 - werkzeug==2.0.2 - wrapt==1.13.3 - xxhash==2.0.2 - yarl==1.7.2 - yaspin==2.1.0 - zipp==3.6.0 prefix: /opt/miniconda3/envs/tango ================================================ FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/executor-metadata.json ================================================ { "config": { "type": "ccos", "x": [ 3.1415926535, 0 ] }, "duration": 0.0004, "finished_at": 1642546350.3743181, "git": { "commit": "8e09b66caffbff20fd0b1504c961932b97417e8d", "remote": "https://github.com/allenai/tango.git" }, "platform": { "cpu_count": 16, "executable": "/opt/miniconda3/envs/tango/bin/python", "host": "ip-192-168-1-194.us-west-2.compute.internal", "operating_system": "macOS-10.16-x86_64-i386-64bit", "python": "3.8.12", "root": "/Users/dirkg/Documents/tango/examples/euler", "user": "dirkg" }, "started_at": 1642546350.373902, "step": "CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk", "tango": { "command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic", "version": "0.4.0rc4" } } ================================================ FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/lock ================================================ ================================================ FILE: test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/requirements.txt ================================================ absl-py==0.15.0 ai2-tango==0.4.0rc1 aiohttp==3.8.0 aiosignal==1.2.0 alabaster==0.7.12 appnope==0.1.2 astunparse==1.6.3 async-timeout==4.0.0 attrs==21.2.0 babel==2.9.1 backcall==0.2.0 base58==2.1.1 beautifulsoup4==4.10.0 black==21.12b0 bleach==4.1.0 boto3==1.19.12 botocore==1.22.12 cached-path==1.0.0 cachetools==4.2.4 certifi==2021.10.8 charset-normalizer==2.0.7 click-help-colors==0.9.1 click==8.0.3 codecov==2.1.12 colorama==0.4.4 configparser==5.1.0 coverage==6.1.1 datasets==1.15.1 decorator==5.1.0 dill==0.3.4 docker-pycreds==0.4.0 docutils==0.17.1 filelock==3.4.0 flake8==4.0.1 flaky==3.7.0 flatbuffers==2.0 frozenlist==1.2.0 fsspec==2021.11.0 furo==2022.1.2 future==0.18.2 gast==0.4.0 gitdb==4.0.9 gitpython==3.1.24 glob2==0.7 google-api-core==2.2.2 google-auth-oauthlib==0.4.6 google-auth==2.3.3 google-cloud-core==2.1.0 google-cloud-storage==1.42.3 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==2.1.0 googleapis-common-protos==1.53.0 grpcio==1.41.1 h5py==3.6.0 huggingface-hub==0.1.1 idna==3.3 imagesize==1.2.0 importlib-metadata==4.8.1 iniconfig==1.1.1 ipython==7.29.0 isort==5.10.1 jedi==0.18.0 jinja2==3.0.2 jmespath==0.10.0 joblib==1.1.0 jsonnet==0.17.0 keras-preprocessing==1.1.2 keras==2.7.0 keyring==23.2.1 libclang==12.0.0 livereload==2.6.3 markdown-it-py==1.1.0 markdown==3.3.4 markupsafe==2.0.1 matplotlib-inline==0.1.2 mccabe==0.6.1 mdit-py-plugins==0.3.0 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 more-itertools==8.10.0 multidict==5.2.0 multiprocess==0.70.12.2 mypy-extensions==0.4.3 mypy==0.931 myst-parser==0.16.1 nltk==3.6.7 numpy==1.21.2 oauthlib==3.1.1 olefile==0.46 opt-einsum==3.3.0 overrides==6.1.0 packaging==21.2 pandas==1.3.4 parso==0.8.2 pathspec==0.9.0 pathtools==0.1.2 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==8.4.0 pip==21.2.4 pkginfo==1.7.1 platformdirs==2.4.0 pluggy==1.0.0 promise==2.3 prompt-toolkit==3.0.20 protobuf==3.19.1 psutil==5.8.0 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.0 pyasn1-modules==0.2.8 pyasn1==0.4.8 pycodestyle==2.8.0 pydeprecate==0.3.1 pyflakes==2.4.0 pygments==2.10.0 pyparsing==2.4.7 pytest-cov==3.0.0 pytest-sphinx==0.3.1 pytest==6.2.5 python-dateutil==2.8.2 pytorch-lightning==1.5.1 pytz==2021.3 pyyaml==6.0 readme-renderer==30.0 regex==2021.11.2 requests-oauthlib==1.3.0 requests-toolbelt==0.9.1 requests==2.26.0 rfc3986==1.5.0 rouge-score==0.0.4 rsa==4.7.2 s3transfer==0.5.0 sacremoses==0.0.46 sentencepiece==0.1.96 sentry-sdk==1.4.3 setuptools==58.0.4 shortuuid==1.0.1 six==1.16.0 smmap==5.0.0 snowballstemmer==2.1.0 soupsieve==2.3 sphinx-autobuild==2021.3.14 sphinx-copybutton==0.4.0 sphinx==4.3.1 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlitedict==1.7.0 subprocess32==3.5.4 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorboard==2.7.0 tensorflow-estimator==2.7.0 tensorflow-io-gcs-filesystem==0.23.1 tensorflow==2.7.0 termcolor==1.1.0 tokenizers==0.10.3 toml==0.10.2 tomli==1.2.2 torch==1.10.0 torchaudio==0.10.0 torchmetrics==0.6.0 torchvision==0.11.1 tornado==6.1 tqdm==4.62.3 traitlets==5.1.0 transformers==4.12.3 twine==3.5.0 types-pyyaml==6.0.0 types-setuptools==57.4.2 typing-extensions==3.10.0.2 typing-utils==0.1.0 urllib3==1.26.7 wandb==0.12.6 wcwidth==0.2.5 webencodings==0.5.1 werkzeug==2.0.2 wheel==0.37.0 wrapt==1.13.3 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==3.6.0 ================================================ FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/cache-metadata.json ================================================ { "step": "ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9" } ================================================ FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/conda-environment.yaml ================================================ name: tango channels: - pytorch - defaults dependencies: - appnope=0.1.2=py38hecd8cb5_1001 - backcall=0.2.0=pyhd3eb1b0_0 - blas=1.0=mkl - bzip2=1.0.8=h1de35cc_0 - ca-certificates=2021.10.26=hecd8cb5_2 - certifi=2021.10.8=py38hecd8cb5_0 - decorator=5.1.0=pyhd3eb1b0_0 - ffmpeg=4.3=h0a44026_0 - freetype=2.11.0=hd8bbffd_0 - gettext=0.21.0=h7535e17_0 - giflib=5.2.1=haf1e3a3_0 - gmp=6.2.1=h23ab428_2 - gnutls=3.6.15=hed9c0bf_0 - icu=58.2=h0a44026_3 - intel-openmp=2021.4.0=hecd8cb5_3538 - ipython=7.29.0=py38h01d92e1_0 - jedi=0.18.0=py38hecd8cb5_1 - jpeg=9d=h9ed2024_0 - lame=3.100=h1de35cc_0 - lcms2=2.12=hf1fd2bf_0 - libcxx=12.0.0=h2f01273_0 - libffi=3.3=hb1e8313_2 - libiconv=1.16=h1de35cc_0 - libidn2=2.3.2=h9ed2024_0 - libpng=1.6.37=ha441bb4_0 - libtasn1=4.16.0=h9ed2024_0 - libtiff=4.2.0=h87d7836_0 - libunistring=0.9.10=h9ed2024_0 - libuv=1.40.0=haf1e3a3_0 - libwebp=1.2.0=hacca55c_0 - libwebp-base=1.2.0=h9ed2024_0 - libxml2=2.9.12=hcdb78fc_0 - llvm-openmp=12.0.0=h0dcd299_1 - lz4-c=1.9.3=h23ab428_1 - matplotlib-inline=0.1.2=pyhd3eb1b0_2 - mkl=2021.4.0=hecd8cb5_637 - mkl-service=2.4.0=py38h9ed2024_0 - mkl_fft=1.3.1=py38h4ab4a9b_0 - mkl_random=1.2.2=py38hb2f4e1b_0 - ncurses=6.3=hca72f7f_1 - nettle=3.7.3=h230ac6f_1 - numpy=1.21.2=py38h4b4dc7a_0 - numpy-base=1.21.2=py38he0bd621_0 - olefile=0.46=pyhd3eb1b0_0 - openh264=2.1.0=hd9629dc_0 - openssl=1.1.1l=h9ed2024_0 - parso=0.8.2=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 - pillow=8.4.0=py38h98e4679_0 - pip=21.2.4=py38hecd8cb5_0 - prompt-toolkit=3.0.20=pyhd3eb1b0_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pygments=2.10.0=pyhd3eb1b0_0 - python=3.8.12=h88f2d9e_0 - pytorch=1.10.0=py3.8_0 - readline=8.1=h9ed2024_0 - setuptools=58.0.4=py38hecd8cb5_0 - six=1.16.0=pyhd3eb1b0_0 - sqlite=3.36.0=hce871da_0 - tk=8.6.11=h7bc2e8c_0 - torchaudio=0.10.0=py38_cpu - torchvision=0.11.1=py38_cpu - traitlets=5.1.0=pyhd3eb1b0_0 - typing_extensions=3.10.0.2=pyh06a4308_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.37.0=pyhd3eb1b0_1 - xz=5.2.5=h1de35cc_0 - zlib=1.2.11=h1de35cc_3 - zstd=1.4.9=h322a384_0 - pip: - absl-py==0.15.0 - aiohttp==3.8.0 - aiosignal==1.2.0 - alabaster==0.7.12 - astunparse==1.6.3 - async-timeout==4.0.0 - attrs==21.2.0 - babel==2.9.1 - base58==2.1.1 - beautifulsoup4==4.10.0 - black==21.12b0 - bleach==4.1.0 - boto3==1.19.12 - botocore==1.22.12 - cached-path==1.0.0 - cachetools==4.2.4 - charset-normalizer==2.0.7 - click==8.0.3 - click-help-colors==0.9.1 - codecov==2.1.12 - colorama==0.4.4 - configparser==5.1.0 - coverage==6.1.1 - datasets==1.15.1 - dill==0.3.4 - docker-pycreds==0.4.0 - docutils==0.17.1 - filelock==3.4.0 - flake8==4.0.1 - flaky==3.7.0 - flatbuffers==2.0 - frozenlist==1.2.0 - fsspec==2021.11.0 - furo==2022.1.2 - future==0.18.2 - gast==0.4.0 - gitdb==4.0.9 - gitpython==3.1.24 - glob2==0.7 - google-api-core==2.2.2 - google-auth==2.3.3 - google-auth-oauthlib==0.4.6 - google-cloud-core==2.1.0 - google-cloud-storage==1.42.3 - google-crc32c==1.3.0 - google-pasta==0.2.0 - google-resumable-media==2.1.0 - googleapis-common-protos==1.53.0 - grpcio==1.41.1 - h5py==3.6.0 - huggingface-hub==0.1.1 - idna==3.3 - imagesize==1.2.0 - importlib-metadata==4.8.1 - iniconfig==1.1.1 - isort==5.10.1 - jinja2==3.0.2 - jmespath==0.10.0 - joblib==1.1.0 - jsonnet==0.17.0 - keras==2.7.0 - keras-preprocessing==1.1.2 - keyring==23.2.1 - libclang==12.0.0 - livereload==2.6.3 - markdown==3.3.4 - markdown-it-py==1.1.0 - markupsafe==2.0.1 - mccabe==0.6.1 - mdit-py-plugins==0.3.0 - more-itertools==8.10.0 - multidict==5.2.0 - multiprocess==0.70.12.2 - mypy==0.931 - mypy-extensions==0.4.3 - myst-parser==0.16.1 - nltk==3.6.7 - oauthlib==3.1.1 - opt-einsum==3.3.0 - overrides==6.1.0 - packaging==21.2 - pandas==1.3.4 - pathspec==0.9.0 - pathtools==0.1.2 - petname==2.6 - pkginfo==1.7.1 - platformdirs==2.4.0 - pluggy==1.0.0 - promise==2.3 - protobuf==3.19.1 - psutil==5.8.0 - py==1.11.0 - pyarrow==6.0.0 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pycodestyle==2.8.0 - pydeprecate==0.3.1 - pyflakes==2.4.0 - pyparsing==2.4.7 - pytest==6.2.5 - pytest-cov==3.0.0 - pytest-sphinx==0.3.1 - python-dateutil==2.8.2 - pytorch-lightning==1.5.1 - pytz==2021.3 - pyyaml==6.0 - readme-renderer==30.0 - regex==2021.11.2 - requests==2.26.0 - requests-oauthlib==1.3.0 - requests-toolbelt==0.9.1 - rfc3986==1.5.0 - rouge-score==0.0.4 - rsa==4.7.2 - s3transfer==0.5.0 - sacremoses==0.0.46 - sentencepiece==0.1.96 - sentry-sdk==1.4.3 - shortuuid==1.0.1 - smmap==5.0.0 - snowballstemmer==2.1.0 - soupsieve==2.3 - sphinx==4.3.1 - sphinx-autobuild==2021.3.14 - sphinx-copybutton==0.4.0 - sphinxcontrib-applehelp==1.0.2 - sphinxcontrib-devhelp==1.0.2 - sphinxcontrib-htmlhelp==2.0.0 - sphinxcontrib-jsmath==1.0.1 - sphinxcontrib-qthelp==1.0.3 - sphinxcontrib-serializinghtml==1.1.5 - sqlitedict==1.7.0 - subprocess32==3.5.4 - tensorboard==2.7.0 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.0 - tensorflow==2.7.0 - tensorflow-estimator==2.7.0 - tensorflow-io-gcs-filesystem==0.23.1 - termcolor==1.1.0 - tokenizers==0.10.3 - toml==0.10.2 - tomli==1.2.2 - torchmetrics==0.6.0 - tornado==6.1 - tqdm==4.62.3 - transformers==4.12.3 - twine==3.5.0 - types-pyyaml==6.0.0 - types-setuptools==57.4.2 - typing-utils==0.1.0 - urllib3==1.26.7 - wandb==0.12.6 - webencodings==0.5.1 - werkzeug==2.0.2 - wrapt==1.13.3 - xxhash==2.0.2 - yarl==1.7.2 - yaspin==2.1.0 - zipp==3.6.0 prefix: /opt/miniconda3/envs/tango ================================================ FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/executor-metadata.json ================================================ { "config": { "type": "cexp", "x": { "type": "ref", "ref": "MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae" } }, "duration": 0.0006, "finished_at": 1642546361.347647, "git": { "commit": "8e09b66caffbff20fd0b1504c961932b97417e8d", "remote": "https://github.com/allenai/tango.git" }, "platform": { "cpu_count": 16, "executable": "/opt/miniconda3/envs/tango/bin/python", "host": "ip-192-168-1-194.us-west-2.compute.internal", "operating_system": "macOS-10.16-x86_64-i386-64bit", "python": "3.8.12", "root": "/Users/dirkg/Documents/tango/examples/euler", "user": "dirkg" }, "started_at": 1642546361.347095, "step": "ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9", "tango": { "command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic", "version": "0.4.0rc4" } } ================================================ FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/lock ================================================ ================================================ FILE: test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/requirements.txt ================================================ absl-py==0.15.0 ai2-tango==0.4.0rc1 aiohttp==3.8.0 aiosignal==1.2.0 alabaster==0.7.12 appnope==0.1.2 astunparse==1.6.3 async-timeout==4.0.0 attrs==21.2.0 babel==2.9.1 backcall==0.2.0 base58==2.1.1 beautifulsoup4==4.10.0 black==21.12b0 bleach==4.1.0 boto3==1.19.12 botocore==1.22.12 cached-path==1.0.0 cachetools==4.2.4 certifi==2021.10.8 charset-normalizer==2.0.7 click-help-colors==0.9.1 click==8.0.3 codecov==2.1.12 colorama==0.4.4 configparser==5.1.0 coverage==6.1.1 datasets==1.15.1 decorator==5.1.0 dill==0.3.4 docker-pycreds==0.4.0 docutils==0.17.1 filelock==3.4.0 flake8==4.0.1 flaky==3.7.0 flatbuffers==2.0 frozenlist==1.2.0 fsspec==2021.11.0 furo==2022.1.2 future==0.18.2 gast==0.4.0 gitdb==4.0.9 gitpython==3.1.24 glob2==0.7 google-api-core==2.2.2 google-auth-oauthlib==0.4.6 google-auth==2.3.3 google-cloud-core==2.1.0 google-cloud-storage==1.42.3 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==2.1.0 googleapis-common-protos==1.53.0 grpcio==1.41.1 h5py==3.6.0 huggingface-hub==0.1.1 idna==3.3 imagesize==1.2.0 importlib-metadata==4.8.1 iniconfig==1.1.1 ipython==7.29.0 isort==5.10.1 jedi==0.18.0 jinja2==3.0.2 jmespath==0.10.0 joblib==1.1.0 jsonnet==0.17.0 keras-preprocessing==1.1.2 keras==2.7.0 keyring==23.2.1 libclang==12.0.0 livereload==2.6.3 markdown-it-py==1.1.0 markdown==3.3.4 markupsafe==2.0.1 matplotlib-inline==0.1.2 mccabe==0.6.1 mdit-py-plugins==0.3.0 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 more-itertools==8.10.0 multidict==5.2.0 multiprocess==0.70.12.2 mypy-extensions==0.4.3 mypy==0.931 myst-parser==0.16.1 nltk==3.6.7 numpy==1.21.2 oauthlib==3.1.1 olefile==0.46 opt-einsum==3.3.0 overrides==6.1.0 packaging==21.2 pandas==1.3.4 parso==0.8.2 pathspec==0.9.0 pathtools==0.1.2 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==8.4.0 pip==21.2.4 pkginfo==1.7.1 platformdirs==2.4.0 pluggy==1.0.0 promise==2.3 prompt-toolkit==3.0.20 protobuf==3.19.1 psutil==5.8.0 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.0 pyasn1-modules==0.2.8 pyasn1==0.4.8 pycodestyle==2.8.0 pydeprecate==0.3.1 pyflakes==2.4.0 pygments==2.10.0 pyparsing==2.4.7 pytest-cov==3.0.0 pytest-sphinx==0.3.1 pytest==6.2.5 python-dateutil==2.8.2 pytorch-lightning==1.5.1 pytz==2021.3 pyyaml==6.0 readme-renderer==30.0 regex==2021.11.2 requests-oauthlib==1.3.0 requests-toolbelt==0.9.1 requests==2.26.0 rfc3986==1.5.0 rouge-score==0.0.4 rsa==4.7.2 s3transfer==0.5.0 sacremoses==0.0.46 sentencepiece==0.1.96 sentry-sdk==1.4.3 setuptools==58.0.4 shortuuid==1.0.1 six==1.16.0 smmap==5.0.0 snowballstemmer==2.1.0 soupsieve==2.3 sphinx-autobuild==2021.3.14 sphinx-copybutton==0.4.0 sphinx==4.3.1 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlitedict==1.7.0 subprocess32==3.5.4 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorboard==2.7.0 tensorflow-estimator==2.7.0 tensorflow-io-gcs-filesystem==0.23.1 tensorflow==2.7.0 termcolor==1.1.0 tokenizers==0.10.3 toml==0.10.2 tomli==1.2.2 torch==1.10.0 torchaudio==0.10.0 torchmetrics==0.6.0 torchvision==0.11.1 tornado==6.1 tqdm==4.62.3 traitlets==5.1.0 transformers==4.12.3 twine==3.5.0 types-pyyaml==6.0.0 types-setuptools==57.4.2 typing-extensions==3.10.0.2 typing-utils==0.1.0 urllib3==1.26.7 wandb==0.12.6 wcwidth==0.2.5 webencodings==0.5.1 werkzeug==2.0.2 wheel==0.37.0 wrapt==1.13.3 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==3.6.0 ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/cache-metadata.json ================================================ { "step": "MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf" } ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/conda-environment.yaml ================================================ name: tango channels: - pytorch - defaults dependencies: - appnope=0.1.2=py38hecd8cb5_1001 - backcall=0.2.0=pyhd3eb1b0_0 - blas=1.0=mkl - bzip2=1.0.8=h1de35cc_0 - ca-certificates=2021.10.26=hecd8cb5_2 - certifi=2021.10.8=py38hecd8cb5_0 - decorator=5.1.0=pyhd3eb1b0_0 - ffmpeg=4.3=h0a44026_0 - freetype=2.11.0=hd8bbffd_0 - gettext=0.21.0=h7535e17_0 - giflib=5.2.1=haf1e3a3_0 - gmp=6.2.1=h23ab428_2 - gnutls=3.6.15=hed9c0bf_0 - icu=58.2=h0a44026_3 - intel-openmp=2021.4.0=hecd8cb5_3538 - ipython=7.29.0=py38h01d92e1_0 - jedi=0.18.0=py38hecd8cb5_1 - jpeg=9d=h9ed2024_0 - lame=3.100=h1de35cc_0 - lcms2=2.12=hf1fd2bf_0 - libcxx=12.0.0=h2f01273_0 - libffi=3.3=hb1e8313_2 - libiconv=1.16=h1de35cc_0 - libidn2=2.3.2=h9ed2024_0 - libpng=1.6.37=ha441bb4_0 - libtasn1=4.16.0=h9ed2024_0 - libtiff=4.2.0=h87d7836_0 - libunistring=0.9.10=h9ed2024_0 - libuv=1.40.0=haf1e3a3_0 - libwebp=1.2.0=hacca55c_0 - libwebp-base=1.2.0=h9ed2024_0 - libxml2=2.9.12=hcdb78fc_0 - llvm-openmp=12.0.0=h0dcd299_1 - lz4-c=1.9.3=h23ab428_1 - matplotlib-inline=0.1.2=pyhd3eb1b0_2 - mkl=2021.4.0=hecd8cb5_637 - mkl-service=2.4.0=py38h9ed2024_0 - mkl_fft=1.3.1=py38h4ab4a9b_0 - mkl_random=1.2.2=py38hb2f4e1b_0 - ncurses=6.3=hca72f7f_1 - nettle=3.7.3=h230ac6f_1 - numpy=1.21.2=py38h4b4dc7a_0 - numpy-base=1.21.2=py38he0bd621_0 - olefile=0.46=pyhd3eb1b0_0 - openh264=2.1.0=hd9629dc_0 - openssl=1.1.1l=h9ed2024_0 - parso=0.8.2=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 - pillow=8.4.0=py38h98e4679_0 - pip=21.2.4=py38hecd8cb5_0 - prompt-toolkit=3.0.20=pyhd3eb1b0_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pygments=2.10.0=pyhd3eb1b0_0 - python=3.8.12=h88f2d9e_0 - pytorch=1.10.0=py3.8_0 - readline=8.1=h9ed2024_0 - setuptools=58.0.4=py38hecd8cb5_0 - six=1.16.0=pyhd3eb1b0_0 - sqlite=3.36.0=hce871da_0 - tk=8.6.11=h7bc2e8c_0 - torchaudio=0.10.0=py38_cpu - torchvision=0.11.1=py38_cpu - traitlets=5.1.0=pyhd3eb1b0_0 - typing_extensions=3.10.0.2=pyh06a4308_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.37.0=pyhd3eb1b0_1 - xz=5.2.5=h1de35cc_0 - zlib=1.2.11=h1de35cc_3 - zstd=1.4.9=h322a384_0 - pip: - absl-py==0.15.0 - aiohttp==3.8.0 - aiosignal==1.2.0 - alabaster==0.7.12 - astunparse==1.6.3 - async-timeout==4.0.0 - attrs==21.2.0 - babel==2.9.1 - base58==2.1.1 - beautifulsoup4==4.10.0 - black==21.12b0 - bleach==4.1.0 - boto3==1.19.12 - botocore==1.22.12 - cached-path==1.0.0 - cachetools==4.2.4 - charset-normalizer==2.0.7 - click==8.0.3 - click-help-colors==0.9.1 - codecov==2.1.12 - colorama==0.4.4 - configparser==5.1.0 - coverage==6.1.1 - datasets==1.15.1 - dill==0.3.4 - docker-pycreds==0.4.0 - docutils==0.17.1 - filelock==3.4.0 - flake8==4.0.1 - flaky==3.7.0 - flatbuffers==2.0 - frozenlist==1.2.0 - fsspec==2021.11.0 - furo==2022.1.2 - future==0.18.2 - gast==0.4.0 - gitdb==4.0.9 - gitpython==3.1.24 - glob2==0.7 - google-api-core==2.2.2 - google-auth==2.3.3 - google-auth-oauthlib==0.4.6 - google-cloud-core==2.1.0 - google-cloud-storage==1.42.3 - google-crc32c==1.3.0 - google-pasta==0.2.0 - google-resumable-media==2.1.0 - googleapis-common-protos==1.53.0 - grpcio==1.41.1 - h5py==3.6.0 - huggingface-hub==0.1.1 - idna==3.3 - imagesize==1.2.0 - importlib-metadata==4.8.1 - iniconfig==1.1.1 - isort==5.10.1 - jinja2==3.0.2 - jmespath==0.10.0 - joblib==1.1.0 - jsonnet==0.17.0 - keras==2.7.0 - keras-preprocessing==1.1.2 - keyring==23.2.1 - libclang==12.0.0 - livereload==2.6.3 - markdown==3.3.4 - markdown-it-py==1.1.0 - markupsafe==2.0.1 - mccabe==0.6.1 - mdit-py-plugins==0.3.0 - more-itertools==8.10.0 - multidict==5.2.0 - multiprocess==0.70.12.2 - mypy==0.931 - mypy-extensions==0.4.3 - myst-parser==0.16.1 - nltk==3.6.7 - oauthlib==3.1.1 - opt-einsum==3.3.0 - overrides==6.1.0 - packaging==21.2 - pandas==1.3.4 - pathspec==0.9.0 - pathtools==0.1.2 - petname==2.6 - pkginfo==1.7.1 - platformdirs==2.4.0 - pluggy==1.0.0 - promise==2.3 - protobuf==3.19.1 - psutil==5.8.0 - py==1.11.0 - pyarrow==6.0.0 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pycodestyle==2.8.0 - pydeprecate==0.3.1 - pyflakes==2.4.0 - pyparsing==2.4.7 - pytest==6.2.5 - pytest-cov==3.0.0 - pytest-sphinx==0.3.1 - python-dateutil==2.8.2 - pytorch-lightning==1.5.1 - pytz==2021.3 - pyyaml==6.0 - readme-renderer==30.0 - regex==2021.11.2 - requests==2.26.0 - requests-oauthlib==1.3.0 - requests-toolbelt==0.9.1 - rfc3986==1.5.0 - rouge-score==0.0.4 - rsa==4.7.2 - s3transfer==0.5.0 - sacremoses==0.0.46 - sentencepiece==0.1.96 - sentry-sdk==1.4.3 - shortuuid==1.0.1 - smmap==5.0.0 - snowballstemmer==2.1.0 - soupsieve==2.3 - sphinx==4.3.1 - sphinx-autobuild==2021.3.14 - sphinx-copybutton==0.4.0 - sphinxcontrib-applehelp==1.0.2 - sphinxcontrib-devhelp==1.0.2 - sphinxcontrib-htmlhelp==2.0.0 - sphinxcontrib-jsmath==1.0.1 - sphinxcontrib-qthelp==1.0.3 - sphinxcontrib-serializinghtml==1.1.5 - sqlitedict==1.7.0 - subprocess32==3.5.4 - tensorboard==2.7.0 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.0 - tensorflow==2.7.0 - tensorflow-estimator==2.7.0 - tensorflow-io-gcs-filesystem==0.23.1 - termcolor==1.1.0 - tokenizers==0.10.3 - toml==0.10.2 - tomli==1.2.2 - torchmetrics==0.6.0 - tornado==6.1 - tqdm==4.62.3 - transformers==4.12.3 - twine==3.5.0 - types-pyyaml==6.0.0 - types-setuptools==57.4.2 - typing-utils==0.1.0 - urllib3==1.26.7 - wandb==0.12.6 - webencodings==0.5.1 - werkzeug==2.0.2 - wrapt==1.13.3 - xxhash==2.0.2 - yarl==1.7.2 - yaspin==2.1.0 - zipp==3.6.0 prefix: /opt/miniconda3/envs/tango ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/executor-metadata.json ================================================ { "config": { "type": "cmul", "a": [ 0, 1 ], "b": { "type": "ref", "ref": "SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk" } }, "duration": 0.0004, "finished_at": 1642546358.776982, "git": { "commit": "8e09b66caffbff20fd0b1504c961932b97417e8d", "remote": "https://github.com/allenai/tango.git" }, "platform": { "cpu_count": 16, "executable": "/opt/miniconda3/envs/tango/bin/python", "host": "ip-192-168-1-194.us-west-2.compute.internal", "operating_system": "macOS-10.16-x86_64-i386-64bit", "python": "3.8.12", "root": "/Users/dirkg/Documents/tango/examples/euler", "user": "dirkg" }, "started_at": 1642546358.776602, "step": "MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf", "tango": { "command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic", "version": "0.4.0rc4" } } ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/lock ================================================ ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/requirements.txt ================================================ absl-py==0.15.0 ai2-tango==0.4.0rc1 aiohttp==3.8.0 aiosignal==1.2.0 alabaster==0.7.12 appnope==0.1.2 astunparse==1.6.3 async-timeout==4.0.0 attrs==21.2.0 babel==2.9.1 backcall==0.2.0 base58==2.1.1 beautifulsoup4==4.10.0 black==21.12b0 bleach==4.1.0 boto3==1.19.12 botocore==1.22.12 cached-path==1.0.0 cachetools==4.2.4 certifi==2021.10.8 charset-normalizer==2.0.7 click-help-colors==0.9.1 click==8.0.3 codecov==2.1.12 colorama==0.4.4 configparser==5.1.0 coverage==6.1.1 datasets==1.15.1 decorator==5.1.0 dill==0.3.4 docker-pycreds==0.4.0 docutils==0.17.1 filelock==3.4.0 flake8==4.0.1 flaky==3.7.0 flatbuffers==2.0 frozenlist==1.2.0 fsspec==2021.11.0 furo==2022.1.2 future==0.18.2 gast==0.4.0 gitdb==4.0.9 gitpython==3.1.24 glob2==0.7 google-api-core==2.2.2 google-auth-oauthlib==0.4.6 google-auth==2.3.3 google-cloud-core==2.1.0 google-cloud-storage==1.42.3 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==2.1.0 googleapis-common-protos==1.53.0 grpcio==1.41.1 h5py==3.6.0 huggingface-hub==0.1.1 idna==3.3 imagesize==1.2.0 importlib-metadata==4.8.1 iniconfig==1.1.1 ipython==7.29.0 isort==5.10.1 jedi==0.18.0 jinja2==3.0.2 jmespath==0.10.0 joblib==1.1.0 jsonnet==0.17.0 keras-preprocessing==1.1.2 keras==2.7.0 keyring==23.2.1 libclang==12.0.0 livereload==2.6.3 markdown-it-py==1.1.0 markdown==3.3.4 markupsafe==2.0.1 matplotlib-inline==0.1.2 mccabe==0.6.1 mdit-py-plugins==0.3.0 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 more-itertools==8.10.0 multidict==5.2.0 multiprocess==0.70.12.2 mypy-extensions==0.4.3 mypy==0.931 myst-parser==0.16.1 nltk==3.6.7 numpy==1.21.2 oauthlib==3.1.1 olefile==0.46 opt-einsum==3.3.0 overrides==6.1.0 packaging==21.2 pandas==1.3.4 parso==0.8.2 pathspec==0.9.0 pathtools==0.1.2 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==8.4.0 pip==21.2.4 pkginfo==1.7.1 platformdirs==2.4.0 pluggy==1.0.0 promise==2.3 prompt-toolkit==3.0.20 protobuf==3.19.1 psutil==5.8.0 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.0 pyasn1-modules==0.2.8 pyasn1==0.4.8 pycodestyle==2.8.0 pydeprecate==0.3.1 pyflakes==2.4.0 pygments==2.10.0 pyparsing==2.4.7 pytest-cov==3.0.0 pytest-sphinx==0.3.1 pytest==6.2.5 python-dateutil==2.8.2 pytorch-lightning==1.5.1 pytz==2021.3 pyyaml==6.0 readme-renderer==30.0 regex==2021.11.2 requests-oauthlib==1.3.0 requests-toolbelt==0.9.1 requests==2.26.0 rfc3986==1.5.0 rouge-score==0.0.4 rsa==4.7.2 s3transfer==0.5.0 sacremoses==0.0.46 sentencepiece==0.1.96 sentry-sdk==1.4.3 setuptools==58.0.4 shortuuid==1.0.1 six==1.16.0 smmap==5.0.0 snowballstemmer==2.1.0 soupsieve==2.3 sphinx-autobuild==2021.3.14 sphinx-copybutton==0.4.0 sphinx==4.3.1 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlitedict==1.7.0 subprocess32==3.5.4 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorboard==2.7.0 tensorflow-estimator==2.7.0 tensorflow-io-gcs-filesystem==0.23.1 tensorflow==2.7.0 termcolor==1.1.0 tokenizers==0.10.3 toml==0.10.2 tomli==1.2.2 torch==1.10.0 torchaudio==0.10.0 torchmetrics==0.6.0 torchvision==0.11.1 tornado==6.1 tqdm==4.62.3 traitlets==5.1.0 transformers==4.12.3 twine==3.5.0 types-pyyaml==6.0.0 types-setuptools==57.4.2 typing-extensions==3.10.0.2 typing-utils==0.1.0 urllib3==1.26.7 wandb==0.12.6 wcwidth==0.2.5 webencodings==0.5.1 werkzeug==2.0.2 wheel==0.37.0 wrapt==1.13.3 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==3.6.0 ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/cache-metadata.json ================================================ { "step": "MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae" } ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/conda-environment.yaml ================================================ name: tango channels: - pytorch - defaults dependencies: - appnope=0.1.2=py38hecd8cb5_1001 - backcall=0.2.0=pyhd3eb1b0_0 - blas=1.0=mkl - bzip2=1.0.8=h1de35cc_0 - ca-certificates=2021.10.26=hecd8cb5_2 - certifi=2021.10.8=py38hecd8cb5_0 - decorator=5.1.0=pyhd3eb1b0_0 - ffmpeg=4.3=h0a44026_0 - freetype=2.11.0=hd8bbffd_0 - gettext=0.21.0=h7535e17_0 - giflib=5.2.1=haf1e3a3_0 - gmp=6.2.1=h23ab428_2 - gnutls=3.6.15=hed9c0bf_0 - icu=58.2=h0a44026_3 - intel-openmp=2021.4.0=hecd8cb5_3538 - ipython=7.29.0=py38h01d92e1_0 - jedi=0.18.0=py38hecd8cb5_1 - jpeg=9d=h9ed2024_0 - lame=3.100=h1de35cc_0 - lcms2=2.12=hf1fd2bf_0 - libcxx=12.0.0=h2f01273_0 - libffi=3.3=hb1e8313_2 - libiconv=1.16=h1de35cc_0 - libidn2=2.3.2=h9ed2024_0 - libpng=1.6.37=ha441bb4_0 - libtasn1=4.16.0=h9ed2024_0 - libtiff=4.2.0=h87d7836_0 - libunistring=0.9.10=h9ed2024_0 - libuv=1.40.0=haf1e3a3_0 - libwebp=1.2.0=hacca55c_0 - libwebp-base=1.2.0=h9ed2024_0 - libxml2=2.9.12=hcdb78fc_0 - llvm-openmp=12.0.0=h0dcd299_1 - lz4-c=1.9.3=h23ab428_1 - matplotlib-inline=0.1.2=pyhd3eb1b0_2 - mkl=2021.4.0=hecd8cb5_637 - mkl-service=2.4.0=py38h9ed2024_0 - mkl_fft=1.3.1=py38h4ab4a9b_0 - mkl_random=1.2.2=py38hb2f4e1b_0 - ncurses=6.3=hca72f7f_1 - nettle=3.7.3=h230ac6f_1 - numpy=1.21.2=py38h4b4dc7a_0 - numpy-base=1.21.2=py38he0bd621_0 - olefile=0.46=pyhd3eb1b0_0 - openh264=2.1.0=hd9629dc_0 - openssl=1.1.1l=h9ed2024_0 - parso=0.8.2=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 - pillow=8.4.0=py38h98e4679_0 - pip=21.2.4=py38hecd8cb5_0 - prompt-toolkit=3.0.20=pyhd3eb1b0_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pygments=2.10.0=pyhd3eb1b0_0 - python=3.8.12=h88f2d9e_0 - pytorch=1.10.0=py3.8_0 - readline=8.1=h9ed2024_0 - setuptools=58.0.4=py38hecd8cb5_0 - six=1.16.0=pyhd3eb1b0_0 - sqlite=3.36.0=hce871da_0 - tk=8.6.11=h7bc2e8c_0 - torchaudio=0.10.0=py38_cpu - torchvision=0.11.1=py38_cpu - traitlets=5.1.0=pyhd3eb1b0_0 - typing_extensions=3.10.0.2=pyh06a4308_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.37.0=pyhd3eb1b0_1 - xz=5.2.5=h1de35cc_0 - zlib=1.2.11=h1de35cc_3 - zstd=1.4.9=h322a384_0 - pip: - absl-py==0.15.0 - aiohttp==3.8.0 - aiosignal==1.2.0 - alabaster==0.7.12 - astunparse==1.6.3 - async-timeout==4.0.0 - attrs==21.2.0 - babel==2.9.1 - base58==2.1.1 - beautifulsoup4==4.10.0 - black==21.12b0 - bleach==4.1.0 - boto3==1.19.12 - botocore==1.22.12 - cached-path==1.0.0 - cachetools==4.2.4 - charset-normalizer==2.0.7 - click==8.0.3 - click-help-colors==0.9.1 - codecov==2.1.12 - colorama==0.4.4 - configparser==5.1.0 - coverage==6.1.1 - datasets==1.15.1 - dill==0.3.4 - docker-pycreds==0.4.0 - docutils==0.17.1 - filelock==3.4.0 - flake8==4.0.1 - flaky==3.7.0 - flatbuffers==2.0 - frozenlist==1.2.0 - fsspec==2021.11.0 - furo==2022.1.2 - future==0.18.2 - gast==0.4.0 - gitdb==4.0.9 - gitpython==3.1.24 - glob2==0.7 - google-api-core==2.2.2 - google-auth==2.3.3 - google-auth-oauthlib==0.4.6 - google-cloud-core==2.1.0 - google-cloud-storage==1.42.3 - google-crc32c==1.3.0 - google-pasta==0.2.0 - google-resumable-media==2.1.0 - googleapis-common-protos==1.53.0 - grpcio==1.41.1 - h5py==3.6.0 - huggingface-hub==0.1.1 - idna==3.3 - imagesize==1.2.0 - importlib-metadata==4.8.1 - iniconfig==1.1.1 - isort==5.10.1 - jinja2==3.0.2 - jmespath==0.10.0 - joblib==1.1.0 - jsonnet==0.17.0 - keras==2.7.0 - keras-preprocessing==1.1.2 - keyring==23.2.1 - libclang==12.0.0 - livereload==2.6.3 - markdown==3.3.4 - markdown-it-py==1.1.0 - markupsafe==2.0.1 - mccabe==0.6.1 - mdit-py-plugins==0.3.0 - more-itertools==8.10.0 - multidict==5.2.0 - multiprocess==0.70.12.2 - mypy==0.931 - mypy-extensions==0.4.3 - myst-parser==0.16.1 - nltk==3.6.7 - oauthlib==3.1.1 - opt-einsum==3.3.0 - overrides==6.1.0 - packaging==21.2 - pandas==1.3.4 - pathspec==0.9.0 - pathtools==0.1.2 - petname==2.6 - pkginfo==1.7.1 - platformdirs==2.4.0 - pluggy==1.0.0 - promise==2.3 - protobuf==3.19.1 - psutil==5.8.0 - py==1.11.0 - pyarrow==6.0.0 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pycodestyle==2.8.0 - pydeprecate==0.3.1 - pyflakes==2.4.0 - pyparsing==2.4.7 - pytest==6.2.5 - pytest-cov==3.0.0 - pytest-sphinx==0.3.1 - python-dateutil==2.8.2 - pytorch-lightning==1.5.1 - pytz==2021.3 - pyyaml==6.0 - readme-renderer==30.0 - regex==2021.11.2 - requests==2.26.0 - requests-oauthlib==1.3.0 - requests-toolbelt==0.9.1 - rfc3986==1.5.0 - rouge-score==0.0.4 - rsa==4.7.2 - s3transfer==0.5.0 - sacremoses==0.0.46 - sentencepiece==0.1.96 - sentry-sdk==1.4.3 - shortuuid==1.0.1 - smmap==5.0.0 - snowballstemmer==2.1.0 - soupsieve==2.3 - sphinx==4.3.1 - sphinx-autobuild==2021.3.14 - sphinx-copybutton==0.4.0 - sphinxcontrib-applehelp==1.0.2 - sphinxcontrib-devhelp==1.0.2 - sphinxcontrib-htmlhelp==2.0.0 - sphinxcontrib-jsmath==1.0.1 - sphinxcontrib-qthelp==1.0.3 - sphinxcontrib-serializinghtml==1.1.5 - sqlitedict==1.7.0 - subprocess32==3.5.4 - tensorboard==2.7.0 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.0 - tensorflow==2.7.0 - tensorflow-estimator==2.7.0 - tensorflow-io-gcs-filesystem==0.23.1 - termcolor==1.1.0 - tokenizers==0.10.3 - toml==0.10.2 - tomli==1.2.2 - torchmetrics==0.6.0 - tornado==6.1 - tqdm==4.62.3 - transformers==4.12.3 - twine==3.5.0 - types-pyyaml==6.0.0 - types-setuptools==57.4.2 - typing-utils==0.1.0 - urllib3==1.26.7 - wandb==0.12.6 - webencodings==0.5.1 - werkzeug==2.0.2 - wrapt==1.13.3 - xxhash==2.0.2 - yarl==1.7.2 - yaspin==2.1.0 - zipp==3.6.0 prefix: /opt/miniconda3/envs/tango ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/executor-metadata.json ================================================ { "config": { "type": "cmul", "a": [ 0, 1 ], "b": [ 3.1415926535, 0 ] }, "duration": 0.0005, "finished_at": 1642546353.7523232, "git": { "commit": "8e09b66caffbff20fd0b1504c961932b97417e8d", "remote": "https://github.com/allenai/tango.git" }, "platform": { "cpu_count": 16, "executable": "/opt/miniconda3/envs/tango/bin/python", "host": "ip-192-168-1-194.us-west-2.compute.internal", "operating_system": "macOS-10.16-x86_64-i386-64bit", "python": "3.8.12", "root": "/Users/dirkg/Documents/tango/examples/euler", "user": "dirkg" }, "started_at": 1642546353.751795, "step": "MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae", "tango": { "command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic", "version": "0.4.0rc4" } } ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/lock ================================================ ================================================ FILE: test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/requirements.txt ================================================ absl-py==0.15.0 ai2-tango==0.4.0rc1 aiohttp==3.8.0 aiosignal==1.2.0 alabaster==0.7.12 appnope==0.1.2 astunparse==1.6.3 async-timeout==4.0.0 attrs==21.2.0 babel==2.9.1 backcall==0.2.0 base58==2.1.1 beautifulsoup4==4.10.0 black==21.12b0 bleach==4.1.0 boto3==1.19.12 botocore==1.22.12 cached-path==1.0.0 cachetools==4.2.4 certifi==2021.10.8 charset-normalizer==2.0.7 click-help-colors==0.9.1 click==8.0.3 codecov==2.1.12 colorama==0.4.4 configparser==5.1.0 coverage==6.1.1 datasets==1.15.1 decorator==5.1.0 dill==0.3.4 docker-pycreds==0.4.0 docutils==0.17.1 filelock==3.4.0 flake8==4.0.1 flaky==3.7.0 flatbuffers==2.0 frozenlist==1.2.0 fsspec==2021.11.0 furo==2022.1.2 future==0.18.2 gast==0.4.0 gitdb==4.0.9 gitpython==3.1.24 glob2==0.7 google-api-core==2.2.2 google-auth-oauthlib==0.4.6 google-auth==2.3.3 google-cloud-core==2.1.0 google-cloud-storage==1.42.3 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==2.1.0 googleapis-common-protos==1.53.0 grpcio==1.41.1 h5py==3.6.0 huggingface-hub==0.1.1 idna==3.3 imagesize==1.2.0 importlib-metadata==4.8.1 iniconfig==1.1.1 ipython==7.29.0 isort==5.10.1 jedi==0.18.0 jinja2==3.0.2 jmespath==0.10.0 joblib==1.1.0 jsonnet==0.17.0 keras-preprocessing==1.1.2 keras==2.7.0 keyring==23.2.1 libclang==12.0.0 livereload==2.6.3 markdown-it-py==1.1.0 markdown==3.3.4 markupsafe==2.0.1 matplotlib-inline==0.1.2 mccabe==0.6.1 mdit-py-plugins==0.3.0 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 more-itertools==8.10.0 multidict==5.2.0 multiprocess==0.70.12.2 mypy-extensions==0.4.3 mypy==0.931 myst-parser==0.16.1 nltk==3.6.7 numpy==1.21.2 oauthlib==3.1.1 olefile==0.46 opt-einsum==3.3.0 overrides==6.1.0 packaging==21.2 pandas==1.3.4 parso==0.8.2 pathspec==0.9.0 pathtools==0.1.2 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==8.4.0 pip==21.2.4 pkginfo==1.7.1 platformdirs==2.4.0 pluggy==1.0.0 promise==2.3 prompt-toolkit==3.0.20 protobuf==3.19.1 psutil==5.8.0 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.0 pyasn1-modules==0.2.8 pyasn1==0.4.8 pycodestyle==2.8.0 pydeprecate==0.3.1 pyflakes==2.4.0 pygments==2.10.0 pyparsing==2.4.7 pytest-cov==3.0.0 pytest-sphinx==0.3.1 pytest==6.2.5 python-dateutil==2.8.2 pytorch-lightning==1.5.1 pytz==2021.3 pyyaml==6.0 readme-renderer==30.0 regex==2021.11.2 requests-oauthlib==1.3.0 requests-toolbelt==0.9.1 requests==2.26.0 rfc3986==1.5.0 rouge-score==0.0.4 rsa==4.7.2 s3transfer==0.5.0 sacremoses==0.0.46 sentencepiece==0.1.96 sentry-sdk==1.4.3 setuptools==58.0.4 shortuuid==1.0.1 six==1.16.0 smmap==5.0.0 snowballstemmer==2.1.0 soupsieve==2.3 sphinx-autobuild==2021.3.14 sphinx-copybutton==0.4.0 sphinx==4.3.1 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlitedict==1.7.0 subprocess32==3.5.4 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorboard==2.7.0 tensorflow-estimator==2.7.0 tensorflow-io-gcs-filesystem==0.23.1 tensorflow==2.7.0 termcolor==1.1.0 tokenizers==0.10.3 toml==0.10.2 tomli==1.2.2 torch==1.10.0 torchaudio==0.10.0 torchmetrics==0.6.0 torchvision==0.11.1 tornado==6.1 tqdm==4.62.3 traitlets==5.1.0 transformers==4.12.3 twine==3.5.0 types-pyyaml==6.0.0 types-setuptools==57.4.2 typing-extensions==3.10.0.2 typing-utils==0.1.0 urllib3==1.26.7 wandb==0.12.6 wcwidth==0.2.5 webencodings==0.5.1 werkzeug==2.0.2 wheel==0.37.0 wrapt==1.13.3 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==3.6.0 ================================================ FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/cache-metadata.json ================================================ { "step": "SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk" } ================================================ FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/conda-environment.yaml ================================================ name: tango channels: - pytorch - defaults dependencies: - appnope=0.1.2=py38hecd8cb5_1001 - backcall=0.2.0=pyhd3eb1b0_0 - blas=1.0=mkl - bzip2=1.0.8=h1de35cc_0 - ca-certificates=2021.10.26=hecd8cb5_2 - certifi=2021.10.8=py38hecd8cb5_0 - decorator=5.1.0=pyhd3eb1b0_0 - ffmpeg=4.3=h0a44026_0 - freetype=2.11.0=hd8bbffd_0 - gettext=0.21.0=h7535e17_0 - giflib=5.2.1=haf1e3a3_0 - gmp=6.2.1=h23ab428_2 - gnutls=3.6.15=hed9c0bf_0 - icu=58.2=h0a44026_3 - intel-openmp=2021.4.0=hecd8cb5_3538 - ipython=7.29.0=py38h01d92e1_0 - jedi=0.18.0=py38hecd8cb5_1 - jpeg=9d=h9ed2024_0 - lame=3.100=h1de35cc_0 - lcms2=2.12=hf1fd2bf_0 - libcxx=12.0.0=h2f01273_0 - libffi=3.3=hb1e8313_2 - libiconv=1.16=h1de35cc_0 - libidn2=2.3.2=h9ed2024_0 - libpng=1.6.37=ha441bb4_0 - libtasn1=4.16.0=h9ed2024_0 - libtiff=4.2.0=h87d7836_0 - libunistring=0.9.10=h9ed2024_0 - libuv=1.40.0=haf1e3a3_0 - libwebp=1.2.0=hacca55c_0 - libwebp-base=1.2.0=h9ed2024_0 - libxml2=2.9.12=hcdb78fc_0 - llvm-openmp=12.0.0=h0dcd299_1 - lz4-c=1.9.3=h23ab428_1 - matplotlib-inline=0.1.2=pyhd3eb1b0_2 - mkl=2021.4.0=hecd8cb5_637 - mkl-service=2.4.0=py38h9ed2024_0 - mkl_fft=1.3.1=py38h4ab4a9b_0 - mkl_random=1.2.2=py38hb2f4e1b_0 - ncurses=6.3=hca72f7f_1 - nettle=3.7.3=h230ac6f_1 - numpy=1.21.2=py38h4b4dc7a_0 - numpy-base=1.21.2=py38he0bd621_0 - olefile=0.46=pyhd3eb1b0_0 - openh264=2.1.0=hd9629dc_0 - openssl=1.1.1l=h9ed2024_0 - parso=0.8.2=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 - pillow=8.4.0=py38h98e4679_0 - pip=21.2.4=py38hecd8cb5_0 - prompt-toolkit=3.0.20=pyhd3eb1b0_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pygments=2.10.0=pyhd3eb1b0_0 - python=3.8.12=h88f2d9e_0 - pytorch=1.10.0=py3.8_0 - readline=8.1=h9ed2024_0 - setuptools=58.0.4=py38hecd8cb5_0 - six=1.16.0=pyhd3eb1b0_0 - sqlite=3.36.0=hce871da_0 - tk=8.6.11=h7bc2e8c_0 - torchaudio=0.10.0=py38_cpu - torchvision=0.11.1=py38_cpu - traitlets=5.1.0=pyhd3eb1b0_0 - typing_extensions=3.10.0.2=pyh06a4308_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.37.0=pyhd3eb1b0_1 - xz=5.2.5=h1de35cc_0 - zlib=1.2.11=h1de35cc_3 - zstd=1.4.9=h322a384_0 - pip: - absl-py==0.15.0 - aiohttp==3.8.0 - aiosignal==1.2.0 - alabaster==0.7.12 - astunparse==1.6.3 - async-timeout==4.0.0 - attrs==21.2.0 - babel==2.9.1 - base58==2.1.1 - beautifulsoup4==4.10.0 - black==21.12b0 - bleach==4.1.0 - boto3==1.19.12 - botocore==1.22.12 - cached-path==1.0.0 - cachetools==4.2.4 - charset-normalizer==2.0.7 - click==8.0.3 - click-help-colors==0.9.1 - codecov==2.1.12 - colorama==0.4.4 - configparser==5.1.0 - coverage==6.1.1 - datasets==1.15.1 - dill==0.3.4 - docker-pycreds==0.4.0 - docutils==0.17.1 - filelock==3.4.0 - flake8==4.0.1 - flaky==3.7.0 - flatbuffers==2.0 - frozenlist==1.2.0 - fsspec==2021.11.0 - furo==2022.1.2 - future==0.18.2 - gast==0.4.0 - gitdb==4.0.9 - gitpython==3.1.24 - glob2==0.7 - google-api-core==2.2.2 - google-auth==2.3.3 - google-auth-oauthlib==0.4.6 - google-cloud-core==2.1.0 - google-cloud-storage==1.42.3 - google-crc32c==1.3.0 - google-pasta==0.2.0 - google-resumable-media==2.1.0 - googleapis-common-protos==1.53.0 - grpcio==1.41.1 - h5py==3.6.0 - huggingface-hub==0.1.1 - idna==3.3 - imagesize==1.2.0 - importlib-metadata==4.8.1 - iniconfig==1.1.1 - isort==5.10.1 - jinja2==3.0.2 - jmespath==0.10.0 - joblib==1.1.0 - jsonnet==0.17.0 - keras==2.7.0 - keras-preprocessing==1.1.2 - keyring==23.2.1 - libclang==12.0.0 - livereload==2.6.3 - markdown==3.3.4 - markdown-it-py==1.1.0 - markupsafe==2.0.1 - mccabe==0.6.1 - mdit-py-plugins==0.3.0 - more-itertools==8.10.0 - multidict==5.2.0 - multiprocess==0.70.12.2 - mypy==0.931 - mypy-extensions==0.4.3 - myst-parser==0.16.1 - nltk==3.6.7 - oauthlib==3.1.1 - opt-einsum==3.3.0 - overrides==6.1.0 - packaging==21.2 - pandas==1.3.4 - pathspec==0.9.0 - pathtools==0.1.2 - petname==2.6 - pkginfo==1.7.1 - platformdirs==2.4.0 - pluggy==1.0.0 - promise==2.3 - protobuf==3.19.1 - psutil==5.8.0 - py==1.11.0 - pyarrow==6.0.0 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pycodestyle==2.8.0 - pydeprecate==0.3.1 - pyflakes==2.4.0 - pyparsing==2.4.7 - pytest==6.2.5 - pytest-cov==3.0.0 - pytest-sphinx==0.3.1 - python-dateutil==2.8.2 - pytorch-lightning==1.5.1 - pytz==2021.3 - pyyaml==6.0 - readme-renderer==30.0 - regex==2021.11.2 - requests==2.26.0 - requests-oauthlib==1.3.0 - requests-toolbelt==0.9.1 - rfc3986==1.5.0 - rouge-score==0.0.4 - rsa==4.7.2 - s3transfer==0.5.0 - sacremoses==0.0.46 - sentencepiece==0.1.96 - sentry-sdk==1.4.3 - shortuuid==1.0.1 - smmap==5.0.0 - snowballstemmer==2.1.0 - soupsieve==2.3 - sphinx==4.3.1 - sphinx-autobuild==2021.3.14 - sphinx-copybutton==0.4.0 - sphinxcontrib-applehelp==1.0.2 - sphinxcontrib-devhelp==1.0.2 - sphinxcontrib-htmlhelp==2.0.0 - sphinxcontrib-jsmath==1.0.1 - sphinxcontrib-qthelp==1.0.3 - sphinxcontrib-serializinghtml==1.1.5 - sqlitedict==1.7.0 - subprocess32==3.5.4 - tensorboard==2.7.0 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.0 - tensorflow==2.7.0 - tensorflow-estimator==2.7.0 - tensorflow-io-gcs-filesystem==0.23.1 - termcolor==1.1.0 - tokenizers==0.10.3 - toml==0.10.2 - tomli==1.2.2 - torchmetrics==0.6.0 - tornado==6.1 - tqdm==4.62.3 - transformers==4.12.3 - twine==3.5.0 - types-pyyaml==6.0.0 - types-setuptools==57.4.2 - typing-utils==0.1.0 - urllib3==1.26.7 - wandb==0.12.6 - webencodings==0.5.1 - werkzeug==2.0.2 - wrapt==1.13.3 - xxhash==2.0.2 - yarl==1.7.2 - yaspin==2.1.0 - zipp==3.6.0 prefix: /opt/miniconda3/envs/tango ================================================ FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/executor-metadata.json ================================================ { "config": { "type": "csin", "x": [ 3.1415926535, 0 ] }, "duration": 0.0004, "finished_at": 1642546356.265017, "git": { "commit": "8e09b66caffbff20fd0b1504c961932b97417e8d", "remote": "https://github.com/allenai/tango.git" }, "platform": { "cpu_count": 16, "executable": "/opt/miniconda3/envs/tango/bin/python", "host": "ip-192-168-1-194.us-west-2.compute.internal", "operating_system": "macOS-10.16-x86_64-i386-64bit", "python": "3.8.12", "root": "/Users/dirkg/Documents/tango/examples/euler", "user": "dirkg" }, "started_at": 1642546356.264595, "step": "SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk", "tango": { "command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic", "version": "0.4.0rc4" } } ================================================ FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/lock ================================================ ================================================ FILE: test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/requirements.txt ================================================ absl-py==0.15.0 ai2-tango==0.4.0rc1 aiohttp==3.8.0 aiosignal==1.2.0 alabaster==0.7.12 appnope==0.1.2 astunparse==1.6.3 async-timeout==4.0.0 attrs==21.2.0 babel==2.9.1 backcall==0.2.0 base58==2.1.1 beautifulsoup4==4.10.0 black==21.12b0 bleach==4.1.0 boto3==1.19.12 botocore==1.22.12 cached-path==1.0.0 cachetools==4.2.4 certifi==2021.10.8 charset-normalizer==2.0.7 click-help-colors==0.9.1 click==8.0.3 codecov==2.1.12 colorama==0.4.4 configparser==5.1.0 coverage==6.1.1 datasets==1.15.1 decorator==5.1.0 dill==0.3.4 docker-pycreds==0.4.0 docutils==0.17.1 filelock==3.4.0 flake8==4.0.1 flaky==3.7.0 flatbuffers==2.0 frozenlist==1.2.0 fsspec==2021.11.0 furo==2022.1.2 future==0.18.2 gast==0.4.0 gitdb==4.0.9 gitpython==3.1.24 glob2==0.7 google-api-core==2.2.2 google-auth-oauthlib==0.4.6 google-auth==2.3.3 google-cloud-core==2.1.0 google-cloud-storage==1.42.3 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==2.1.0 googleapis-common-protos==1.53.0 grpcio==1.41.1 h5py==3.6.0 huggingface-hub==0.1.1 idna==3.3 imagesize==1.2.0 importlib-metadata==4.8.1 iniconfig==1.1.1 ipython==7.29.0 isort==5.10.1 jedi==0.18.0 jinja2==3.0.2 jmespath==0.10.0 joblib==1.1.0 jsonnet==0.17.0 keras-preprocessing==1.1.2 keras==2.7.0 keyring==23.2.1 libclang==12.0.0 livereload==2.6.3 markdown-it-py==1.1.0 markdown==3.3.4 markupsafe==2.0.1 matplotlib-inline==0.1.2 mccabe==0.6.1 mdit-py-plugins==0.3.0 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 more-itertools==8.10.0 multidict==5.2.0 multiprocess==0.70.12.2 mypy-extensions==0.4.3 mypy==0.931 myst-parser==0.16.1 nltk==3.6.7 numpy==1.21.2 oauthlib==3.1.1 olefile==0.46 opt-einsum==3.3.0 overrides==6.1.0 packaging==21.2 pandas==1.3.4 parso==0.8.2 pathspec==0.9.0 pathtools==0.1.2 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==8.4.0 pip==21.2.4 pkginfo==1.7.1 platformdirs==2.4.0 pluggy==1.0.0 promise==2.3 prompt-toolkit==3.0.20 protobuf==3.19.1 psutil==5.8.0 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.0 pyasn1-modules==0.2.8 pyasn1==0.4.8 pycodestyle==2.8.0 pydeprecate==0.3.1 pyflakes==2.4.0 pygments==2.10.0 pyparsing==2.4.7 pytest-cov==3.0.0 pytest-sphinx==0.3.1 pytest==6.2.5 python-dateutil==2.8.2 pytorch-lightning==1.5.1 pytz==2021.3 pyyaml==6.0 readme-renderer==30.0 regex==2021.11.2 requests-oauthlib==1.3.0 requests-toolbelt==0.9.1 requests==2.26.0 rfc3986==1.5.0 rouge-score==0.0.4 rsa==4.7.2 s3transfer==0.5.0 sacremoses==0.0.46 sentencepiece==0.1.96 sentry-sdk==1.4.3 setuptools==58.0.4 shortuuid==1.0.1 six==1.16.0 smmap==5.0.0 snowballstemmer==2.1.0 soupsieve==2.3 sphinx-autobuild==2021.3.14 sphinx-copybutton==0.4.0 sphinx==4.3.1 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlitedict==1.7.0 subprocess32==3.5.4 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorboard==2.7.0 tensorflow-estimator==2.7.0 tensorflow-io-gcs-filesystem==0.23.1 tensorflow==2.7.0 termcolor==1.1.0 tokenizers==0.10.3 toml==0.10.2 tomli==1.2.2 torch==1.10.0 torchaudio==0.10.0 torchmetrics==0.6.0 torchvision==0.11.1 tornado==6.1 tqdm==4.62.3 traitlets==5.1.0 transformers==4.12.3 twine==3.5.0 types-pyyaml==6.0.0 types-setuptools==57.4.2 typing-extensions==3.10.0.2 typing-utils==0.1.0 urllib3==1.26.7 wandb==0.12.6 wcwidth==0.2.5 webencodings==0.5.1 werkzeug==2.0.2 wheel==0.37.0 wrapt==1.13.3 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==3.6.0 ================================================ FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/cache-metadata.json ================================================ { "step": "SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz" } ================================================ FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/conda-environment.yaml ================================================ name: tango channels: - pytorch - defaults dependencies: - appnope=0.1.2=py38hecd8cb5_1001 - backcall=0.2.0=pyhd3eb1b0_0 - blas=1.0=mkl - bzip2=1.0.8=h1de35cc_0 - ca-certificates=2021.10.26=hecd8cb5_2 - certifi=2021.10.8=py38hecd8cb5_0 - decorator=5.1.0=pyhd3eb1b0_0 - ffmpeg=4.3=h0a44026_0 - freetype=2.11.0=hd8bbffd_0 - gettext=0.21.0=h7535e17_0 - giflib=5.2.1=haf1e3a3_0 - gmp=6.2.1=h23ab428_2 - gnutls=3.6.15=hed9c0bf_0 - icu=58.2=h0a44026_3 - intel-openmp=2021.4.0=hecd8cb5_3538 - ipython=7.29.0=py38h01d92e1_0 - jedi=0.18.0=py38hecd8cb5_1 - jpeg=9d=h9ed2024_0 - lame=3.100=h1de35cc_0 - lcms2=2.12=hf1fd2bf_0 - libcxx=12.0.0=h2f01273_0 - libffi=3.3=hb1e8313_2 - libiconv=1.16=h1de35cc_0 - libidn2=2.3.2=h9ed2024_0 - libpng=1.6.37=ha441bb4_0 - libtasn1=4.16.0=h9ed2024_0 - libtiff=4.2.0=h87d7836_0 - libunistring=0.9.10=h9ed2024_0 - libuv=1.40.0=haf1e3a3_0 - libwebp=1.2.0=hacca55c_0 - libwebp-base=1.2.0=h9ed2024_0 - libxml2=2.9.12=hcdb78fc_0 - llvm-openmp=12.0.0=h0dcd299_1 - lz4-c=1.9.3=h23ab428_1 - matplotlib-inline=0.1.2=pyhd3eb1b0_2 - mkl=2021.4.0=hecd8cb5_637 - mkl-service=2.4.0=py38h9ed2024_0 - mkl_fft=1.3.1=py38h4ab4a9b_0 - mkl_random=1.2.2=py38hb2f4e1b_0 - ncurses=6.3=hca72f7f_1 - nettle=3.7.3=h230ac6f_1 - numpy=1.21.2=py38h4b4dc7a_0 - numpy-base=1.21.2=py38he0bd621_0 - olefile=0.46=pyhd3eb1b0_0 - openh264=2.1.0=hd9629dc_0 - openssl=1.1.1l=h9ed2024_0 - parso=0.8.2=pyhd3eb1b0_0 - pexpect=4.8.0=pyhd3eb1b0_3 - pickleshare=0.7.5=pyhd3eb1b0_1003 - pillow=8.4.0=py38h98e4679_0 - pip=21.2.4=py38hecd8cb5_0 - prompt-toolkit=3.0.20=pyhd3eb1b0_0 - ptyprocess=0.7.0=pyhd3eb1b0_2 - pygments=2.10.0=pyhd3eb1b0_0 - python=3.8.12=h88f2d9e_0 - pytorch=1.10.0=py3.8_0 - readline=8.1=h9ed2024_0 - setuptools=58.0.4=py38hecd8cb5_0 - six=1.16.0=pyhd3eb1b0_0 - sqlite=3.36.0=hce871da_0 - tk=8.6.11=h7bc2e8c_0 - torchaudio=0.10.0=py38_cpu - torchvision=0.11.1=py38_cpu - traitlets=5.1.0=pyhd3eb1b0_0 - typing_extensions=3.10.0.2=pyh06a4308_0 - wcwidth=0.2.5=pyhd3eb1b0_0 - wheel=0.37.0=pyhd3eb1b0_1 - xz=5.2.5=h1de35cc_0 - zlib=1.2.11=h1de35cc_3 - zstd=1.4.9=h322a384_0 - pip: - absl-py==0.15.0 - aiohttp==3.8.0 - aiosignal==1.2.0 - alabaster==0.7.12 - astunparse==1.6.3 - async-timeout==4.0.0 - attrs==21.2.0 - babel==2.9.1 - base58==2.1.1 - beautifulsoup4==4.10.0 - black==21.12b0 - bleach==4.1.0 - boto3==1.19.12 - botocore==1.22.12 - cached-path==1.0.0 - cachetools==4.2.4 - charset-normalizer==2.0.7 - click==8.0.3 - click-help-colors==0.9.1 - codecov==2.1.12 - colorama==0.4.4 - configparser==5.1.0 - coverage==6.1.1 - datasets==1.15.1 - dill==0.3.4 - docker-pycreds==0.4.0 - docutils==0.17.1 - filelock==3.4.0 - flake8==4.0.1 - flaky==3.7.0 - flatbuffers==2.0 - frozenlist==1.2.0 - fsspec==2021.11.0 - furo==2022.1.2 - future==0.18.2 - gast==0.4.0 - gitdb==4.0.9 - gitpython==3.1.24 - glob2==0.7 - google-api-core==2.2.2 - google-auth==2.3.3 - google-auth-oauthlib==0.4.6 - google-cloud-core==2.1.0 - google-cloud-storage==1.42.3 - google-crc32c==1.3.0 - google-pasta==0.2.0 - google-resumable-media==2.1.0 - googleapis-common-protos==1.53.0 - grpcio==1.41.1 - h5py==3.6.0 - huggingface-hub==0.1.1 - idna==3.3 - imagesize==1.2.0 - importlib-metadata==4.8.1 - iniconfig==1.1.1 - isort==5.10.1 - jinja2==3.0.2 - jmespath==0.10.0 - joblib==1.1.0 - jsonnet==0.17.0 - keras==2.7.0 - keras-preprocessing==1.1.2 - keyring==23.2.1 - libclang==12.0.0 - livereload==2.6.3 - markdown==3.3.4 - markdown-it-py==1.1.0 - markupsafe==2.0.1 - mccabe==0.6.1 - mdit-py-plugins==0.3.0 - more-itertools==8.10.0 - multidict==5.2.0 - multiprocess==0.70.12.2 - mypy==0.931 - mypy-extensions==0.4.3 - myst-parser==0.16.1 - nltk==3.6.7 - oauthlib==3.1.1 - opt-einsum==3.3.0 - overrides==6.1.0 - packaging==21.2 - pandas==1.3.4 - pathspec==0.9.0 - pathtools==0.1.2 - petname==2.6 - pkginfo==1.7.1 - platformdirs==2.4.0 - pluggy==1.0.0 - promise==2.3 - protobuf==3.19.1 - psutil==5.8.0 - py==1.11.0 - pyarrow==6.0.0 - pyasn1==0.4.8 - pyasn1-modules==0.2.8 - pycodestyle==2.8.0 - pydeprecate==0.3.1 - pyflakes==2.4.0 - pyparsing==2.4.7 - pytest==6.2.5 - pytest-cov==3.0.0 - pytest-sphinx==0.3.1 - python-dateutil==2.8.2 - pytorch-lightning==1.5.1 - pytz==2021.3 - pyyaml==6.0 - readme-renderer==30.0 - regex==2021.11.2 - requests==2.26.0 - requests-oauthlib==1.3.0 - requests-toolbelt==0.9.1 - rfc3986==1.5.0 - rouge-score==0.0.4 - rsa==4.7.2 - s3transfer==0.5.0 - sacremoses==0.0.46 - sentencepiece==0.1.96 - sentry-sdk==1.4.3 - shortuuid==1.0.1 - smmap==5.0.0 - snowballstemmer==2.1.0 - soupsieve==2.3 - sphinx==4.3.1 - sphinx-autobuild==2021.3.14 - sphinx-copybutton==0.4.0 - sphinxcontrib-applehelp==1.0.2 - sphinxcontrib-devhelp==1.0.2 - sphinxcontrib-htmlhelp==2.0.0 - sphinxcontrib-jsmath==1.0.1 - sphinxcontrib-qthelp==1.0.3 - sphinxcontrib-serializinghtml==1.1.5 - sqlitedict==1.7.0 - subprocess32==3.5.4 - tensorboard==2.7.0 - tensorboard-data-server==0.6.1 - tensorboard-plugin-wit==1.8.0 - tensorflow==2.7.0 - tensorflow-estimator==2.7.0 - tensorflow-io-gcs-filesystem==0.23.1 - termcolor==1.1.0 - tokenizers==0.10.3 - toml==0.10.2 - tomli==1.2.2 - torchmetrics==0.6.0 - tornado==6.1 - tqdm==4.62.3 - transformers==4.12.3 - twine==3.5.0 - types-pyyaml==6.0.0 - types-setuptools==57.4.2 - typing-utils==0.1.0 - urllib3==1.26.7 - wandb==0.12.6 - webencodings==0.5.1 - werkzeug==2.0.2 - wrapt==1.13.3 - xxhash==2.0.2 - yarl==1.7.2 - yaspin==2.1.0 - zipp==3.6.0 prefix: /opt/miniconda3/envs/tango ================================================ FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/executor-metadata.json ================================================ { "config": { "type": "csub", "a": { "type": "ref", "ref": "AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP" }, "b": { "type": "ref", "ref": "ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9" } }, "duration": 0.0005, "finished_at": 1642546366.57007, "git": { "commit": "8e09b66caffbff20fd0b1504c961932b97417e8d", "remote": "https://github.com/allenai/tango.git" }, "platform": { "cpu_count": 16, "executable": "/opt/miniconda3/envs/tango/bin/python", "host": "ip-192-168-1-194.us-west-2.compute.internal", "operating_system": "macOS-10.16-x86_64-i386-64bit", "python": "3.8.12", "root": "/Users/dirkg/Documents/tango/examples/euler", "user": "dirkg" }, "started_at": 1642546366.569589, "step": "SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz", "tango": { "command": "/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic", "version": "0.4.0rc4" } } ================================================ FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/lock ================================================ ================================================ FILE: test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/requirements.txt ================================================ absl-py==0.15.0 ai2-tango==0.4.0rc1 aiohttp==3.8.0 aiosignal==1.2.0 alabaster==0.7.12 appnope==0.1.2 astunparse==1.6.3 async-timeout==4.0.0 attrs==21.2.0 babel==2.9.1 backcall==0.2.0 base58==2.1.1 beautifulsoup4==4.10.0 black==21.12b0 bleach==4.1.0 boto3==1.19.12 botocore==1.22.12 cached-path==1.0.0 cachetools==4.2.4 certifi==2021.10.8 charset-normalizer==2.0.7 click-help-colors==0.9.1 click==8.0.3 codecov==2.1.12 colorama==0.4.4 configparser==5.1.0 coverage==6.1.1 datasets==1.15.1 decorator==5.1.0 dill==0.3.4 docker-pycreds==0.4.0 docutils==0.17.1 filelock==3.4.0 flake8==4.0.1 flaky==3.7.0 flatbuffers==2.0 frozenlist==1.2.0 fsspec==2021.11.0 furo==2022.1.2 future==0.18.2 gast==0.4.0 gitdb==4.0.9 gitpython==3.1.24 glob2==0.7 google-api-core==2.2.2 google-auth-oauthlib==0.4.6 google-auth==2.3.3 google-cloud-core==2.1.0 google-cloud-storage==1.42.3 google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==2.1.0 googleapis-common-protos==1.53.0 grpcio==1.41.1 h5py==3.6.0 huggingface-hub==0.1.1 idna==3.3 imagesize==1.2.0 importlib-metadata==4.8.1 iniconfig==1.1.1 ipython==7.29.0 isort==5.10.1 jedi==0.18.0 jinja2==3.0.2 jmespath==0.10.0 joblib==1.1.0 jsonnet==0.17.0 keras-preprocessing==1.1.2 keras==2.7.0 keyring==23.2.1 libclang==12.0.0 livereload==2.6.3 markdown-it-py==1.1.0 markdown==3.3.4 markupsafe==2.0.1 matplotlib-inline==0.1.2 mccabe==0.6.1 mdit-py-plugins==0.3.0 mkl-fft==1.3.1 mkl-random==1.2.2 mkl-service==2.4.0 more-itertools==8.10.0 multidict==5.2.0 multiprocess==0.70.12.2 mypy-extensions==0.4.3 mypy==0.931 myst-parser==0.16.1 nltk==3.6.7 numpy==1.21.2 oauthlib==3.1.1 olefile==0.46 opt-einsum==3.3.0 overrides==6.1.0 packaging==21.2 pandas==1.3.4 parso==0.8.2 pathspec==0.9.0 pathtools==0.1.2 petname==2.6 pexpect==4.8.0 pickleshare==0.7.5 pillow==8.4.0 pip==21.2.4 pkginfo==1.7.1 platformdirs==2.4.0 pluggy==1.0.0 promise==2.3 prompt-toolkit==3.0.20 protobuf==3.19.1 psutil==5.8.0 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.0 pyasn1-modules==0.2.8 pyasn1==0.4.8 pycodestyle==2.8.0 pydeprecate==0.3.1 pyflakes==2.4.0 pygments==2.10.0 pyparsing==2.4.7 pytest-cov==3.0.0 pytest-sphinx==0.3.1 pytest==6.2.5 python-dateutil==2.8.2 pytorch-lightning==1.5.1 pytz==2021.3 pyyaml==6.0 readme-renderer==30.0 regex==2021.11.2 requests-oauthlib==1.3.0 requests-toolbelt==0.9.1 requests==2.26.0 rfc3986==1.5.0 rouge-score==0.0.4 rsa==4.7.2 s3transfer==0.5.0 sacremoses==0.0.46 sentencepiece==0.1.96 sentry-sdk==1.4.3 setuptools==58.0.4 shortuuid==1.0.1 six==1.16.0 smmap==5.0.0 snowballstemmer==2.1.0 soupsieve==2.3 sphinx-autobuild==2021.3.14 sphinx-copybutton==0.4.0 sphinx==4.3.1 sphinxcontrib-applehelp==1.0.2 sphinxcontrib-devhelp==1.0.2 sphinxcontrib-htmlhelp==2.0.0 sphinxcontrib-jsmath==1.0.1 sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlitedict==1.7.0 subprocess32==3.5.4 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.0 tensorboard==2.7.0 tensorflow-estimator==2.7.0 tensorflow-io-gcs-filesystem==0.23.1 tensorflow==2.7.0 termcolor==1.1.0 tokenizers==0.10.3 toml==0.10.2 tomli==1.2.2 torch==1.10.0 torchaudio==0.10.0 torchmetrics==0.6.0 torchvision==0.11.1 tornado==6.1 tqdm==4.62.3 traitlets==5.1.0 transformers==4.12.3 twine==3.5.0 types-pyyaml==6.0.0 types-setuptools==57.4.2 typing-extensions==3.10.0.2 typing-utils==0.1.0 urllib3==1.26.7 wandb==0.12.6 wcwidth==0.2.5 webencodings==0.5.1 werkzeug==2.0.2 wheel==0.37.0 wrapt==1.13.3 xxhash==2.0.2 yarl==1.7.2 yaspin==2.1.0 zipp==3.6.0 ================================================ FILE: tests/__init__.py ================================================ ================================================ FILE: tests/common/__init__.py ================================================ ================================================ FILE: tests/common/dataset_dict_test.py ================================================ from tango.common.dataset_dict import DatasetDict def test_dataset_dict(): dataset_dict = DatasetDict(splits={"train": list(range(10)), "test": list(range(5))}) assert len(dataset_dict) == 2 assert "train" in dataset_dict assert "test" in dataset_dict assert len(dataset_dict["train"]) == 10 assert len(dataset_dict["test"]) == 5 assert set(dataset_dict) == set(dataset_dict.keys()) == {"train", "test"} ================================================ FILE: tests/common/det_hash_test.py ================================================ from tango.common.det_hash import DetHashWithVersion, det_hash def test_normal_det_hash(): class C: VERSION = 1 def __init__(self, x: int): self.x = x c1_1 = C(10) c2_1 = C(10) c3_1 = C(20) assert det_hash(c1_1) == det_hash(c2_1) assert det_hash(c3_1) != det_hash(c2_1) class C: VERSION = 2 def __init__(self, x: int): self.x = x c1_2 = C(10) c2_2 = C(10) c3_2 = C(20) assert det_hash(c1_2) == det_hash(c2_2) assert det_hash(c3_2) != det_hash(c2_2) assert det_hash(c1_2) == det_hash(c1_1) # because the version isn't taken into account assert det_hash(c3_2) == det_hash(c3_1) # because the version isn't taken into account def test_versioned_det_hash(): class C(DetHashWithVersion): VERSION = "1" def __init__(self, x: int): self.x = x c1_1 = C(10) c2_1 = C(10) c3_1 = C(20) assert det_hash(c1_1) == det_hash(c2_1) assert det_hash(c3_1) != det_hash(c2_1) class C(DetHashWithVersion): VERSION = "2" def __init__(self, x: int): self.x = x c1_2 = C(10) c2_2 = C(10) c3_2 = C(20) assert det_hash(c1_2) == det_hash(c2_2) assert det_hash(c3_2) != det_hash(c2_2) assert det_hash(c1_2) != det_hash(c1_1) # because the version is taken into account assert det_hash(c3_2) != det_hash(c3_1) # because the version is taken into account ================================================ FILE: tests/common/from_params_pep_563_test.py ================================================ """ This tests `FromParams` functionality with https://www.python.org/dev/peps/pep-0563/. """ from __future__ import annotations from tango.common.from_params import FromParams, infer_method_params from tango.common.lazy import Lazy class Foo(FromParams): def __init__(self, x: int): self.x = x class Bar(FromParams): def __init__(self, foo: Lazy[Foo]): self.foo = foo.construct() class Baz(FromParams): def __init__(self, bar: Lazy[Bar]): self.bar = bar.construct() def test_infer_method_params(): parameters = infer_method_params(Baz, Baz.__init__) assert not isinstance(parameters["bar"].annotation, str) def test_from_params(): baz = Baz.from_params({"bar": {"foo": {"x": 1}}}) assert baz.bar.foo.x == 1 ================================================ FILE: tests/common/from_params_test.py ================================================ import sys from copy import deepcopy from dataclasses import dataclass from numbers import Number from typing import ( Dict, Generic, Iterable, List, Mapping, Optional, Set, Tuple, TypeVar, Union, ) import pytest from tango.common import det_hash from tango.common.det_hash import DetHashWithVersion from tango.common.exceptions import ConfigurationError from tango.common.from_params import ( FromParams, create_kwargs, is_base_registrable, remove_optional, takes_arg, ) from tango.common.lazy import Lazy from tango.common.params import Params from tango.common.registrable import Registrable from tango.common.testing import TangoTestCase from tango.step import Step class TestFromParams(TangoTestCase): def test_takes_arg(self): def bare_function(some_input: int) -> int: return some_input + 1 assert takes_arg(bare_function, "some_input") assert not takes_arg(bare_function, "some_other_input") class SomeClass: total = 0 def __init__(self, constructor_param: str) -> None: self.constructor_param = constructor_param def check_param(self, check: str) -> bool: return self.constructor_param == check @classmethod def set_total(cls, new_total: int) -> None: cls.total = new_total assert takes_arg(SomeClass, "self") assert takes_arg(SomeClass, "constructor_param") assert not takes_arg(SomeClass, "check") assert takes_arg(SomeClass.check_param, "check") assert not takes_arg(SomeClass.check_param, "other_check") assert takes_arg(SomeClass.set_total, "new_total") assert not takes_arg(SomeClass.set_total, "total") def test_remove_optional(self): optional_type = Optional[Dict[str, str]] bare_type = remove_optional(optional_type) # type: ignore bare_bare_type = remove_optional(bare_type) assert bare_type == Dict[str, str] assert bare_bare_type == Dict[str, str] assert remove_optional(Optional[str]) == str # type: ignore[arg-type] assert remove_optional(str) == str @pytest.mark.parametrize("input_type", [dict, Params]) def test_from_params(self, input_type): params = {"my_int": 10} my_class = MyClass.from_params(input_type(params), my_bool=True) assert isinstance(my_class, MyClass) assert my_class.my_int == 10 assert my_class.my_bool def test_create_kwargs(self): kwargs = create_kwargs( MyClass, MyClass, Params({"my_int": 5}), dict(my_bool=True, my_float=4.4) ) # my_float should not be included because it's not a param of the MyClass constructor assert kwargs == {"my_int": 5, "my_bool": True} def test_extras(self): class A(Registrable): pass @A.register("b") class B(A): def __init__(self, size: int, name: str) -> None: self.size = size self.name = name @A.register("c") class C(A): def __init__(self, size: int, name: str) -> None: self.size = size self.name = name # custom from params @classmethod def from_params(cls, params: Params, size: int, **extras) -> "C": # type: ignore name = params.pop("name") return cls(size=size, name=name) # Check that extras get passed, even though A doesn't need them. params = Params({"type": "b", "size": 10}) b: B = A.from_params(params, name="extra") # type: ignore[assignment] assert b.name == "extra" assert b.size == 10 # Check that extra extras don't get passed. params = Params({"type": "b", "size": 10}) b = A.from_params(params, name="extra", unwanted=True) # type: ignore[assignment] assert b.name == "extra" # type: ignore[attr-defined] assert b.size == 10 # type: ignore[attr-defined] # Now the same with a custom from_params. params = Params({"type": "c", "name": "extra_c"}) c: C = A.from_params(params, size=20) # type: ignore[assignment] assert c.name == "extra_c" assert c.size == 20 # Check that extra extras don't get passed. params = Params({"type": "c", "name": "extra_c"}) c = A.from_params(params, size=20, unwanted=True) # type: ignore[assignment] assert c.name == "extra_c" # type: ignore[attr-defined] assert c.size == 20 # type: ignore[attr-defined] def test_variable_length_tuple(self): class Foo(FromParams): def __init__(self, x: Tuple[Optional[int], ...]): self.x = x assert Foo.from_params({"x": [None, 1, 2, 3]}).x == (None, 1, 2, 3) assert Foo.from_params({"x": [1, 2]}).x == (1, 2) assert Foo.from_params({"x": [1]}).x == (1,) def test_union(self): class A(FromParams): def __init__(self, a: Union[int, List[int]]) -> None: self.a = a class B(FromParams): def __init__(self, b: Union[A, List[A]]) -> None: # Really you would want to be sure that `self.b` has a consistent type, but for # this test we'll ignore that. self.b = b params = Params({"a": 3}) a = A.from_params(params) assert a.a == 3 params = Params({"a": [3, 4, 5]}) a = A.from_params(params) assert a.a == [3, 4, 5] params = Params({"b": {"a": 3}}) b = B.from_params(params) assert isinstance(b.b, A) assert b.b.a == 3 params = Params({"b": [{"a": 3}, {"a": [4, 5]}]}) b = B.from_params(params) assert isinstance(b.b, list) assert b.b[0].a == 3 assert b.b[1].a == [4, 5] def test_non_params_object_with_params(self): bar = Bar.from_params({"foo": Foo(a=1)}) assert bar.foo.a == 1 def test_crazy_nested_union(self): class A(FromParams): def __init__(self, a: Union[int, List[int]]) -> None: self.a = a class B(FromParams): def __init__(self, b: Union[A, List[A]]) -> None: # Really you would want to be sure that `self.b` has a consistent type, but for # this test we'll ignore that. self.b = b class C(FromParams): def __init__(self, c: Union[A, B, Dict[str, A]]) -> None: # Really you would want to be sure that `self.c` has a consistent type, but for # this test we'll ignore that. self.c = c # This is a contrived, ugly example (why would you want to duplicate names in a nested # structure like this??), but it demonstrates a potential bug when dealing with mutatable # parameters. If you're not careful about keeping the parameters un-mutated in two # separate places, you'll end up with a B, or with a dict that's missing the 'b' key. params = Params({"c": {"a": {"a": 3}, "b": {"a": [4, 5]}}}) c = C.from_params(params) assert isinstance(c.c, dict) assert c.c["a"].a == 3 assert c.c["b"].a == [4, 5] def test_union_of_castable_types(self): class IntFloat(FromParams): def __init__(self, a: Union[int, float]) -> None: self.a = a class FloatInt(FromParams): def __init__(self, a: Union[float, int]) -> None: self.a = a float_param_str = '{"a": 1.0}' int_param_str = '{"a": 1}' import json for expected_type, param_str in [(int, int_param_str), (float, float_param_str)]: for cls in [IntFloat, FloatInt]: c = cls.from_params(Params(json.loads(param_str))) assert type(c.a) == expected_type # type: ignore[attr-defined] def test_invalid_type_conversions(self): class A(FromParams): def __init__(self, a: int) -> None: self.a = a with pytest.raises(TypeError): A.from_params(Params({"a": "1"})) with pytest.raises(TypeError): A.from_params(Params({"a": 1.0})) def test_dict(self): class A(Registrable): pass @A.register("b") class B(A): def __init__(self, size: int) -> None: self.size = size class C(Registrable): pass @C.register("d") class D(C): def __init__(self, items: Dict[str, A]) -> None: self.items = items params = Params( { "type": "d", "items": {"first": {"type": "b", "size": 1}, "second": {"type": "b", "size": 2}}, } ) d: D = C.from_params(params) # type: ignore[assignment] assert isinstance(d.items, dict) assert len(d.items) == 2 assert all(isinstance(key, str) for key in d.items.keys()) assert all(isinstance(value, B) for value in d.items.values()) assert d.items["first"].size == 1 # type: ignore[attr-defined] assert d.items["second"].size == 2 # type: ignore[attr-defined] def test_dict_not_params(self): class A(FromParams): def __init__(self, counts: Dict[str, int]) -> None: self.counts = counts params = Params({"counts": {"a": 10, "b": 20}}) a = A.from_params(params) assert isinstance(a.counts, dict) assert not isinstance(a.counts, Params) def test_list(self): class A(Registrable): pass @A.register("b") class B(A): def __init__(self, size: int) -> None: self.size = size class C(Registrable): pass @C.register("d") class D(C): def __init__(self, items: List[A]) -> None: self.items = items params = Params( {"type": "d", "items": [{"type": "b", "size": 1}, {"type": "b", "size": 2}]} ) d: D = C.from_params(params) # type: ignore[assignment] assert isinstance(d.items, list) assert len(d.items) == 2 assert all(isinstance(item, B) for item in d.items) assert d.items[0].size == 1 # type: ignore[attr-defined] assert d.items[1].size == 2 # type: ignore[attr-defined] def test_tuple(self): class A(Registrable): pass @A.register("b") class B(A): def __init__(self, size: int) -> None: self.size = size class C(Registrable): pass @C.register("d") class D(C): def __init__(self, name: str) -> None: self.name = name class E(Registrable): pass @E.register("f") class F(E): def __init__(self, items: Tuple[A, C]) -> None: self.items = items params = Params( {"type": "f", "items": [{"type": "b", "size": 1}, {"type": "d", "name": "item2"}]} ) f: F = E.from_params(params) # type: ignore[assignment] assert isinstance(f.items, tuple) assert len(f.items) == 2 assert isinstance(f.items[0], B) assert isinstance(f.items[1], D) assert f.items[0].size == 1 assert f.items[1].name == "item2" def test_set(self): class A(Registrable): def __init__(self, name: str) -> None: self.name = name def __eq__(self, other): return self.name == other.name def __hash__(self): return hash(self.name) @A.register("b") class B(A): pass class C(Registrable): pass @C.register("d") class D(C): def __init__(self, items: Set[A]) -> None: self.items = items params = Params( { "type": "d", "items": [ {"type": "b", "name": "item1"}, {"type": "b", "name": "item2"}, {"type": "b", "name": "item2"}, ], } ) d: D = C.from_params(params) # type: ignore[assignment] assert isinstance(d.items, set) assert len(d.items) == 2 assert all(isinstance(item, B) for item in d.items) assert any(item.name == "item1" for item in d.items) assert any(item.name == "item2" for item in d.items) def test_kwargs_with_multiple_inheritance(self): # Basic idea: have two identical classes, differing only in the order of their multiple # inheritance, and make sure that passing kwargs up to the super class works in both cases. class A(Registrable): def __init__(self, a: int): self.a = a @A.register("b1") # type: ignore class B1(A, Number): def __init__(self, b: float, **kwargs): super().__init__(**kwargs) self.b = b @A.register("b2") # type: ignore class B2(Number, A): def __init__(self, b: float, **kwargs): super().__init__(**kwargs) self.b = b b1 = B1.from_params(Params({"a": 4, "b": 5})) assert b1.b == 5 assert b1.a == 4 b2 = B2.from_params(Params({"a": 4, "b": 5})) assert b2.b == 5 assert b2.a == 4 def test_instantiating_with_multiple_inheritance(self): class A(Registrable): def __init__(self, a: int): self.a = a @A.register("b") # type: ignore class B(A, Number): def __init__(self, b: float, **kwargs): super().__init__(**kwargs) self.b = b assert not is_base_registrable(B) @B.register("c") class C(B): def __init__(self, c: float, **kwargs): super().__init__(**kwargs) self.c = c # make sure we can instantiate B directly. b = B.from_params({"b": 1.0, "a": 1}) assert isinstance(b, B) # and also make sure we can instantiate subclasses of B. c = B.from_params({"type": "c", "c": 2.0, "b": 1.0, "a": 1}) assert isinstance(c, C) def test_only_infer_superclass_params_if_unknown(self): class BaseClass(Registrable): def __init__(self): self.x = None self.a = None self.rest = None @BaseClass.register("a") class A(BaseClass): def __init__(self, a: int, x: int, **kwargs): super().__init__() self.x = x self.a = a self.rest = kwargs @BaseClass.register("b") class B(A): def __init__(self, a: str, x: int = 42, **kwargs): super().__init__(x=x, a=-1, raw_a=a, **kwargs) params = Params({"type": "b", "a": "123"}) # The param `x` should not be required as it has default value in `B` # The correct type of the param `a` should be inferred from `B` as well. instance = BaseClass.from_params(params) assert instance.x == 42 assert instance.a == -1 assert len(instance.rest) == 1 # type: ignore assert isinstance(instance.rest["raw_a"], str) # type: ignore assert instance.rest["raw_a"] == "123" # type: ignore def test_kwargs_are_passed_to_deeper_superclasses(self): class BaseClass(Registrable): def __init__(self): self.a = None self.b = None self.c = None @BaseClass.register("a") class A(BaseClass): def __init__(self, a: str): super().__init__() self.a = a @BaseClass.register("b") class B(A): def __init__(self, b: str, **kwargs): super().__init__(**kwargs) self.b = b @BaseClass.register("c") class C(B): def __init__(self, c, **kwargs): super().__init__(**kwargs) self.c = c params = Params({"type": "c", "a": "a_value", "b": "b_value", "c": "c_value"}) instance = BaseClass.from_params(params) assert instance.a == "a_value" assert instance.b == "b_value" assert instance.c == "c_value" def test_lazy_construction_can_happen_multiple_times(self): test_string = "this is a test" extra_string = "extra string" class ConstructedObject(FromParams): def __init__(self, string: str, extra: str): self.string = string self.extra = extra class Testing(FromParams): def __init__(self, lazy_object: Lazy[ConstructedObject]): first_time = lazy_object.construct(extra=extra_string) second_time = lazy_object.construct(extra=extra_string) assert first_time.string == test_string assert first_time.extra == extra_string assert second_time.string == test_string assert second_time.extra == extra_string Testing.from_params(Params({"lazy_object": {"string": test_string}})) def test_lazy_and_from_params_can_be_pickled(self): import pickle baz = Baz.from_params(Params({"bar": {"foo": {"a": 2}}})) pickle.dumps(baz) def test_optional_vs_required_lazy_objects(self): class ConstructedObject(FromParams): def __init__(self, a: int): self.a = a class Testing(FromParams): def __init__( self, lazy1: Lazy[ConstructedObject], lazy2: Lazy[ConstructedObject] = Lazy(ConstructedObject), lazy3: Lazy[ConstructedObject] = None, lazy4: Optional[Lazy[ConstructedObject]] = Lazy(ConstructedObject), ) -> None: self.lazy1 = lazy1.construct() self.lazy2 = lazy2.construct(a=2) self.lazy3 = None if lazy3 is None else lazy3.construct() self.lazy4 = None if lazy4 is None else lazy4.construct(a=1) test1 = Testing.from_params(Params({"lazy1": {"a": 1}})) assert test1.lazy1.a == 1 assert test1.lazy2.a == 2 assert test1.lazy3 is None assert test1.lazy4 is not None test2 = Testing.from_params(Params({"lazy1": {"a": 1}, "lazy2": {"a": 3}})) assert test2.lazy1.a == 1 assert test2.lazy2.a == 3 assert test2.lazy3 is None assert test2.lazy4 is not None test3 = Testing.from_params(Params({"lazy1": {"a": 1}, "lazy3": {"a": 3}, "lazy4": None})) assert test3.lazy1.a == 1 assert test3.lazy2.a == 2 assert test3.lazy3 is not None assert test3.lazy3.a == 3 assert test3.lazy4 is None with pytest.raises(ConfigurationError, match='Missing key "lazy1" for Testing'): Testing.from_params(Params({})) def test_wrapper_kwargs_passed_down(self): class BaseObject: def __init__(self, x: int = 1): self.x = x class BaseWrapper(BaseObject, FromParams): def __init__(self, y: int = 2, **kwargs): super().__init__(**kwargs) self.y = y o = BaseWrapper.from_params(Params({"y": 3}), x=2) assert o.x == 2 def test_iterable(self): class A(Registrable): pass @A.register("b") class B(A): def __init__(self, size: int) -> None: self.size = size class C(Registrable): pass @C.register("d") class D(C): def __init__(self, items: Iterable[A]) -> None: self.items = items params = Params( {"type": "d", "items": [{"type": "b", "size": 1}, {"type": "b", "size": 2}]} ) d: D = C.from_params(params) # type: ignore[assignment] assert isinstance(d.items, Iterable) items = list(d.items) assert len(items) == 2 assert all(isinstance(item, B) for item in items) assert items[0].size == 1 # type: ignore assert items[1].size == 2 # type: ignore def test_mapping(self): class A(Registrable): pass @A.register("b") class B(A): def __init__(self, size: int) -> None: self.size = size class C(Registrable): pass @C.register("d") class D(C): def __init__(self, items: Mapping[str, A]) -> None: self.items = items params = Params( { "type": "d", "items": {"first": {"type": "b", "size": 1}, "second": {"type": "b", "size": 2}}, } ) d: D = C.from_params(params) # type: ignore[assignment] assert isinstance(d.items, Mapping) assert len(d.items) == 2 assert all(isinstance(key, str) for key in d.items.keys()) assert all(isinstance(value, B) for value in d.items.values()) assert d.items["first"].size == 1 # type: ignore assert d.items["second"].size == 2 # type: ignore def test_custom_abc_mapping(self): from collections import abc class CustomMapping(abc.Mapping): def __init__(self, data: Dict[str, int]): self.data = data def __getitem__(self, key): return self.data[key] def __iter__(self): return iter(self.data) def __len__(self): return len(self.data) class ClassWithCustomMapping(FromParams): def __init__(self, mapping: CustomMapping): self.mapping = mapping o = ClassWithCustomMapping.from_params({"mapping": {"data": {"a": 1}}}) assert isinstance(o.mapping, CustomMapping) assert o.mapping["a"] == 1 def test_extra_parameters_are_not_allowed_when_there_is_no_constructor(self): class A(FromParams): pass with pytest.raises(ConfigurationError, match="Extra parameters"): A.from_params(Params({"some_spurious": "key", "value": "pairs"})) def test_explicit_kwargs_always_passed_to_constructor(self): class Base(FromParams): def __init__(self, lazy: bool = False, x: int = 0) -> None: self.lazy = lazy self.x = x class A(Base): def __init__(self, **kwargs) -> None: assert "lazy" in kwargs super().__init__(**kwargs) A.from_params(Params({"lazy": False})) class B(Base): def __init__(self, **kwargs) -> None: super().__init__(lazy=True, **kwargs) b = B.from_params(Params({})) assert b.lazy is True def test_raises_when_there_are_no_implementations(self): class A(Registrable): pass with pytest.raises(ConfigurationError, match="not in acceptable choices for type"): A.from_params("nonexistent_class") with pytest.raises(ConfigurationError, match='key "type" is required'): A.from_params(Params({"some_spurious": "key", "value": "pairs"})) with pytest.raises(ConfigurationError, match='key "type" is required'): A.from_params(Params({})) # Some paths through the code are different if there is a constructor here versus not. We # don't actually go through this logic anymore, but it's here as a regression test. class B(Registrable): def __init__(self): pass with pytest.raises(ConfigurationError, match="not in acceptable choices for type"): B.from_params("nonexistent_class") with pytest.raises(ConfigurationError, match='key "type" is required'): B.from_params(Params({"some_spurious": "key", "value": "pairs"})) with pytest.raises(ConfigurationError, match='key "type" is required'): B.from_params(Params({})) def test_from_params_raises_error_on_wrong_parameter_name_in_optional_union(self): class NestedClass(FromParams): def __init__(self, varname: Optional[str] = None): self.varname = varname class WrapperClass(FromParams): def __init__(self, nested_class: Optional[Union[str, NestedClass]] = None): if isinstance(nested_class, str): nested_class = NestedClass(varname=nested_class) self.nested_class = nested_class with pytest.raises(ConfigurationError): WrapperClass.from_params(Params({"nested_class": {"wrong_varname": "varstring"}})) def test_from_params_handles_base_class_kwargs(self): class Foo(FromParams): def __init__(self, a: int, b: str = None, **kwargs) -> None: self.a = a self.b = b for key, value in kwargs.items(): setattr(self, key, value) foo = Foo.from_params(Params({"a": 2, "b": "hi"})) assert foo.a == 2 assert foo.b == "hi" foo = Foo.from_params(Params({"a": 2, "b": "hi", "c": {"2": "3"}})) assert foo.a == 2 assert foo.b == "hi" assert foo.c == {"2": "3"} # type: ignore[attr-defined] class Bar(Foo): def __init__(self, a: int, b: str, d: int, **kwargs) -> None: super().__init__(a, b=b, **kwargs) self.d = d bar = Bar.from_params(Params({"a": 2, "b": "hi", "c": {"2": "3"}, "d": 0})) assert bar.a == 2 assert bar.b == "hi" assert bar.c == {"2": "3"} # type: ignore[attr-defined] assert bar.d == 0 class Baz(Foo): def __init__(self, a: int, b: Optional[str] = "a", **kwargs) -> None: super().__init__(a, b=b, **kwargs) baz = Baz.from_params(Params({"a": 2, "b": None})) assert baz.b is None baz = Baz.from_params(Params({"a": 2})) assert baz.b == "a" def test_from_params_base_class_kwargs_crashes_if_params_not_handled(self): class Bar(FromParams): def __init__(self, c: str = None) -> None: self.c = c class Foo(Bar): def __init__(self, a: int, b: str = None, **kwargs) -> None: super().__init__(**kwargs) self.a = a self.b = b foo = Foo.from_params(Params({"a": 2, "b": "hi", "c": "some value"})) assert foo.a == 2 assert foo.b == "hi" assert foo.c == "some value" with pytest.raises(TypeError, match="invalid_key"): Foo.from_params(Params({"a": 2, "b": "hi", "invalid_key": "some value"})) def test_from_params_handles_kwargs_in_non_from_params_registered_class(self): class Bar(Registrable): pass class Baz: def __init__(self, a: int) -> None: self.a = a @Bar.register("foo") class Foo(Baz): def __init__(self, a: int, b: str = None, **kwargs) -> None: super().__init__(a) self.b = b for key, value in kwargs.items(): setattr(self, key, value) foo: Foo = Bar.from_params(Params({"type": "foo", "a": 2, "b": "hi"})) # type: ignore[assignment] assert foo.a == 2 assert foo.b == "hi" foo = Bar.from_params( # type: ignore[assignment] Params({"type": "foo", "a": 2, "b": "hi", "c": {"2": "3"}}) ) assert foo.a == 2 # type: ignore[attr-defined] assert foo.b == "hi" # type: ignore[attr-defined] assert foo.c == {"2": "3"} # type: ignore[attr-defined] def test_from_params_passes_extras_to_non_from_params_registered_class(self): class Bar(Registrable): pass class Baz: def __init__(self, a: int, c: Dict[str, str] = None, extra: str = "idk") -> None: self.a = a self.c = c self.extra = extra @Bar.register("foo") class Foo(Baz): def __init__(self, a: int, b: str = None, **kwargs) -> None: super().__init__(a, **kwargs) self.b = b foo: Foo = Bar.from_params(Params({"type": "foo", "a": 2, "b": "hi"})) # type: ignore[assignment] assert foo.a == 2 assert foo.b == "hi" assert foo.c is None foo = Bar.from_params( # type: ignore[assignment] Params({"type": "foo", "a": 2, "b": "hi", "c": {"2": "3"}}), extra="4" ) assert foo.a == 2 # type: ignore[attr-defined] assert foo.b == "hi" # type: ignore[attr-defined] assert foo.c == {"2": "3"} # type: ignore[attr-defined] assert foo.extra == "4" # type: ignore[attr-defined] def test_from_params_child_has_kwargs_base_implicit_constructor(self): class Foo(FromParams): pass class Bar(Foo): def __init__(self, a: int, **kwargs) -> None: self.a = a bar = Bar.from_params(Params({"a": 2})) assert bar.a == 2 def test_from_params_has_args(self): class Foo(FromParams): def __init__(self, a: int, *args) -> None: self.a = a foo = Foo.from_params(Params({"a": 2})) assert foo.a == 2 def test_from_params_with_dataclass(self): @dataclass class Foo(FromParams): x: int y: str assert Foo.from_params({"x": 1, "y": "2"}).x == 1 with pytest.raises(TypeError): Foo.from_params({"x": 1, "y": 2}) def test_to_params(self): @dataclass class Bar(FromParams): z: bool @dataclass class Foo(FromParams): x: int bar: Bar params_dict = {"x": 1, "bar": {"z": True}} foo = Foo.from_params(deepcopy(params_dict)) assert foo.bar.z params = foo.to_params() assert params.as_dict() == params_dict def test_to_params_needs_custom_to_params(self): @dataclass class Bar: z: bool @dataclass class Foo(FromParams): x: int bar: Bar foo = Foo.from_params({"x": 1}, bar=Bar(z=True)) with pytest.raises(NotImplementedError): foo.to_params() @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python 3.9 or higher") def test_type_hinting_generics_from_std_collections(self): class Item(FromParams): def __init__(self, a: int) -> None: self.a = a class ClassWithStdGenerics(FromParams): def __init__(self, x: list[Item], y: dict[str, Item]) -> None: # type: ignore[syntax] self.x = x self.y = y o = ClassWithStdGenerics.from_params({"x": [{"a": 1}], "y": {"b": {"a": 1}}}) assert isinstance(o.x, list) assert isinstance(o.x[0], Item) assert isinstance(o.y["b"], Item) def test_with_non_from_params_generics(self): T = TypeVar("T") class Item(Generic[T]): def __init__(self, x: T): self.x = x class ClassWithGenerics(FromParams): def __init__(self, item: Item[T]): self.item = item o = ClassWithGenerics.from_params({"item": {"x": 1}}) assert isinstance(o.item, Item) @pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python 3.10 or higher") def test_with_union_pipe(self): class Item(FromParams): def __init__(self, a: int) -> None: self.a = a class ClassWithUnionType(FromParams): def __init__(self, x: Item | str): # type: ignore[syntax] self.x = x o = ClassWithUnionType.from_params({"x": {"a": 1}}) assert isinstance(o.x, Item) def test_from_params_with_function(self): """ Tests that a function registered as a constructor for a registrable class will properly construct arguments. """ class MyRegistrableClass(Registrable): def __init__(self, a: int, b: int): self.a = a self.b = b @dataclass class OptionsClass(FromParams): a: int b: int @MyRegistrableClass.register("func_constructor") # type: ignore def constructor(options: OptionsClass) -> MyRegistrableClass: assert isinstance(options, OptionsClass) return MyRegistrableClass(options.a, options.b) MyRegistrableClass.from_params({"type": "func_constructor", "options": {"a": 1, "b": 2}}) def test_from_params_passes_no_extra_args_in_factory_construction(self): class InnerBase(Registrable): pass from typing import Callable def innerbase_with_x_factory(cls) -> Callable[..., InnerBase]: def factory(x: int, **kwargs) -> InnerBase: return cls(x=x, **kwargs) return factory class Inner(InnerBase): def __init__(self, x: int): self.x = x InnerBase.register("inner")(innerbase_with_x_factory(Inner)) # type: ignore[arg-type] class OuterBase(Registrable): default_implementation = "default" def __init__(self, y: str, i: InnerBase, c: int): self.i = i self.y = y self.c = c OuterBase.register("default")(OuterBase) config = {"c": 4, "i": {"type": "inner", "x": 5}} outer_lazy = Lazy(OuterBase, Params(config)) outer = outer_lazy.construct(y="placeholder") assert outer.i.x == 5 # type: ignore[attr-defined] def test_lazy_from_params_with_version(self): class Gizmo(Registrable): pass @Gizmo.register("widget") class WidgetGizmo(Gizmo, DetHashWithVersion): VERSION = "001" def __init__(self, x: int): self.x = x @classmethod def default(cls): return WidgetGizmo(0) Gizmo.register("default_widget", "default")(WidgetGizmo) lazy = Lazy(Gizmo, params=Params({"type": "widget", "x": 1})) hash_before = det_hash(lazy) WidgetGizmo.VERSION = "001" assert hash_before == det_hash(lazy) WidgetGizmo.VERSION = "002" assert hash_before != det_hash(lazy) assert lazy.construct().x == 1 # type: ignore[attr-defined] default_lazy = Lazy( Gizmo, params=Params( { "type": "default_widget", } ), ) assert hash_before != det_hash(default_lazy) assert det_hash(lazy) != det_hash(default_lazy) hash_before = det_hash(default_lazy) WidgetGizmo.VERSION = "003" assert hash_before != det_hash(default_lazy) assert default_lazy.construct().x == 0 # type: ignore[attr-defined] def test_from_params_that_takes_step_directly(self): class FakeStepBase(Step): def run(self, test_input: int) -> int: # type: ignore return test_input @FakeStepBase.register("fake_step") class FakeStep(FakeStepBase): def run(self, test_input: int) -> int: # type: ignore return test_input * 2 class FromParamsWithStepInput(FromParams): def __init__(self, fake_step: FakeStepBase): self.fake_step = fake_step o = FromParamsWithStepInput.from_params( {"fake_step": {"type": "fake_step", "test_input": 1}} ) assert isinstance(o.fake_step, FakeStep) class MyClass(FromParams): def __init__(self, my_int: int, my_bool: bool = False) -> None: self.my_int = my_int self.my_bool = my_bool class Foo(FromParams): def __init__(self, a: int = 1) -> None: self.a = a class Bar(FromParams): def __init__(self, foo: Foo) -> None: self.foo = foo class Baz(FromParams): def __init__(self, bar: Lazy[Bar]) -> None: self._bar = bar @property def bar(self): return self._bar.construct() ================================================ FILE: tests/common/params_test.py ================================================ import json import os import re from collections import OrderedDict import pytest from tango.common.exceptions import ConfigurationError from tango.common.params import ( Params, infer_and_cast, remove_keys_from_params, with_overrides, ) from tango.common.testing import TangoTestCase class TestParams(TangoTestCase): @pytest.mark.parametrize("extension", ["jsonnet", "yaml"]) def test_load_from_file(self, extension): filename = self.FIXTURES_ROOT / "common" / f"params_example.{extension}" params = Params.from_file(filename) assert params["model"]["type"] == "classifier" def test_replace_none(self): params = Params({"a": "None", "b": [1.0, "None", 2], "c": {"d": "None"}}) assert params["a"] is None assert params["b"][1] is None assert params["c"]["d"] is None def test_init_with_different_types(self): assert Params({"a": 1, "b": 2}) == Params(Params({"a": 1, "b": 2})) def test_bad_unicode_environment_variables(self): filename = self.FIXTURES_ROOT / "common" / "params_example.jsonnet" os.environ["BAD_ENVIRONMENT_VARIABLE"] = "\udce2" Params.from_file(filename) del os.environ["BAD_ENVIRONMENT_VARIABLE"] def test_with_overrides(self): original = { "foo": {"bar": {"baz": 3}, "x": 0}, "bar": ["a", "b", "c"], "baz": {"bar": 2, "y": 3, "x": [0, 1, 2]}, } overrides = { "foo.bar": {"z": 2}, "bar.0": "d", "baz.bar": 1, "baz.x": [0, 0], "z": 2, } assert with_overrides(original, overrides) == { "foo": {"bar": {"z": 2}, "x": 0}, "bar": ["d", "b", "c"], "baz": {"bar": 1, "y": 3, "x": [0, 0]}, "z": 2, } def test_bad_overrides(self): with pytest.raises(ValueError, match="contains unused keys"): with_overrides({"foo": [0, 1, 2]}, {"foo.3": 4}) with pytest.raises(ValueError, match="expected list or dict"): with_overrides({"foo": 3}, {"foo.x": 2}) @pytest.mark.parametrize("input_type", [dict, str]) def test_overrides(self, input_type): filename = self.FIXTURES_ROOT / "common" / "params_example.jsonnet" overrides = { "data_path": "train.txt", "model.type": "new_classifier", "model.layers.0.activation": "gelu", "model.layers.1": {"type": "classifier"}, } params = Params.from_file( filename, overrides if input_type == dict else json.dumps(overrides) ) assert params["data_path"] == "train.txt" assert params["model"]["type"] == "new_classifier" assert len(params["model"]["layers"]) == 2 assert params["model"]["layers"][0]["activation"] == "gelu" assert params["model"]["layers"][1]["type"] == "classifier" def test_as_flat_dict(self): params = Params({"a": 10, "b": {"c": 20, "d": "stuff"}}).as_flat_dict() assert params == {"a": 10, "b.c": 20, "b.d": "stuff"} def test_jsonnet_features(self): config_file = self.TEST_DIR / "config.jsonnet" with open(config_file, "w") as f: f.write( """{ // This example is copied straight from the jsonnet docs person1: { name: "Alice", welcome: "Hello " + self.name + "!", }, person2: self.person1 { name: "Bob" }, }""" ) params = Params.from_file(config_file) alice = params.pop("person1") bob = params.pop("person2") assert alice.as_dict() == {"name": "Alice", "welcome": "Hello Alice!"} assert bob.as_dict() == {"name": "Bob", "welcome": "Hello Bob!"} params.assert_empty("TestParams") def test_regexes_with_backslashes(self): bad_regex = self.TEST_DIR / "bad_regex.jsonnet" good_regex = self.TEST_DIR / "good_regex.jsonnet" with open(bad_regex, "w") as f: f.write(r'{"myRegex": "a\.b"}') with open(good_regex, "w") as f: f.write(r'{"myRegex": "a\\.b"}') with pytest.raises(RuntimeError): Params.from_file(bad_regex) params = Params.from_file(good_regex) regex = params["myRegex"] assert re.match(regex, "a.b") assert not re.match(regex, "a-b") # Check roundtripping good_regex2 = self.TEST_DIR / "good_regex2.jsonnet" with open(good_regex2, "w") as f: f.write(json.dumps(params.as_dict())) params2 = Params.from_file(good_regex2) assert params.as_dict() == params2.as_dict() def test_env_var_substitution(self): substitutor = self.TEST_DIR / "substitutor.jsonnet" key = "TEST_ENV_VAR_SUBSTITUTION" assert os.environ.get(key) is None with open(substitutor, "w") as f: f.write(f'{{"path": std.extVar("{key}")}}') # raises without environment variable set with pytest.raises(RuntimeError): Params.from_file(substitutor) os.environ[key] = "PERFECT" params = Params.from_file(substitutor) assert params["path"] == "PERFECT" del os.environ[key] def test_as_ordered_dict(self): # keyD > keyC > keyE; keyDA > keyDB; Next all other keys alphabetically preference_orders = [["keyD", "keyC", "keyE"], ["keyDA", "keyDB"]] params = Params( { "keyC": "valC", "keyB": "valB", "keyA": "valA", "keyE": "valE", "keyD": {"keyDB": "valDB", "keyDA": "valDA"}, } ) ordered_params_dict = params.as_ordered_dict(preference_orders) expected_ordered_params_dict = OrderedDict( { "keyD": {"keyDA": "valDA", "keyDB": "valDB"}, "keyC": "valC", "keyE": "valE", "keyA": "valA", "keyB": "valB", } ) assert json.dumps(ordered_params_dict) == json.dumps(expected_ordered_params_dict) def test_to_file(self): # Test to_file works with or without preference orders params_dict = {"keyA": "valA", "keyB": "valB"} expected_ordered_params_dict = OrderedDict({"keyB": "valB", "keyA": "valA"}) params = Params(params_dict) file_path = self.TEST_DIR / "config.jsonnet" # check with preference orders params.to_file(file_path, [["keyB", "keyA"]]) with open(file_path, "r") as handle: ordered_params_dict = OrderedDict(json.load(handle)) assert json.dumps(expected_ordered_params_dict) == json.dumps(ordered_params_dict) # check without preference orders doesn't give error params.to_file(file_path) def test_infer_and_cast(self): lots_of_strings = { "a": ["10", "1.3", "true"], "b": {"x": 10, "y": "20.1", "z": "other things"}, "c": "just a string", } casted = { "a": [10, 1.3, True], "b": {"x": 10, "y": 20.1, "z": "other things"}, "c": "just a string", } assert infer_and_cast(lots_of_strings) == casted contains_bad_data = {"x": 10, "y": int} with pytest.raises(ValueError, match="cannot infer type"): infer_and_cast(contains_bad_data) params = Params(lots_of_strings) assert params.as_dict() == lots_of_strings assert params.as_dict(infer_type_and_cast=True) == casted def test_pop_choice(self): choices = ["my_model", "other_model"] params = Params({"model": "my_model"}) assert params.pop_choice("model", choices) == "my_model" params = Params({"model": "non_existent_model"}) with pytest.raises(ConfigurationError): params.pop_choice("model", choices) params = Params({"model": "module.submodule.ModelName"}) assert params.pop_choice("model", choices) == "module.submodule.ModelName" params = Params({"model": "module.submodule.ModelName"}) with pytest.raises(ConfigurationError): params.pop_choice("model", choices, allow_class_names=False) def test_remove_keys_from_params(self): filename = self.FIXTURES_ROOT / "common" / "params_example.jsonnet" params = Params.from_file(filename) assert params["model"]["layers"][0]["activation"] == "relu" assert params["model"]["layers"][1]["activation"] == "softmax" remove_keys_from_params(params, keys=["activation"]) assert "activation" not in params["model"]["layers"][0] assert "activation" not in params["model"]["layers"][1] ================================================ FILE: tests/common/registrable_test.py ================================================ import pytest from tango.common.exceptions import ConfigurationError from tango.common.registrable import Registrable from tango.common.testing import TangoTestCase from tango.step import Step class TestRegistrable(TangoTestCase): def test_basic_functionality(self): class MockBaseClass(Registrable): pass assert "mock-1" not in MockBaseClass.list_available() @MockBaseClass.register("mock-1") class MockSubclass1(MockBaseClass): pass assert MockBaseClass in Registrable._registry assert MockBaseClass.by_name("mock-1") == MockSubclass1 # Verify that registering under a name that already exists # causes a ConfigurationError. with pytest.raises(ConfigurationError): @MockBaseClass.register("mock-1") class MockAlternate(MockBaseClass): pass # Registering under a name that already exists should overwrite # if exist_ok=True. @MockBaseClass.register("mock-1", exist_ok=True) class MockAlternate2(MockBaseClass): pass assert MockBaseClass.by_name("mock-1") == MockAlternate2 # Test that we get a suggestion when the name is close. with pytest.raises(ConfigurationError) as exc: MockBaseClass.by_name("mock_1") assert "did you mean 'mock-1'?" in str(exc.value) def test_registering_step_by_reserved_name(self): with pytest.raises(ConfigurationError, match="cannot use the name 'ref'"): @Step.register("ref") class BadStep(Step): pass def test_search_modules(self): Step.search_modules("foo-bar-baz-non-existent") ================================================ FILE: tests/common/sequences_test.py ================================================ import os from tempfile import TemporaryDirectory import pytest from tango.common.sequences import ( ConcatenatedSequence, MappedSequence, ShuffledSequence, SlicedSequence, SqliteSparseSequence, ) def assert_equal_including_exceptions(expected_fn, actual_fn): try: expected = expected_fn() except Exception as e: with pytest.raises(e.__class__): actual_fn() else: assert expected == actual_fn() def test_shuffled_sequence(): seq = ShuffledSequence(list(range(10))) assert 5 in seq assert len(seq) == 10 def test_sliced_sequence(): seq = SlicedSequence(list(range(10)), slice(10)) assert len(seq) == 10 assert seq[0] == 0 assert seq[-1] == 9 seq2 = seq[-2:] assert len(seq2) == 2 def test_concatenated_sequence(): l1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] l2 = ConcatenatedSequence([0, 1], [], [2, 3, 4], [5, 6, 7, 8, 9], []) # __len__() assert len(l1) == len(l2) # index() for item in l1 + [999]: # no indices assert_equal_including_exceptions(lambda: l1.index(item), lambda: l2.index(item)) # only start index for index in range(-15, 15): assert_equal_including_exceptions( lambda: l1.index(item, index), lambda: l2.index(item, index) ) # start and stop index for start_index in range(-15, 15): for end_index in range(-15, 15): assert_equal_including_exceptions( lambda: l1.index(item, start_index, end_index), lambda: l2.index(item, start_index, end_index), ) # __getitem__() for index in range(-15, 15): assert_equal_including_exceptions(lambda: l1[index], lambda: l2[index]) for start_index in range(-15, 15): for end_index in range(-15, 15): assert_equal_including_exceptions( lambda: l1[start_index:end_index], lambda: list(l2[start_index:end_index]) ) # count() for item in l1 + [999]: assert_equal_including_exceptions(lambda: l1.count(item), lambda: l2.count(item)) # __contains__() for item in l1 + [999]: assert_equal_including_exceptions(lambda: item in l1, lambda: item in l2) def test_sqlite_sparse_sequence(): with TemporaryDirectory(prefix="test_sparse_sequence-") as temp_dir: s = SqliteSparseSequence(os.path.join(temp_dir, "test.sqlite")) assert len(s) == 0 s.extend([]) assert len(s) == 0 s.append("one") assert len(s) == 1 s.extend(["two", "three"]) s.insert(1, "two") assert s[1] == "two" assert s.count("two") == 2 ss = s[1:3] assert list(ss) == ["two", "two"] del s[1:3] assert len(s) == 2 assert s[-1] == "three" s.clear() assert len(s) == 0 def test_mapped_sequence(): my_very_long_sequence = ["John", "Paul", "George", "Ringo"] m = MappedSequence(lambda x: len(x), my_very_long_sequence) assert m[0] == 4 assert len(m) == len(my_very_long_sequence) for i in range(len(m)): assert m[i] == m[i:][0] ================================================ FILE: tests/common/util_test.py ================================================ import os import time from pathlib import Path import pytest from flaky import flaky from tango.common.testing import TangoTestCase from tango.common.util import ( could_be_class_name, find_integrations, find_submodules, resolve_module_name, threaded_generator, ) class TestResolveModuleName(TangoTestCase): def setup_method(self): super().setup_method() self._work_dir_restore = os.getcwd() os.chdir(self.TEST_DIR) def teardown_method(self): super().teardown_method() os.chdir(self._work_dir_restore) def test_with_package_init_file(self): path = Path("fake_package/fake_module/__init__.py") (self.TEST_DIR / path.parent).mkdir(parents=True) open(path, "w").close() open(path.parent.parent / "__init__.py", "w").close() assert resolve_module_name(str(path)) == ("fake_package.fake_module", Path(".")) def test_with_submodule(self): path = Path("fake_package/fake_module") (self.TEST_DIR / path).mkdir(parents=True) open(path / "__init__.py", "w").close() open(path.parent / "__init__.py", "w").close() assert resolve_module_name(str(path)) == ("fake_package.fake_module", Path(".")) def test_with_module_in_child_directory(self): path = Path("some_dir/fake_module.py") (self.TEST_DIR / path.parent).mkdir(parents=True) open(path, "w").close() assert resolve_module_name(str(path)) == ("fake_module", Path("./some_dir")) def test_find_submodules(): assert "tango.version" in set(find_submodules()) assert "tango.common.registrable" in set(find_submodules()) assert "tango.common" in set(find_submodules(recursive=False)) assert "tango.common.registrable" not in set(find_submodules(recursive=False)) assert "tango.integrations.torch" in set(find_submodules("integrations")) assert "tango.integrations.torch" not in set(find_submodules(exclude={"tango.integrations*"})) def test_find_integrations(): integrations = set(find_integrations()) assert "tango.integrations.torch" in integrations assert "tango.integrations.torch.format" not in integrations @pytest.mark.parametrize( "name, result", [ ("", False), ("foo.Bar", True), ("foo.Bar.", False), ("1foo.Bar", False), ("lib.my_package.MyClass", True), ], ) def test_could_be_class_name(name: str, result: bool): assert could_be_class_name(name) is result @flaky(max_runs=3) def test_threaded_generator(): def generate_slowly(): for i in range(10): yield i time.sleep(0.1) start = time.time() for i in threaded_generator(generate_slowly()): time.sleep(0.1) end = time.time() assert end - start < 11 ================================================ FILE: tests/end_to_end/test_dataset_dict_from_separate_steps.py ================================================ from typing import Any, Sequence from tango import Format, JsonFormat, Step from tango.common import DatasetDict from tango.common.testing import run_experiment @Step.register("train_data") class TrainData(Step): DETERMINISTIC = True CACHEABLE = False def run(self) -> Sequence[int]: # type: ignore return list(range(10)) @Step.register("val_data") class ValData(Step): DETERMINISTIC = True CACHEABLE = False def run(self) -> Sequence[int]: # type: ignore return list(range(10, 20)) @Step.register("save_data") class SaveData(Step): DETERMINISTIC = True CACHEABLE = True FORMAT: Format = JsonFormat() def run(self, dataset_dict: DatasetDict) -> Any: # type: ignore return dataset_dict.splits def test_experiment(): with run_experiment( { "steps": { "train_data": { "type": "train_data", }, "val_data": { "type": "val_data", }, "saved_data": { "type": "save_data", "dataset_dict": { "splits": { "train": {"type": "ref", "ref": "train_data"}, "val": {"type": "ref", "ref": "val_data"}, } }, }, } } ) as run_dir: assert (run_dir / "saved_data").is_dir() fmt = JsonFormat() data = fmt.read(run_dir / "saved_data") assert data["train"] == list(range(10)) assert data["val"] == list(range(10, 20)) ================================================ FILE: tests/end_to_end/test_lazy_input_with_another_step.py ================================================ from dataclasses import dataclass from tango import Format, JsonFormat, Step from tango.common import FromParams, Lazy from tango.common.testing import run_experiment @dataclass class Foo(FromParams): number: float @Step.register("generate_number") class GenerateNumberStep(Step): DETERMINISTIC = True CACHEABLE = True FORMAT: Format = JsonFormat() def run(self) -> float: # type: ignore[override] return 1.0 @Step.register("lazy_input") class StepWithLazyInput(Step): DETERMINISTIC = True CACHEABLE = True FORMAT: Format = JsonFormat() def run(self, foo: Lazy[Foo]) -> float: # type: ignore[override] foo = foo.construct() assert isinstance(foo, Foo) assert isinstance(foo.number, float) return foo.number def test_experiment(): with run_experiment( { "steps": { "gen_number": { "type": "generate_number", }, "get_number": { "type": "lazy_input", "foo": { "number": { "type": "ref", "ref": "gen_number", } }, }, } } ) as run_dir: assert (run_dir / "get_number").is_dir() fmt: Format = JsonFormat() data = fmt.read(run_dir / "get_number") assert data == 1.0 ================================================ FILE: tests/end_to_end/test_multicore_cli.py ================================================ import pytest from tango.common.exceptions import CliRunError from tango.common.logging import initialize_logging, teardown_logging from tango.common.testing import TangoTestCase class TestExperiment(TangoTestCase): def setup_method(self): super().setup_method() initialize_logging() self.config = { "steps": { "step1": { "type": "sleep-print-maybe-fail", "string": "string_to_pass_down", "seconds": 1, "fail": True, }, "step2": { "type": "sleep-print-maybe-fail", "string": {"type": "ref", "ref": "step1"}, "seconds": 1, "fail": False, }, "step3": { "type": "sleep-print-maybe-fail", "string": "This may or may not fail!", "seconds": 3, "fail": False, }, } } def teardown_method(self): super().teardown_method() teardown_logging() def test_experiment(self, caplog): with pytest.raises(CliRunError): self.run( self.config, multicore=True, parallelism=2, ) latest_outputs = self.TEST_DIR / "workspace" / "latest" num_executed = 0 for out in latest_outputs.iterdir(): if (out / "cache-metadata.json").exists(): num_executed += 1 assert num_executed == 1 def test_experiment_with_overrides(self, caplog): import json self.run( self.config, multicore=True, parallelism=2, overrides=json.dumps({"steps.step1.fail": False}), ) latest_outputs = self.TEST_DIR / "workspace" / "latest" num_executed = 0 for out in latest_outputs.iterdir(): if (out / "cache-metadata.json").exists(): num_executed += 1 assert num_executed == 3 ================================================ FILE: tests/end_to_end/test_non_cacheable_into_cacheable_multiple_runs.py ================================================ import random from tango import Step from tango.common.testing import TangoTestCase @Step.register("give_me_a_number") class GiveMeANumber(Step): DETERMINISTIC = True CACHEABLE = False def run(self, what_number: int) -> int: # type: ignore return what_number @Step.register("random_int") class RandomInt(Step): DETERMINISTIC = True CACHEABLE = True def run(self, lower_bound: int, upper_bound: int) -> int: # type: ignore return random.randint(lower_bound, upper_bound) class TestExperiment(TangoTestCase): def test_experiment(self, caplog): config = { "steps": { "a_number": { "type": "give_me_a_number", "what_number": 3, }, "final_number": { "type": "random_int", "lower_bound": 0, "upper_bound": {"type": "ref", "ref": "a_number"}, }, } } self.run(config) self.run(config, overrides={"steps.final_number.lower_bound": 1}) ================================================ FILE: tests/end_to_end/test_registered_runs.py ================================================ from tango import Format, JsonFormat, Step from tango.common.testing import TangoTestCase @Step.register("return_a_number") class ReturnANumber(Step): DETERMINISTIC = True CACHEABLE = True FORMAT: Format = JsonFormat() def run(self, what_number: int) -> int: # type: ignore return what_number class TestExperiment(TangoTestCase): def test_experiment_updates_latest_run_output(self, caplog): config = { "steps": { "a_number": { "type": "return_a_number", "what_number": 3, }, } } self.run(config) assert (self.TEST_DIR / "workspace" / "latest" / "a_number").exists() fmt: Format = JsonFormat() data = fmt.read(self.TEST_DIR / "workspace" / "latest" / "a_number") assert data == 3 config = { "steps": { "a_number": { "type": "return_a_number", "what_number": 5, }, } } self.run(config) data = fmt.read(self.TEST_DIR / "workspace" / "latest" / "a_number") assert data == 5 ================================================ FILE: tests/end_to_end/test_run_single_step.py ================================================ from tango.common.testing import TangoTestCase class TestRunSingleStep(TangoTestCase): def test_run_single_step(self): config = { "steps": { "strA": {"type": "string", "result": "Hello, "}, "strB": {"type": "string", "result": "World"}, "concatenated": { "type": "concat_strings", "string1": {"type": "ref", "ref": "strA"}, "string2": {"type": "ref", "ref": "strB"}, }, } } num_other_files = 2 # out.log and stepinfo.json # Regular run contains all step outputs. self.run(config) latest_outputs = self.TEST_DIR / "workspace" / "latest" assert len(list(latest_outputs.iterdir())) == num_other_files + 3 # Running a single step with no dependencies should have a single output. self.run(config, step_name="strB") latest_outputs = self.TEST_DIR / "workspace" / "latest" assert len(list(latest_outputs.iterdir())) == num_other_files + 1 # Running a single step with one or more dependencies will also run the step's dependencies. self.run(config, step_name="concatenated") latest_outputs = self.TEST_DIR / "workspace" / "latest" assert len(list(latest_outputs.iterdir())) == num_other_files + 3 ================================================ FILE: tests/end_to_end/test_step_indexing.py ================================================ from tango.common.testing import TangoTestCase from tango.workspaces import LocalWorkspace class TestStepIndexing(TangoTestCase): def test_step_indexing(self): run_name = "run1" config = { "steps": { "list": {"type": "range_step", "start": 0, "end": 3}, "added": { "type": "add_numbers", "a_number": 2, "b_number": {"type": "ref", "ref": "list", "key": 1}, }, } } self.run(config, name=run_name) workspace = LocalWorkspace(self.TEST_DIR / "workspace") result = workspace.step_result_for_run(run_name, "added") assert result == 3 ================================================ FILE: tests/end_to_end/test_steps_that_fail.py ================================================ from collections import Counter from typing import MutableMapping import pytest from tango import Step from tango.common.exceptions import CliRunError from tango.common.testing import TangoTestCase step_execution_count: MutableMapping[str, int] = Counter() @Step.register("step_a") class StepA(Step): def run(self, what_number: int) -> int: # type: ignore global step_execution_count step_execution_count["a"] += 1 return what_number @Step.register("step_b") class StepB(Step): def run(self, what_number: int) -> int: # type: ignore global step_execution_count step_execution_count["b"] += 1 return what_number step_should_fail: bool = True @Step.register("step_fail") class StepFail(Step): def run(self, what_number: int) -> int: # type: ignore global step_execution_count step_execution_count["fail"] += 1 global step_should_fail if step_should_fail: raise RuntimeError("Step should fail") else: return what_number class TestExperiment(TangoTestCase): def test_experiment(self, caplog): global step_should_fail config = { "steps": { "a_number": { "type": "step_a", "what_number": 3, }, "fail_number": { "type": "step_fail", "what_number": {"type": "ref", "ref": "a_number"}, }, "b_number": { "type": "step_b", "what_number": {"type": "ref", "ref": "fail_number"}, }, } } global step_should_fail global step_execution_count step_should_fail = True with pytest.raises(CliRunError): self.run(config) assert step_execution_count["a"] == 1 assert step_execution_count["fail"] == 1 assert step_execution_count["b"] == 0 step_should_fail = False self.run(config) assert step_execution_count["a"] == 1 assert step_execution_count["fail"] == 2 assert step_execution_count["b"] == 1 ================================================ FILE: tests/end_to_end/test_uncacheable_leaf_steps.py ================================================ from tango import Step from tango.common.testing import TangoTestCase, run_experiment from tango.common.testing.steps import MakeNumber # noqa:F401 stored_number = None @Step.register("store_number_globally") class StoreNumberGlobally(Step): DETERMINISTIC = True CACHEABLE = False def run(self, number: int) -> None: # type: ignore global stored_number stored_number = number class TestExperiment(TangoTestCase): def test_experiment(self, caplog): config = { "steps": { "a_number": { "type": "make_number", "what_number": 3, }, "store_number": { "type": "store_number_globally", "number": {"type": "ref", "ref": "a_number"}, }, } } global stored_number assert stored_number is None self.run(config) assert stored_number == 3 class TestExperimentMulticore(TangoTestCase): def test_experiment(self, caplog): file_name = self.TEST_DIR / "number_file.txt" assert not file_name.exists() with run_experiment( { "steps": { "a_number": { "type": "make_number", "what_number": 3, }, "store_number": { "type": "store_number_in_file", "number": {"type": "ref", "ref": "a_number"}, "file_name": str(file_name), }, } }, multicore=True, ): with open(file_name) as file_ref: number = file_ref.read() assert int(number) == 3 ================================================ FILE: tests/executor_test.py ================================================ from tango.common.testing import TangoTestCase from tango.common.testing.steps import SleepPrintMaybeFail # noqa:F401 from tango.executor import Executor from tango.step import Step from tango.step_graph import StepGraph from tango.workspaces import LocalWorkspace @Step.register("sum_numbers") class AdditionStep(Step): DETERMINISTIC = True CACHEABLE = True def run(self, a: int, b: int) -> int: # type: ignore return a + b class TestExecutor(TangoTestCase): def test_executor(self): workspace = LocalWorkspace(self.TEST_DIR) step = AdditionStep(a=1, b=2) step_graph = StepGraph.from_params({"sum": {"type": "sum_numbers", "a": 1, "b": 2}}) executor = Executor(workspace) assert len(executor.workspace.step_cache) == 0 output = executor.execute_step_graph(step_graph) assert "sum" in output.successful assert len(executor.workspace.step_cache) == 1 assert executor.workspace.step_cache[step] == 3 def test_executor_with_failing_steps(self): workspace = LocalWorkspace(self.TEST_DIR) step_graph = StepGraph.from_params( { "successful_step": { "type": "sleep-print-maybe-fail", "string": "This ran perfectly.", "seconds": 0, "fail": False, }, "failing_step": { "type": "sleep-print-maybe-fail", "string": "This should fail.", "seconds": 0, "fail": True, }, "dependent_step": { "type": "sleep-print-maybe-fail", "string": {"type": "ref", "ref": "failing_step"}, "seconds": 0, "fail": False, }, } ) executor = Executor(workspace) assert len(executor.workspace.step_cache) == 0 output = executor.execute_step_graph(step_graph) assert "successful_step" in output.successful assert "failing_step" in output.failed assert "dependent_step" in output.not_run assert len(executor.workspace.step_cache) == 1 ================================================ FILE: tests/executors/__init__.py ================================================ ================================================ FILE: tests/executors/multicore_executor_test.py ================================================ import time import pytest from tango.common.logging import initialize_logging from tango.common.testing import TangoTestCase from tango.common.testing.steps import SleepPrintMaybeFail from tango.executors.multicore_executor import MulticoreExecutor from tango.step_graph import StepGraph from tango.workspaces import LocalWorkspace class TestMulticoreExecutor(TangoTestCase): def setup_method(self): super().setup_method() initialize_logging() def test_simple_execution_in_parallel(self): step_graph = StepGraph( { "step1": SleepPrintMaybeFail(string="hello", seconds=5, fail=False), "step2": SleepPrintMaybeFail(string="hi", seconds=5, fail=False), } ) executor = MulticoreExecutor(workspace=LocalWorkspace(self.TEST_DIR), parallelism=2) start_time = time.time() executor.execute_step_graph(step_graph) end_time = time.time() time_taken = end_time - start_time assert time_taken < 10 # TODO: will this be flaky? assert len(executor.workspace.step_cache) == 2 def test_more_processes_ready_than_parallelism(self): step_graph = StepGraph( { "step1": SleepPrintMaybeFail(string="hello", seconds=5, fail=False), "step2": SleepPrintMaybeFail(string="hi", seconds=5, fail=False), "step3": SleepPrintMaybeFail(string="howdy", seconds=5, fail=False), } ) executor = MulticoreExecutor(workspace=LocalWorkspace(self.TEST_DIR), parallelism=2) start_time = time.time() executor.execute_step_graph(step_graph) end_time = time.time() time_taken = end_time - start_time assert 10 < time_taken < 20 # TODO: will this be flaky? assert len(executor.workspace.step_cache) == 3 @pytest.mark.parametrize("parallelism", [1, 2, 3]) def test_failing_step_no_downstream_task(self, parallelism): step_graph = StepGraph.from_params( { "step1": { "type": "sleep-print-maybe-fail", "string": "string_to_pass_down", "seconds": 0, "fail": False, }, "step2": { "type": "sleep-print-maybe-fail", "string": {"type": "ref", "ref": "step1"}, "seconds": 0, "fail": False, }, "step3": { "type": "sleep-print-maybe-fail", "string": "This is going to fail!", "seconds": 0, "fail": True, }, } ) executor = MulticoreExecutor( workspace=LocalWorkspace(self.TEST_DIR), parallelism=parallelism, ) executor.execute_step_graph(step_graph) assert len(executor.workspace.step_cache) == 2 @pytest.mark.parametrize("parallelism", [1, 2, 3]) def test_failing_step_with_downstream_task(self, parallelism): step_graph = StepGraph.from_params( { "step1": { "type": "sleep-print-maybe-fail", "string": "string_to_pass_down", "seconds": 0, "fail": True, }, "step2": { "type": "sleep-print-maybe-fail", "string": {"type": "ref", "ref": "step1"}, "seconds": 0, "fail": False, }, "step3": { "type": "sleep-print-maybe-fail", "string": "This is going to fail!", "seconds": 0, "fail": False, }, } ) executor = MulticoreExecutor( workspace=LocalWorkspace(self.TEST_DIR), parallelism=parallelism, ) executor.execute_step_graph(step_graph) assert len(executor.workspace.step_cache) == 1 @pytest.mark.parametrize("parallelism", [1, 2, 3]) def test_failing_step_with_further_downstream_task(self, parallelism): step_graph = StepGraph.from_params( { "step1": { "type": "sleep-print-maybe-fail", "string": "string_to_pass_down", "seconds": 0, "fail": True, }, "step2": { "type": "sleep-print-maybe-fail", "string": {"type": "ref", "ref": "step1"}, "seconds": 0, "fail": False, }, "step3": { "type": "sleep-print-maybe-fail", "string": {"type": "ref", "ref": "step2"}, "seconds": 0, "fail": False, }, } ) executor = MulticoreExecutor( workspace=LocalWorkspace(self.TEST_DIR), parallelism=parallelism, ) executor.execute_step_graph(step_graph) assert len(executor.workspace.step_cache) == 0 def test_uncacheable_failing_step_no_downstream_task(self): step_graph = StepGraph.from_params( { "step1": { "type": "sleep-print-maybe-fail", "string": "string_to_pass_down", "seconds": 0, "fail": False, }, "step2": { "type": "sleep-print-maybe-fail", "string": {"type": "ref", "ref": "step1"}, "seconds": 0, "fail": False, }, "step3": { "type": "sleep-print-maybe-fail", "string": "This is going to fail!", "seconds": 0, "fail": True, "cache_results": False, }, } ) executor = MulticoreExecutor( workspace=LocalWorkspace(self.TEST_DIR), parallelism=2, ) executor.execute_step_graph(step_graph) assert len(executor.workspace.step_cache) == 2 def test_uncacheable_failing_step_with_downstream_task(self): step_graph = StepGraph.from_params( { "step1": { "type": "sleep-print-maybe-fail", "string": "string_to_pass_down", "seconds": 0, "fail": True, "cache_results": False, }, "step2": { "type": "sleep-print-maybe-fail", "string": {"type": "ref", "ref": "step1"}, "seconds": 0, "fail": False, }, "step3": { "type": "sleep-print-maybe-fail", "string": "This is going to fail!", "seconds": 0, "fail": False, }, } ) executor = MulticoreExecutor( workspace=LocalWorkspace(self.TEST_DIR), parallelism=2, ) executor.execute_step_graph(step_graph) assert len(executor.workspace.step_cache) == 1 @pytest.mark.parametrize("parallelism", [1, 2, 3]) def test_steps_with_their_own_multiprocessing(self, parallelism): step_graph = StepGraph.from_params( { "step1": {"type": "multiprocessing_step", "num_proc": 2}, "step2": {"type": "multiprocessing_step", "num_proc": 3}, "step3": {"type": "multiprocessing_step", "num_proc": 1}, } ) executor = MulticoreExecutor( workspace=LocalWorkspace(self.TEST_DIR), parallelism=parallelism, ) executor.execute_step_graph(step_graph) assert len(executor.workspace.step_cache) == 3 ================================================ FILE: tests/format_test.py ================================================ from typing import Dict, Iterable, Optional import pytest from tango.common.testing import TangoTestCase from tango.format import _OPEN_FUNCTIONS, DillFormat, JsonFormat, TextFormat class TestFormat(TangoTestCase): @pytest.mark.parametrize("compress", _OPEN_FUNCTIONS.keys()) def test_dill_format(self, compress: Optional[str]): artifact = "Hello, World!" format = DillFormat[str](compress) format.write(artifact, self.TEST_DIR) assert format.read(self.TEST_DIR) == artifact assert "compress" in format.to_params() @pytest.mark.parametrize("compress", _OPEN_FUNCTIONS.keys()) def test_iterable_dill_format(self, compress: Optional[str]): r = (x + 1 for x in range(10)) format = DillFormat[Iterable[int]](compress) format.write(r, self.TEST_DIR) r2 = format.read(self.TEST_DIR) assert [x + 1 for x in range(10)] == list(r2) assert "compress" in format.to_params() @pytest.mark.parametrize("compress", _OPEN_FUNCTIONS.keys()) def test_json_format(self, compress: Optional[str]): artifact = {"Hello, World!": "Hi!"} format = JsonFormat[Dict[str, str]](compress) format.write(artifact, self.TEST_DIR) assert format.read(self.TEST_DIR) == artifact assert "compress" in format.to_params() @pytest.mark.parametrize("compress", _OPEN_FUNCTIONS.keys()) def test_iterable_json_format(self, compress: Optional[str]): r = (x + 1 for x in range(10)) format = JsonFormat[Iterable[int]](compress) format.write(r, self.TEST_DIR) r2 = format.read(self.TEST_DIR) assert [x + 1 for x in range(10)] == list(r2) assert "compress" in format.to_params() def test_iterable_text_format(self): numbers = ["ichi", "ni", "san"] l1 = iter(numbers) format = TextFormat() format.write(l1, self.TEST_DIR) l2 = format.read(self.TEST_DIR) assert list(l2) == numbers ================================================ FILE: tests/integrations/__init__.py ================================================ ================================================ FILE: tests/integrations/beaker/__init__.py ================================================ ================================================ FILE: tests/integrations/beaker/conftest.py ================================================ import os import uuid from pathlib import Path from typing import Generator import pytest from beaker import Beaker from tango.common import util from tango.integrations.beaker.common import Constants from tango.step import Step @pytest.fixture(autouse=True) def patched_cache_dir(tmp_path, monkeypatch) -> Path: monkeypatch.setattr(util, "tango_cache_dir", lambda: tmp_path) return tmp_path @pytest.fixture(autouse=True) def patched_unique_id_suffix(monkeypatch) -> str: UNIQUE_ID_SUFFIX = os.environ.get("GITHUB_SHA", "")[:6] + "-" + str(uuid.uuid1())[:6] monkeypatch.setattr(Step, "_UNIQUE_ID_SUFFIX", UNIQUE_ID_SUFFIX) return UNIQUE_ID_SUFFIX @pytest.fixture(autouse=True) def patched_constants_prefix(monkeypatch) -> str: PREFIX = os.environ.get("GITHUB_SHA", "A")[:6] + "-" + str(uuid.uuid1())[:6] + "-" monkeypatch.setattr(Constants, "STEP_ARTIFACT_PREFIX", "tango-step-" + PREFIX) monkeypatch.setattr(Constants, "RUN_ARTIFACT_PREFIX", "tango-run-" + PREFIX) monkeypatch.setattr(Constants, "ENTRYPOINT_DATASET_PREFIX", "tango-entrypoint-" + PREFIX) monkeypatch.setattr(Constants, "STEP_GRAPH_ARTIFACT_PREFIX", "tango-step-graph-" + PREFIX) monkeypatch.setattr(Constants, "STEP_EXPERIMENT_PREFIX", "tango-step-" + PREFIX) return PREFIX @pytest.fixture def beaker_workspace_name() -> str: return "ai2/tango-beaker-testing" @pytest.fixture def beaker_workspace( beaker_workspace_name: str, patched_unique_id_suffix: str, patched_constants_prefix: str ) -> Generator[str, None, None]: beaker = Beaker.from_env(default_workspace=beaker_workspace_name) yield beaker_workspace_name # Remove experiments. # for experiment in beaker.workspace.experiments(match=patched_constants_prefix): # beaker.experiment.delete(experiment) # Remove datasets. for dataset in beaker.workspace.datasets(match=patched_unique_id_suffix): beaker.dataset.delete(dataset) for dataset in beaker.workspace.datasets(match=patched_constants_prefix): beaker.dataset.delete(dataset) ================================================ FILE: tests/integrations/beaker/executor_test.py ================================================ import petname import pytest from beaker import DataMount from tango.common.exceptions import ConfigurationError from tango.common.testing import run_experiment from tango.executor import Executor from tango.integrations.beaker.executor import BeakerExecutor from tango.integrations.beaker.workspace import BeakerWorkspace from tango.settings import TangoGlobalSettings from tango.workspaces import default_workspace def test_from_params(beaker_workspace_name: str): executor = Executor.from_params( dict( type="beaker", beaker_workspace=beaker_workspace_name, beaker_image="ai2/conda", github_token="FAKE_TOKEN", datasets=[{"source": {"beaker": "some-dataset"}, "mount_path": "/input"}], budget="ai2/allennlp", ), workspace=BeakerWorkspace(workspace=beaker_workspace_name), clusters=["fake-cluster"], ) assert isinstance(executor, BeakerExecutor) assert executor.datasets is not None assert len(executor.datasets) == 1 assert isinstance(executor.datasets[0], DataMount) assert executor.datasets[0].source.beaker == "some-dataset" def test_init_with_mem_workspace(beaker_workspace_name: str): with pytest.raises(ConfigurationError, match="MemoryWorkspace"): BeakerExecutor( workspace=default_workspace, beaker_workspace=beaker_workspace_name, beaker_image="ai2/conda", github_token="FAKE_TOKEN", clusters=["fake-cluster"], budget="ai2/allennlp", ) @pytest.fixture def settings(beaker_workspace_name: str) -> TangoGlobalSettings: return TangoGlobalSettings( workspace={"type": "beaker", "beaker_workspace": beaker_workspace_name}, executor={ "type": "beaker", "beaker_workspace": beaker_workspace_name, "install_cmd": "pip install .[beaker]", "clusters": ["ai2/allennlp-cirrascale", "ai2/general-cirrascale"], "budget": "ai2/allennlp", }, ) def test_beaker_executor( settings: TangoGlobalSettings, beaker_workspace_name: str, patched_unique_id_suffix: str ): run_name = petname.generate() with run_experiment( {"steps": {"hello": {"type": "string", "result": "Hello, World!"}}}, settings=settings, workspace_url=f"beaker://{beaker_workspace_name}", name=run_name, multicore=None, ): workspace = BeakerWorkspace(workspace=beaker_workspace_name) assert "hello" in workspace.registered_run(run_name).steps ================================================ FILE: tests/integrations/beaker/step_cache_test.py ================================================ from tango.common.testing.steps import FloatStep from tango.integrations.beaker.step_cache import BeakerStepCache def test_step_cache(beaker_workspace: str): cache = BeakerStepCache(beaker_workspace=beaker_workspace) step = FloatStep(result=1.0) cache[step] = 1.0 assert step in cache assert len(cache) == 1 assert FloatStep(result=2.0) not in cache assert cache[step] == 1.0 ================================================ FILE: tests/integrations/beaker/workspace_test.py ================================================ import pytest from beaker import DatasetNotFound from tango.common.testing.steps import FloatStep from tango.integrations.beaker.workspace import BeakerWorkspace from tango.step_info import StepState from tango.workspace import Workspace def test_from_url(beaker_workspace: str): print(beaker_workspace) workspace = Workspace.from_url(f"beaker://{beaker_workspace}") assert isinstance(workspace, BeakerWorkspace) def test_direct_usage(beaker_workspace: str): workspace = BeakerWorkspace(beaker_workspace) step = FloatStep(step_name="float", result=1.0) run = workspace.register_run([step]) assert run.name in workspace.registered_runs() assert workspace.step_info(step).state == StepState.INCOMPLETE workspace.step_starting(step) assert workspace.step_info(step).state == StepState.RUNNING workspace.step_finished(step, 1.0) assert workspace.step_info(step).state == StepState.COMPLETED assert workspace.step_result_for_run(run.name, "float") == 1.0 def test_remove_step(beaker_workspace: str): beaker_workspace = "ai2/tango_remove_cache_test" workspace = BeakerWorkspace(beaker_workspace) step = FloatStep(step_name="float", result=1.0) workspace.step_starting(step) workspace.step_finished(step, 1.0) step_info = workspace.step_info(step) dataset_name = workspace.Constants.step_artifact_name(step_info) cache = workspace.step_cache assert workspace.beaker.dataset.get(dataset_name) is not None assert step in cache workspace.remove_step(step.unique_id) cache = workspace.step_cache dataset_name = workspace.Constants.step_artifact_name(step_info) with pytest.raises(DatasetNotFound): workspace.beaker.dataset.get(dataset_name) assert step not in cache ================================================ FILE: tests/integrations/datasets/__init__.py ================================================ ================================================ FILE: tests/integrations/datasets/dataset_test.py ================================================ import datasets from tango.common.sequences import MappedSequence from tango.common.testing import TangoTestCase from tango.integrations.datasets import ( DatasetRemixStep, DatasetsFormat, LoadDataset, convert_to_tango_dataset_dict, ) from tango.step import Step class TestDatasets(TangoTestCase): def test_from_params_and_convert_to_tango_dataset_dict(self): step: LoadDataset = Step.from_params( # type: ignore[assignment] { "type": "datasets::load", "path": "lhoestq/test", "cache_dir": str(self.TEST_DIR / "cache"), } ) hf_dataset_dict = step.result() assert "train" in hf_dataset_dict dataset_dict = convert_to_tango_dataset_dict(hf_dataset_dict) assert "train" in dataset_dict.splits def test_convert_to_tango_iterable_dataset_dict(self): def data_gen(): for x in range(100): yield {"x": x} hf_dataset_dict = datasets.IterableDatasetDict( train=datasets.iterable_dataset.IterableDataset.from_generator(data_gen) ) dataset_dict1 = convert_to_tango_dataset_dict(hf_dataset_dict) assert "train" in dataset_dict1.splits def test_load_concatenate_and_interleave(self): result_dir = self.run( self.FIXTURES_ROOT / "integrations" / "datasets" / "config.json", overrides={ "steps.train_data.cache_dir": str(self.TEST_DIR / "cache"), "steps.dev_data.cache_dir": str(self.TEST_DIR / "cache"), }, ) assert (result_dir / "train_data" / "data").is_dir() dataset = DatasetsFormat().read(result_dir / "train_data") assert len(dataset) == 2 def test_mapped_sequence_of_dataset(): ds = datasets.load_dataset("piqa", split="validation") mapped_ds = MappedSequence(lambda x: x["goal"], ds) # type: ignore[arg-type] assert len(ds) == len(mapped_ds) # type: ignore[arg-type] assert ds[0]["goal"] == mapped_ds[0] # type: ignore[index] assert ds[0]["goal"] == mapped_ds[:10][0] # type: ignore[index] def test_datasets_dataset_remix(): dataset_dict = datasets.load_dataset("lhoestq/test") step = DatasetRemixStep() result = step.run( input=dataset_dict, # type: ignore[arg-type] new_splits={ "all": "train + validation", "crossval_train": "train[:1] + validation[1:]", "crossval_test": "train[1:] + validation[:1]", }, ) assert len(result["all"]) == len(dataset_dict["train"]) + len(dataset_dict["validation"]) # type: ignore assert len(result["crossval_train"]) == 3 assert len(result["crossval_test"]) == 2 ================================================ FILE: tests/integrations/fairscale/__init__.py ================================================ ================================================ FILE: tests/integrations/fairscale/train_test.py ================================================ from typing import Any, Dict import pytest import torch from tango.common.logging import initialize_logging, teardown_logging from tango.common.testing import TangoTestCase class TestFairScaleTrain(TangoTestCase): def setup_method(self): super().setup_method() initialize_logging(log_level="info") def teardown_method(self): teardown_logging() @pytest.mark.parametrize( "fsdp", ( pytest.param( True, id="fsdp=True", marks=[ pytest.mark.gpu, pytest.mark.skipif( torch.cuda.device_count() < 2, reason="Requires CUDA devices" ), ], ), pytest.param(False, id="fsdp=False"), ), ) @pytest.mark.parametrize( "activation_checkpoint", ( pytest.param(True, id="checkpointing=True"), pytest.param(False, id="checkpointing=False"), ), ) @pytest.mark.parametrize( "amp", ( pytest.param( True, id="amp=True", marks=[ pytest.mark.gpu, pytest.mark.skipif( torch.cuda.device_count() < 2, reason="Requires CUDA devices" ), ], ), pytest.param(False, id="amp=False"), ), ) def test_train_tiny_gpt2(self, fsdp: bool, activation_checkpoint: bool, amp: bool): overrides: Dict[str, Any] = { "steps.trained_model.model.activation_checkpointing": activation_checkpoint, } training_engine: Dict[str, Any] = { "amp": amp, "optimizer": { "type": "torch::AdamW", "lr": 0.005, "betas": [0.9, 0.95], "eps": 1e-6, }, } if fsdp: training_engine["type"] = "fairscale" fsdp_config = {"reshard_after_forward": True, "mixed_precision": amp} training_engine["fsdp_config"] = fsdp_config overrides["steps.trained_model.model.fsdp_config"] = fsdp_config else: training_engine["type"] = "torch" overrides["steps.trained_model.model.fsdp_config"] = None overrides["steps.trained_model.training_engine"] = training_engine run_dir = self.run( self.FIXTURES_ROOT / "integrations" / "fairscale" / "config.jsonnet", include_package=["test_fixtures.integrations.fairscale.components"], overrides=overrides, ) assert (run_dir / "trained_model").is_dir() ================================================ FILE: tests/integrations/flax/__init__.py ================================================ ================================================ FILE: tests/integrations/flax/data_test.py ================================================ from typing import Dict from transformers import AutoTokenizer from tango.common.testing import TangoTestCase from tango.integrations.flax import DataLoader, FlaxDataLoader from tango.integrations.flax.util import get_PRNGkey from tango.step import Step class TestDataStep(TangoTestCase): def test_dataloader(self) -> None: assert "flax::dataloader" in DataLoader.list_available() def test_sample_data(self) -> None: step = Step.from_params( # type: ignore[assignment] { "type": "datasets::load", "path": "lhoestq/demo1", "split": "train", "cache_dir": str(self.TEST_DIR / "cache"), } ) dataset = step.result() tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") column_names = dataset.column_names dataset = dataset.map( lambda e: tokenizer(e["review"], truncation=True, padding="max_length") ) dataset = dataset.remove_columns(column_names) data = FlaxDataLoader(dataset, batch_size=16) rng = get_PRNGkey() for batch in data(rng, do_distributed=False): assert isinstance(batch, Dict) ================================================ FILE: tests/integrations/flax/format_test.py ================================================ import os from tango import Format from tango.common.testing import TangoTestCase from tango.integrations.flax.format import FlaxFormat class TestTorchFormat(TangoTestCase): def test_read_write(self): flax_format: FlaxFormat = Format.by_name("flax")() # type: ignore[assignment] flax_format.write({"a": 1}, self.TEST_DIR) assert os.path.exists(self.TEST_DIR / "checkpoint_0") data = flax_format.read(self.TEST_DIR) assert data == {"a": 1} ================================================ FILE: tests/integrations/flax/optim_test.py ================================================ from tango.integrations.flax.optim import LRScheduler, Optimizer def test_all_optimizers_registered(): assert "optax::adafactor" in Optimizer.list_available() def test_all_lr_schedulers_registered(): assert "optax::constant_schedule" in LRScheduler.list_available() ================================================ FILE: tests/integrations/flax/train_test.py ================================================ from tango.common.logging import initialize_logging, teardown_logging from tango.common.testing import TangoTestCase class TestTrainStep(TangoTestCase): def setup_method(self): super().setup_method() initialize_logging(enable_cli_logs=True) def teardown_method(self): super().teardown_method() teardown_logging() def test_trainer(self): result_dir = self.run( self.FIXTURES_ROOT / "integrations" / "flax" / "config.jsonnet", include_package=[ "test_fixtures.integrations.common", "test_fixtures.integrations.flax", ], ) assert ( result_dir / "train" / "work" / "checkpoint_state_latest" / "checkpoint_0" / "checkpoint" ).is_file() ================================================ FILE: tests/integrations/gs/__init__.py ================================================ ================================================ FILE: tests/integrations/gs/step_cache_test.py ================================================ import os import pytest from tango.common.testing import TangoTestCase from tango.common.testing.steps import FloatStep from tango.integrations.gs.common import empty_bucket_folder from tango.integrations.gs.step_cache import GSStepCache GS_BUCKET_NAME = os.environ.get("GS_BUCKET_NAME", "allennlp-tango-bucket") GS_SUBFOLDER = f"{GS_BUCKET_NAME}/my-workspaces/workspace1" class TestGSStepCache(TangoTestCase): def setup_method(self): super().setup_method() empty_bucket_folder(GS_BUCKET_NAME) empty_bucket_folder(GS_SUBFOLDER) def teardown_method(self): super().teardown_method() @pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER]) def test_step_cache(self, gs_path): cache = GSStepCache(folder_name=gs_path) step = FloatStep(result=1.0) cache[step] = 1.0 assert step in cache assert len(cache) == 1 assert FloatStep(result=2.0) not in cache assert cache[step] == 1.0 ================================================ FILE: tests/integrations/gs/workspace_test.py ================================================ import os import pytest from tango.common.testing import TangoTestCase from tango.common.testing.steps import FloatStep from tango.integrations.gs.common import empty_bucket_folder, empty_datastore from tango.integrations.gs.workspace import GSWorkspace from tango.step_info import StepState from tango.workspace import Workspace GS_BUCKET_NAME = os.environ.get("GS_BUCKET_NAME", "allennlp-tango-bucket") GS_SUBFOLDER = f"{GS_BUCKET_NAME}/my-workspaces/workspace1" class TestGSWorkspace(TangoTestCase): def setup_method(self): super().setup_method() empty_bucket_folder(GS_BUCKET_NAME) empty_bucket_folder(GS_SUBFOLDER) empty_datastore(GS_BUCKET_NAME) empty_datastore(GS_SUBFOLDER) def teardown_method(self): super().teardown_method() @pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER]) def test_from_url(self, gs_path: str): workspace = Workspace.from_url(f"gs://{gs_path}") assert isinstance(workspace, GSWorkspace) @pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER]) def test_from_params(self, gs_path: str): workspace = Workspace.from_params({"type": "gs", "workspace": gs_path}) assert isinstance(workspace, GSWorkspace) @pytest.mark.parametrize("gs_path", [GS_BUCKET_NAME, GS_SUBFOLDER]) def test_direct_usage(self, gs_path: str): workspace = GSWorkspace(gs_path) step = FloatStep(step_name="float", result=1.0) run = workspace.register_run([step]) assert run.name in workspace.registered_runs() assert workspace.step_info(step).state == StepState.INCOMPLETE workspace.step_starting(step) assert workspace.step_info(step).state == StepState.RUNNING workspace.step_finished(step, 1.0) assert workspace.step_info(step).state == StepState.COMPLETED assert workspace.step_result_for_run(run.name, "float") == 1.0 def test_remove_step(self): workspace = GSWorkspace(GS_BUCKET_NAME) step = FloatStep(step_name="float", result=1.0) step_info = workspace.step_info(step) workspace.step_starting(step) workspace.step_finished(step, 1.0) bucket_artifact = workspace.Constants.step_artifact_name(step_info) ds_entity = workspace._ds.get(key=workspace._ds.key("stepinfo", step_info.unique_id)) cache = workspace.step_cache assert workspace.client.artifacts(prefix=bucket_artifact) is not None assert ds_entity is not None assert step in cache workspace.remove_step(step.unique_id) cache = workspace.step_cache ds_entity = workspace._ds.get(key=workspace._ds.key("stepinfo", step_info.unique_id)) with pytest.raises(Exception) as excinfo: workspace.client.artifacts(prefix=bucket_artifact) assert "KeyError" in str(excinfo) assert ds_entity is None assert step not in cache ================================================ FILE: tests/integrations/torch/__init__.py ================================================ ================================================ FILE: tests/integrations/torch/data_test.py ================================================ import torch from tango.integrations.torch.data import DataLoader, Sampler def test_dataloader_from_params(): DataLoader.from_params( { "dataset": list(range(10)), "batch_size": 2, "shuffle": True, } ) def test_samplers_registered(): assert "torch::SequentialSampler" in Sampler.list_available() def test_dataloader_from_params_with_sampler(): dataloader = DataLoader.from_params( { "dataset": list(range(10)), "sampler": { "type": "torch::RandomSampler", "replacement": True, }, } ) assert isinstance(dataloader.sampler, torch.utils.data.RandomSampler) assert dataloader.sampler.replacement def test_dataloader_from_params_with_batch_sampler(): dataloader = DataLoader.from_params( { "dataset": list(range(10)), "sampler": { "type": "torch::BatchSampler", "sampler": { "type": "torch::RandomSampler", }, "batch_size": 2, "drop_last": True, }, } ) assert isinstance(dataloader.sampler, torch.utils.data.BatchSampler) ================================================ FILE: tests/integrations/torch/det_hash_test.py ================================================ import numpy import torch from tango.common import det_hash def test_numpy_det_hash(): a1 = numpy.array([[1, 2], [3, 4]], order="C") a2 = numpy.array([[1, 2], [3, 4]], order="K") assert det_hash(a1) == det_hash(a2) def test_torch_det_hash(): a1 = numpy.array([[1, 2], [3, 4]], order="C") a2 = numpy.array([[1, 2], [3, 4]], order="K") a1 = torch.tensor(a1) a2 = torch.tensor(a2) assert det_hash(a1) == det_hash(a2) ================================================ FILE: tests/integrations/torch/eval_test.py ================================================ from tango.common.testing import TangoTestCase class TestEvalStep(TangoTestCase): def test_basic_eval(self): result_dir = self.run( self.FIXTURES_ROOT / "integrations/torch/eval.jsonnet", include_package=[ "test_fixtures.integrations.common", "test_fixtures.integrations.torch", ], ) assert (result_dir / "eval" / "data.json").is_file() ================================================ FILE: tests/integrations/torch/format_test.py ================================================ import os from tango import Format from tango.common.testing import TangoTestCase from tango.integrations.torch.format import TorchFormat class TestTorchFormat(TangoTestCase): def test_read_write(self): torch_format: TorchFormat = Format.by_name("torch")() # type: ignore[assignment] torch_format.write({"a": 1}, self.TEST_DIR) assert os.path.exists(self.TEST_DIR / "data.pt") data = torch_format.read(self.TEST_DIR) assert data == {"a": 1} ================================================ FILE: tests/integrations/torch/optim_test.py ================================================ from tango.integrations.torch.optim import LRScheduler, Optimizer def test_all_optimizers_registered(): assert "torch::Adagrad" in Optimizer.list_available() def test_all_lr_schedulers_registered(): assert "torch::ExponentialLR" in LRScheduler.list_available() ================================================ FILE: tests/integrations/torch/train_callback_test.py ================================================ from pathlib import Path import pytest from torch.optim import SGD from tango.common import DatasetDict, Lazy from tango.integrations.torch import ( DataLoader, StopEarly, StopEarlyCallback, TorchTrainingEngine, TrainConfig, ) from tango.workspaces import MemoryWorkspace from .training_engine_test import DummyModel def test_stop_early_callback(): workspace = MemoryWorkspace() train_config = TrainConfig(step_id="FakeStep-abc123", work_dir=Path("/tmp")) training_engine = TorchTrainingEngine( train_config=train_config, model=DummyModel(), optimizer=Lazy(SGD, lr=0.001) # type: ignore ) dataset_dict = DatasetDict(splits={"train": [1, 2, 3]}) train_dataloader = Lazy(DataLoader) callback = StopEarlyCallback( patience=10, workspace=workspace, train_config=train_config, training_engine=training_engine, dataset_dict=dataset_dict, train_dataloader=train_dataloader, ) callback.post_val_loop(1, 1, 0.5, 0.5) callback.post_val_loop(2, 1, 0.5, 0.5) callback.post_val_loop(20, 1, 0.6, 0.6) with pytest.raises(StopEarly): callback.post_val_loop(31, 1, 0.6, 0.6) ================================================ FILE: tests/integrations/torch/train_test.py ================================================ import json import pytest import torch.distributed as dist from tango.common.logging import initialize_logging, teardown_logging from tango.common.testing import TangoTestCase class TestTrainStep(TangoTestCase): def setup_method(self): super().setup_method() initialize_logging(enable_cli_logs=True) def teardown_method(self): super().teardown_method() if dist.is_initialized(): dist.destroy_process_group() teardown_logging() @pytest.mark.parametrize("with_validation", [True, False]) def test_basic_train(self, with_validation: bool): result_dir = self.run( self.FIXTURES_ROOT / "integrations" / "torch" / "train.jsonnet", include_package=[ "test_fixtures.integrations.common", "test_fixtures.integrations.torch", ], overrides="" if with_validation else json.dumps( {"steps.train.validation_split": None, "steps.train.validate_every": None} ), ) assert (result_dir / "train" / "data.pt").is_file() assert (result_dir / "train" / "work" / "weights.pt").is_file() assert ( result_dir / "train" / "work" / "checkpoint_state_latest" / "worker0_model.pt" ).is_file() assert ( result_dir / "train" / "work" / "checkpoint_state_best" / "worker0_optimizer.pt" ).is_file() assert ( result_dir / "train" / "work" / "checkpoint_state_best" / "worker0_trainer.pt" ).is_file() @pytest.mark.parametrize("grad_acc", [1, 2]) def test_basic_train_with_epochs(self, grad_acc: int): result_dir = self.run( self.FIXTURES_ROOT / "integrations" / "torch" / "train.jsonnet", include_package=[ "test_fixtures.integrations.common", "test_fixtures.integrations.torch", ], overrides=json.dumps( { "steps.train.train_steps": None, "steps.train.train_epochs": 2, "steps.train.validate_every": None, "steps.train.grad_accum": grad_acc, } ), ) assert (result_dir / "train" / "data.pt").is_file() # Make sure we trained for the right number of steps. expected_steps = 16 // grad_acc latest = result_dir / "train" / "work" / "checkpoint_state_latest" assert latest.is_symlink() last_step = result_dir / "train" / "work" / f"checkpoint_state_step{expected_steps}" assert last_step.is_dir() assert latest.samefile(last_step) def test_basic_train_with_streaming_data(self): result_dir = self.run( self.FIXTURES_ROOT / "integrations" / "torch" / "train.jsonnet", include_package=[ "test_fixtures.integrations.common", "test_fixtures.integrations.torch", ], ) assert (result_dir / "train" / "data.pt").is_file() def test_train_distributed(self): result_dir = self.run( self.FIXTURES_ROOT / "integrations" / "torch" / "train_dist.jsonnet", include_package=[ "test_fixtures.integrations.common", "test_fixtures.integrations.torch", ], ) assert (result_dir / "train" / "data.pt").is_file() assert (result_dir / "train" / "work" / "weights.pt").is_file() assert ( result_dir / "train" / "work" / "checkpoint_state_latest" / "worker0_model.pt" ).is_file() assert ( result_dir / "train" / "work" / "checkpoint_state_best" / "worker0_model.pt" ).is_file() assert ( result_dir / "train" / "work" / "checkpoint_state_latest" / "worker1_model.pt" ).is_file() assert ( result_dir / "train" / "work" / "checkpoint_state_best" / "worker1_model.pt" ).is_file() @pytest.mark.parametrize("grad_acc", [1, 2]) def test_train_distributed_with_epochs(self, grad_acc: int): result_dir = self.run( self.FIXTURES_ROOT / "integrations" / "torch" / "train_dist.jsonnet", include_package=[ "test_fixtures.integrations.common", "test_fixtures.integrations.torch", ], overrides=json.dumps( { "steps.train.train_steps": None, "steps.train.train_epochs": 2, "steps.train.validate_every": None, "steps.train.grad_accum": grad_acc, } ), ) assert (result_dir / "train" / "data.pt").is_file() # Make sure we trained for the right number of steps. expected_steps = 8 // grad_acc latest = result_dir / "train" / "work" / "checkpoint_state_latest" assert latest.is_symlink() last_step = result_dir / "train" / "work" / f"checkpoint_state_step{expected_steps}" assert last_step.is_dir() assert latest.samefile(last_step) ================================================ FILE: tests/integrations/torch/training_engine_test.py ================================================ import time import pytest import torch import torch.nn as nn from torch.nn import MSELoss from tango.common import DatasetDict, Lazy from tango.common.testing import TangoTestCase from tango.integrations.torch import ( DataLoader, StopEarly, TorchTrainStep, TrainCallback, ) from tango.integrations.torch.model import Model from tango.integrations.torch.training_engine import TorchTrainingEngine @Model.register("dummy_model") class DummyModel(Model): def __init__(self): super().__init__() self.linear = nn.Linear(10, 1) def forward(self, x, y=None): return self.linear(x) @pytest.mark.gpu @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Requires CUDA devices") class TestTorchTrainingEngine(TangoTestCase): def test_grad_scaler(self): training_engine = TorchTrainingEngine.from_params( { "train_config": {"step_id": "001", "work_dir": self.TEST_DIR}, "model": { "type": "dummy_model", }, "optimizer": {"type": "torch::Adam"}, "amp": True, } ) state_dict = {"training_steps": None} training_engine.save_checkpoint(self.TEST_DIR, state_dict) saved_grad_scaler = training_engine.grad_scaler training_engine.load_checkpoint(self.TEST_DIR) assert (self.TEST_DIR / "worker0_grad_scaler.pt").is_file() assert training_engine.grad_scaler == saved_grad_scaler class WorseningModel(Model): def __init__(self): super().__init__() self.linear = nn.Linear(7, 1) self.loss = MSELoss() self.start_time = time.time() def forward(self, x, y): y_hat = self.linear(x) time.sleep(0.01) return {"loss": self.loss(y_hat, y) + (time.time() - self.start_time)} class StopOnStepCallback(TrainCallback): def __init__(self, stop_on_step: int, *args, **kwargs): super().__init__(*args, **kwargs) self.stop_on_step = stop_on_step def post_val_loop( self, step: int, epoch: int, val_metric: float, best_val_metric: float ) -> None: if step == self.stop_on_step: raise StopEarly def test_with_increasing_loss(): model = WorseningModel() xs = [torch.randn(7) for _ in range(100)] train_set = [{"x": x, "y": x + 0.1} for x in xs] dataset = DatasetDict(splits={"train": train_set, "validation": train_set}, metadata={}) step = TorchTrainStep( model=model, training_engine=Lazy(TorchTrainingEngine, optimizer=Lazy(torch.optim.AdamW, lr=1e-5)), dataset_dict=dataset, train_dataloader=Lazy(DataLoader), train_steps=10, validation_steps=10, train_split="train", validation_split="validation", callbacks=[Lazy(StopOnStepCallback, stop_on_step=9)], ) step.result() ================================================ FILE: tests/integrations/transformers/data_test.py ================================================ from transformers.data.data_collator import DataCollatorWithPadding, DefaultDataCollator from tango.integrations.torch import DataCollator from tango.integrations.transformers.data import * # noqa: F403,F401 def test_init_collator_no_tokenizer(): collator = DataCollator.from_params({"type": "transformers::DefaultDataCollator"}) assert isinstance(collator, DefaultDataCollator) def test_init_collator_with_tokenizer(): collator = DataCollator.from_params( { "type": "transformers::DataCollatorWithPadding", "tokenizer": { "pretrained_model_name_or_path": "epwalsh/bert-xsmall-dummy", }, } ) assert isinstance(collator, DataCollatorWithPadding) ================================================ FILE: tests/integrations/transformers/finetune_test.py ================================================ from datasets import Dataset, DatasetDict from transformers import AutoTokenizer from tango.common.testing import TangoTestCase from tango.integrations.transformers import TokenizeText2TextData class TestTokenizeText2TextData(TangoTestCase): def test_tokenize_seq2seq(self): dataset = Dataset.from_dict( {"field1": ["hello", "hi"], "field2": ["world", "me"], "meta_field": [1, 0]} ) data_dict = DatasetDict({"train": dataset}) tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") step = TokenizeText2TextData() tokenized = step.run( data=data_dict, tokenizer=tokenizer, source_field="field1", target_field="field2" ) assert isinstance(tokenized, DatasetDict) assert len(tokenized["train"]) == 2 assert "input_ids" in tokenized["train"].column_names assert "labels" in tokenized["train"].column_names assert tokenized["train"][0]["input_ids"] == [21820, 1] def test_tokenize_concat(self): dataset = Dataset.from_dict( {"field1": ["hello", "hi"], "field2": ["world", "me"], "meta_field": [1, 0]} ) data_dict = DatasetDict({"train": dataset}) tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") step = TokenizeText2TextData() tokenized = step.run( data=data_dict, tokenizer=tokenizer, source_field="field1", target_field="field2", concat_source_target=True, ) assert isinstance(tokenized, DatasetDict) assert len(tokenized["train"]) == 2 assert "input_ids" in tokenized["train"].column_names assert "labels" in tokenized["train"].column_names assert tokenized["train"][0]["input_ids"] == [31373, 50257, 6894, 50256] assert tokenized["train"][0]["labels"] == [-100, -100, 6894, 50256] ================================================ FILE: tests/integrations/transformers/ia3_test.py ================================================ import torch from transformers import AutoModelForCausalLM, AutoTokenizer from tango.integrations.transformers.ia3 import GPT_2_IA3_CONFIG, modify_with_ia3 def test_ia3(): config = GPT_2_IA3_CONFIG model_name = "sshleifer/tiny-gpt2" tokenizer = AutoTokenizer.from_pretrained(model_name) input_seq = tokenizer(["A tiny test on a tiny model."], return_tensors="pt") model = AutoModelForCausalLM.from_pretrained(model_name).eval() with torch.inference_mode(): old_outputs = model( input_ids=input_seq.input_ids, attention_mask=input_seq.attention_mask, labels=input_seq.input_ids, ) model = modify_with_ia3(model, config=config) with torch.inference_mode(): new_outputs = model( input_ids=input_seq.input_ids, attention_mask=input_seq.attention_mask, labels=input_seq.input_ids, ) logits_diff = torch.abs(old_outputs.logits - new_outputs.logits).mean() assert logits_diff < 1e-10 loss_diff = torch.abs(old_outputs.loss - new_outputs.loss) assert loss_diff < 1e-10 ================================================ FILE: tests/integrations/transformers/run_generation_test.py ================================================ from tango import Step from tango.common import DatasetDict from tango.common.testing import TangoTestCase from tango.integrations.transformers import RunGenerationDataset class TestRunGeneration(TangoTestCase): def test_run_generation(self): step = Step.from_params( # type: ignore[assignment] { "type": "transformers::run_generation", "prompts": ["Tango is the future of", "Everybody should be using Tango to"], "model": "sshleifer/tiny-gpt2", }, ) result = list(step.result()) assert len(result) == 2 def test_run_generation_with_model(self): step = Step.from_params( # type: ignore[assignment] { "type": "transformers::run_generation", "prompts": ["Tango is the future of", "Everybody should be using Tango to"], "model": { "type": "transformers::AutoModelForCausalLM::from_pretrained", "pretrained_model_name_or_path": "sshleifer/tiny-gpt2", }, }, ) result = list(step.result()) assert len(result) == 2 def test_run_generation_dataset(self): dataset = DatasetDict( { "train": [ {"prompt": "Tango is the future of"}, {"prompt": "Everybody should be using Tango to"}, ] }, {}, ) step = RunGenerationDataset( model="sshleifer/tiny-gpt2", input=dataset, prompt_field="prompt" ) result = step.result() assert len(result) == 1 train_split = result["train"] assert len(train_split) == 2 assert len(train_split[1]) == 2 assert train_split[1]["prompt"] == "Everybody should be using Tango to" assert all( g.startswith("Everybody should be using Tango to") for g in train_split[1]["prompt_generated"] ) ================================================ FILE: tests/integrations/transformers/soft_prompt_test.py ================================================ import transformers from tango.integrations.transformers import add_soft_prompt def test_soft_prompt(): model = transformers.AutoModelForSeq2SeqLM.from_pretrained("t5-small") tokenizer = transformers.AutoTokenizer.from_pretrained("t5-small") prompt = "translate English to German: That is good." model.eval() generated = model.generate( tokenizer.encode(prompt, return_tensors="pt"), num_beams=10, num_return_sequences=5 ) original_output = [tokenizer.decode(g) for g in generated] add_soft_prompt(model, prompt_length=3) model.eval() generated = model.generate( tokenizer.encode(prompt, return_tensors="pt"), num_beams=10, num_return_sequences=5 ) prompted_output = [tokenizer.decode(g) for g in generated] assert original_output != prompted_output def test_soft_prompt_twice(): tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2") model = transformers.AutoModelForCausalLM.from_pretrained("gpt2") add_soft_prompt(model, prompt_length=2) model.eval() generated = model.generate(tokenizer.encode("It was the best of times.", return_tensors="pt")) prompted_output1 = tokenizer.decode(generated[0]) add_soft_prompt(model, prompt_length=5) model.eval() generated = model.generate(tokenizer.encode("It was the best of times.", return_tensors="pt")) prompted_output2 = tokenizer.decode(generated[0]) assert prompted_output1 != prompted_output2 ================================================ FILE: tests/integrations/wandb/__init__.py ================================================ ================================================ FILE: tests/integrations/wandb/step_cache_test.py ================================================ import os import pickle import sys import pytest from tango import Step from tango.integrations.wandb import WandbStepCache WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "allennlp") WANDB_PROJECT = "tango-workspace-testing" class SomeFakeStep(Step): DETERMINISTIC = True CACHEABLE = True def run(self) -> int: # type: ignore return 1 def test_step_cache_artifact_not_found(): step = SomeFakeStep(step_name="hi there") step_cache = WandbStepCache(project=WANDB_PROJECT, entity=WANDB_ENTITY) assert step not in step_cache @pytest.mark.parametrize( "protocol", [pytest.param(protocol, id=f"protocol={protocol}") for protocol in range(4)] + [ pytest.param( 5, id="protocol=5", marks=pytest.mark.skipif( sys.version_info < (3, 8), reason="Protocol 5 requires Python 3.8 or newer" ), ), ], ) def test_pickling(protocol: int): step_cache = WandbStepCache(project=WANDB_PROJECT, entity=WANDB_ENTITY) pickle.loads(pickle.dumps(step_cache, protocol=protocol)) ================================================ FILE: tests/integrations/wandb/workspace_test.py ================================================ import json import os import pickle import shutil import sys import uuid import pytest import wandb from tango import Step, StepGraph, Workspace from tango.common import Params, util from tango.common.logging import initialize_logging, teardown_logging from tango.common.testing import TangoTestCase from tango.common.testing.steps import * # noqa: F403,F401 from tango.integrations.wandb import WandbWorkspace from tango.step_info import StepState WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "allennlp") WANDB_PROJECT = "tango-workspace-testing" class TestWandbWorkspace(TangoTestCase): # Need to define the `setup_method()` as fixture so we can use other fixtures within it. @pytest.fixture(autouse=True) def setup_method(self, monkeypatch): super().setup_method() # Patch tango_cache_dir() monkeypatch.setattr(util, "tango_cache_dir", lambda: self.TEST_DIR) @pytest.mark.parametrize( "protocol", [pytest.param(protocol, id=f"protocol={protocol}") for protocol in range(4)] + [ pytest.param( 5, id="protocol=5", marks=pytest.mark.skipif( sys.version_info < (3, 8), reason="Protocol 5 requires Python 3.8 or newer" ), ), ], ) def test_pickle_workspace(self, protocol): workspace = WandbWorkspace(project=WANDB_PROJECT, entity=WANDB_ENTITY) unpickled_workspace = pickle.loads(pickle.dumps(workspace, protocol=protocol)) assert unpickled_workspace.wandb_client is not None assert unpickled_workspace.project == workspace.project assert unpickled_workspace.entity == workspace.entity assert unpickled_workspace.steps_dir == workspace.steps_dir def test_from_url(self): workspace = Workspace.from_url(f"wandb://{WANDB_ENTITY}/{WANDB_PROJECT}") assert isinstance(workspace, WandbWorkspace) assert workspace.entity == WANDB_ENTITY assert workspace.project == WANDB_PROJECT class TestWandbWorkspaceUsage(TangoTestCase): # Need to define the `setup_method()` as fixture so we can use other fixtures within it. @pytest.fixture(autouse=True) def setup_method(self, monkeypatch): super().setup_method() self.UNIQUE_ID_SUFFIX = os.environ.get("GITHUB_SHA", "")[:6] + "-" + str(uuid.uuid1())[:6] # Patch tango_cache_dir() monkeypatch.setattr(util, "tango_cache_dir", lambda: self.TEST_DIR) # Patch Step unique IDs and W&B run IDs. monkeypatch.setattr(Step, "_UNIQUE_ID_SUFFIX", self.UNIQUE_ID_SUFFIX) monkeypatch.setattr( WandbWorkspace, "_generate_run_suite_id", lambda workspace: wandb.util.generate_id() + "-" + self.UNIQUE_ID_SUFFIX, ) self.workspace = WandbWorkspace(project=WANDB_PROJECT, entity=WANDB_ENTITY) initialize_logging(enable_cli_logs=True) def teardown_method(self): super().teardown_method() # Delete W&B runs and their artifacts produced by the test. for wandb_run in self.workspace.wandb_client.runs( f"{WANDB_ENTITY}/{WANDB_PROJECT}", ): if ( self.UNIQUE_ID_SUFFIX in wandb_run.id or self.UNIQUE_ID_SUFFIX in wandb_run.config.get("_run_suite_id", "") ): wandb_run.delete(delete_artifacts=True) teardown_logging() def test_direct_usage(self): params = Params.from_file(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet") step_graph = StepGraph.from_params(params.pop("steps", keep_as_dict=True)) tango_run = self.workspace.register_run(step for step in step_graph.values()) # Test 'registered_run()' and 'registered_runs()' methods. assert self.workspace.registered_run(tango_run.name) == tango_run assert self.workspace.registered_runs()[tango_run.name] == tango_run hello_step = step_graph["hello"] hello_world_step = step_graph["hello_world"] # Test getting step info. step_info = self.workspace.step_info(hello_step) assert step_info.unique_id.endswith(self.UNIQUE_ID_SUFFIX) assert step_info.step_name == "hello" assert step_info.state == StepState.INCOMPLETE # Mark the "hello" step as starting. self.workspace.step_starting(hello_step) assert self.workspace.step_info(hello_step).state == StepState.RUNNING # Mark the "hello" step as finished. self.workspace.step_finished(hello_step, "hello") assert self.workspace.step_info(hello_step).state == StepState.COMPLETED # Make sure the result is in the cache, exists locally, and on W&B. cache = self.workspace.cache assert hello_step in cache assert cache.step_dir(hello_step).is_dir() assert cache.get_step_result_artifact(hello_step) is not None # Now make sure we can fetch the item from the cache, even if it's not in memory # or in the cache directory. if hello_step.unique_id in cache.weak_cache: del cache.weak_cache[hello_step.unique_id] if hello_step.unique_id in cache.strong_cache: del cache.strong_cache[hello_step.unique_id] shutil.rmtree(cache.step_dir(hello_step)) assert hello_step in cache assert cache[hello_step] == "hello" # Now start the "hello_world" step and then mark it as failed. self.workspace.step_starting(hello_world_step) self.workspace.step_failed(hello_world_step, ValueError("oh no!")) assert self.workspace.step_info(hello_world_step).state == StepState.FAILED @pytest.mark.parametrize( "multicore", [pytest.param(True, id="multicore"), pytest.param(False, id="singe-core")] ) @pytest.mark.parametrize( "distributed", [ pytest.param(True, id="distributed"), pytest.param(False, id="single-device"), ], ) def test_with_wandb_train_callback(self, multicore: bool, distributed: bool): self.run( self.FIXTURES_ROOT / "integrations" / "torch" / ("train.jsonnet" if not distributed else "train_dist.jsonnet"), include_package=[ "test_fixtures.integrations.common", "test_fixtures.integrations.torch", ], overrides=json.dumps({"steps.train.callbacks": [{"type": "wandb::log"}]}), workspace_url=f"wandb://{WANDB_ENTITY}/{WANDB_PROJECT}", multicore=multicore, ) ================================================ FILE: tests/main_test.py ================================================ import json import os import re import subprocess from pathlib import Path from typing import List, Tuple import click import pytest from tango.common.testing import TangoTestCase from tango.settings import TangoGlobalSettings from tango.version import VERSION class TestRun(TangoTestCase): def clean_log_lines( self, log_lines: List[str], file_friendly_logging: bool = False ) -> List[str]: out = [] for line in log_lines: unstyled_line = click.unstyle(line) if file_friendly_logging: assert line == unstyled_line line = unstyled_line parts = re.split(r"(DEBUG|INFO|WARNING|ERROR|CRITICAL)\s+", line) if len(parts) >= 3: line = "".join(parts[2:]) line = re.sub(r"\s+[^ ]+$", "", line) elif len(parts) == 1: line = parts[0] else: raise ValueError(str(parts)) if line: out.append(line.strip()) return out def check_logs( self, run_dir: Path, process_result: subprocess.CompletedProcess, file_friendly_logging: bool = False, ) -> Tuple[List[str], List[str]]: stdout_lines = process_result.stdout.decode().replace("\r", "\n").split("\n") cleaned_stdout_lines = self.clean_log_lines(stdout_lines, file_friendly_logging) log_file = run_dir / "out.log" assert log_file.is_file() log_lines = open(log_file).read().split("\n") cleaned_log_lines = self.clean_log_lines(log_lines) for line in cleaned_stdout_lines[ next(i for i, line in enumerate(stdout_lines) if "Starting new run" in line) : ]: assert line in cleaned_log_lines return log_lines, cleaned_log_lines def test_version(self): result = subprocess.run(["tango", "--version"], capture_output=True, text=True) assert result.returncode == 0 assert VERSION in result.stdout @pytest.mark.parametrize("log_level", ["debug", "info", "warning", "error"]) @pytest.mark.parametrize("raise_error", (True, False)) def test_logging_all_levels(self, log_level: str, raise_error): cmd = [ "tango", "--log-level", log_level, "run", str(self.FIXTURES_ROOT / "experiment" / "noisy.jsonnet"), "-w", str(self.TEST_DIR), "-o", json.dumps({"steps.noisy_step.raise_error": raise_error}), ] result = subprocess.run(cmd, capture_output=True) run_dir = next((self.TEST_DIR / "runs").iterdir()) if raise_error: assert result.returncode == 1 else: assert result.returncode == 0 _, cleaned_log_lines = self.check_logs(run_dir, result) # Debug messages. assert cleaned_log_lines.count("debug message from cli_logger") == 1 assert cleaned_log_lines.count("debug message") == (1 if log_level == "debug" else 0) # Info messages. assert cleaned_log_lines.count("info message from cli_logger") == 1 assert cleaned_log_lines.count("info message") == ( 1 if log_level in {"debug", "info"} else 0 ) # Warning messages. assert cleaned_log_lines.count("warning message from cli_logger") == 1 assert cleaned_log_lines.count("warning message") == ( 1 if log_level in {"debug", "info", "warning"} else 0 ) # Error messages. assert cleaned_log_lines.count("error message from cli_logger") == 1 assert cleaned_log_lines.count("error message") == ( 1 if log_level in {"debug", "info", "warning", "error"} else 0 ) # Traceback. if raise_error: assert "Traceback (most recent call last):" in cleaned_log_lines assert "ValueError: Oh no!" in cleaned_log_lines def test_deterministic_experiment(self): cmd = [ "tango", "run", str(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet"), "-w", str(self.TEST_DIR), ] result = subprocess.run(cmd, capture_output=True) assert result.returncode == 0 assert len(os.listdir(self.TEST_DIR / "cache")) == 2 run_dir = next((self.TEST_DIR / "runs").iterdir()) assert (run_dir / "hello").is_dir() assert (run_dir / "hello" / "cache-metadata.json").is_file() assert (run_dir / "hello_world").is_dir() # Check logs. self.check_logs(run_dir, result) # Running again shouldn't create any more directories in the cache. result = subprocess.run(cmd) assert result.returncode == 0 assert len(os.listdir(self.TEST_DIR / "cache")) == 2 # We should see two runs now. assert len(os.listdir(self.TEST_DIR / "runs")) == 2 def test_experiment_with_memory_workspace(self): cmd = [ "tango", "run", str(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet"), "-w", "memory://", ] result = subprocess.run(cmd, capture_output=True) assert result.returncode == 0 def test_experiment_with_default_workspace(self): cmd = [ "tango", "run", str(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet"), ] result = subprocess.run(cmd, capture_output=True) assert result.returncode == 0 def test_random_experiment(self): cmd = [ "tango", "run", str(self.FIXTURES_ROOT / "experiment" / "random.jsonnet"), "-w", str(self.TEST_DIR), ] result = subprocess.run(cmd) assert result.returncode == 0 def test_run_name(self): name = "unique-tango-run-name" cmd = [ "tango", "run", str(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet"), "-w", str(self.TEST_DIR), "--name", name, ] result = subprocess.run(cmd, capture_output=True) run_dir = next((self.TEST_DIR / "runs").iterdir()) _, clean_log_lines = self.check_logs(run_dir, result) assert result.returncode == 0 assert f"Starting new run {name}" == clean_log_lines[0] @pytest.mark.parametrize("parallelism", [1, 2]) @pytest.mark.parametrize("start_method", ["fork", "spawn"]) @pytest.mark.parametrize("file_friendly_logging", [True, False]) def test_experiment_with_logging_and_multiprocessing( self, parallelism, start_method, file_friendly_logging ): cmd = ( [ "tango", "--log-level", "info", "--start-method", start_method, ] + ([] if not file_friendly_logging else ["--file-friendly-logging"]) + [ "run", str(self.FIXTURES_ROOT / "experiment" / "logging_check.jsonnet"), "-w", str(self.TEST_DIR), "-j", str(parallelism), ] ) result = subprocess.run(cmd, capture_output=True) run_dir = next((self.TEST_DIR / "runs").iterdir()) _, clean_log_lines = self.check_logs(run_dir, result, file_friendly_logging) all_logs = "\n".join(clean_log_lines) assert "[step stringA] 0 - This is a logging test." in clean_log_lines assert "[step stringC] 0 - This is also a logging test." in clean_log_lines assert ( "[step final_string] 0 - This is a logging test. This is being logged." in clean_log_lines ) # Make sure tqdm output makes it into the log file. assert "[step stringA] log progress: 100%" in all_logs assert "[step stringC] log progress: 100%" in all_logs assert "[step final_string] log progress: 100%" in all_logs # And logs from steps that contain multiprocessing themselves. assert "[step multiprocessing_result rank 0] Hello from worker 0!" in all_logs assert "[step multiprocessing_result rank 1] Hello from worker 1!" in all_logs assert ( "[step multiprocessing_result rank 0] Hello from the cli logger in worker 0!" in all_logs ) assert ( "[step multiprocessing_result rank 1] Hello from the cli logger in worker 1!" in all_logs ) assert "[step multiprocessing_result] progress from main process: 100%" in all_logs class TestSettings(TangoTestCase): def setup_method(self): super().setup_method() self._wd_backup = os.getcwd() os.chdir(self.TEST_DIR) cmd = "tango settings init -p ./tango.yml".split(" ") subprocess.run(cmd, check=True) def teardown_method(self): os.chdir(self._wd_backup) super().teardown_method() @property def settings(self) -> TangoGlobalSettings: return TangoGlobalSettings.from_file(self.TEST_DIR / "tango.yml") def test_settings_set_workspace(self): cmd = "tango settings set workspace ./workspace".split(" ") subprocess.run(cmd, check=True) assert self.settings.workspace == { "type": "local", "dir": str((self.TEST_DIR / "workspace").resolve()), } def test_settings_set_include_package(self): cmd = "tango settings set include-package tango.steps".split(" ") subprocess.run(cmd, check=True) assert self.settings.include_package == ["tango.steps"] def test_settings_set_include_package_invalid(self): cmd = "tango settings set include-package foo".split(" ") with pytest.raises(subprocess.CalledProcessError): subprocess.run(cmd, check=True) def test_settings_set_environment(self): cmd = "tango settings set env FOO BAR".split(" ") subprocess.run(cmd, check=True) assert self.settings.environment == {"FOO": "BAR"} def test_settings_set_environment_blocked_var(self): cmd = "tango settings set env TANGO_LOG_LEVEL info".split(" ") with pytest.raises(subprocess.CalledProcessError): subprocess.run(cmd, check=True) ================================================ FILE: tests/step_caches/__init__.py ================================================ ================================================ FILE: tests/step_caches/local_step_cache_test.py ================================================ import pickle import sys import pytest from tango.common.testing import TangoTestCase from tango.step import Step from tango.step_caches.local_step_cache import LocalStepCache class DummyStep(Step): def run(self, x: int) -> int: # type: ignore[override] return x class TestLocalStepCache(TangoTestCase): @pytest.mark.parametrize( "protocol", [pytest.param(protocol, id=f"protocol={protocol}") for protocol in range(4)] + [ pytest.param( 5, id="protocol=5", marks=pytest.mark.skipif( sys.version_info < (3, 8), reason="Protocol 5 requires Python 3.8 or newer" ), ), ], ) def test_pickling(self, protocol: int): step = DummyStep(step_name="dummy", x=1) step_cache = LocalStepCache(self.TEST_DIR) step_cache[step] = 1 assert step in step_cache assert step.unique_id in step_cache.strong_cache pickled_step_cache = pickle.dumps(step_cache, protocol=protocol) unpickled_step_cache = pickle.loads(pickled_step_cache) assert step.unique_id not in unpickled_step_cache.strong_cache assert step in unpickled_step_cache ================================================ FILE: tests/step_graph_test.py ================================================ import re from copy import deepcopy from tempfile import NamedTemporaryFile import pytest from tango.common.exceptions import ConfigurationError from tango.common.testing import TangoTestCase from tango.common.testing.steps import ( # noqa: F401 AddNumbersStep, ConcatStringsStep, StringStep, ) from tango.step_graph import StepGraph class TestStepGraph(TangoTestCase): def test_ordered_steps(self): step_graph = StepGraph.from_params( { "stepB": { "type": "add_numbers", "a_number": 2, "b_number": 3, }, "stepC": { "type": "add_numbers", "a_number": {"type": "ref", "ref": "stepB"}, "b_number": 5, }, "stepA": { "type": "add_numbers", "a_number": 3, "b_number": 1, }, } ) result = StepGraph.ordered_steps(step_graph.parsed_steps) assert [res.name for res in result] == ["stepB", "stepC", "stepA"] def test_from_file(self): step_graph = StepGraph.from_file(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet") assert "hello" in step_graph assert "hello_world" in step_graph def test_missing_type(self): with pytest.raises(ConfigurationError, match=re.escape('key "type" is required')): StepGraph.from_params( { "step3": { "a_number": 3, "b_number": 1, }, } ) def test_direct_construction(self): step_a = AddNumbersStep(a_number=3, b_number=2, step_name="stepA") step_b = AddNumbersStep(a_number=step_a, b_number=2, step_name="stepB") step_graph = StepGraph({"stepA": step_a, "stepB": step_b}) assert list(step_graph.parsed_steps.keys()) == ["stepA", "stepB"] def test_direct_construction_missing_dependency(self): step_a = AddNumbersStep(a_number=3, b_number=2, step_name="stepA") step_b = AddNumbersStep(a_number=step_a, b_number=2, step_name="stepB") with pytest.raises(ConfigurationError, match="Or a missing dependency"): StepGraph({"stepB": step_b}) def test_to_file(self): step_graph = StepGraph.from_file(self.FIXTURES_ROOT / "experiment" / "hello_world.jsonnet") with NamedTemporaryFile( prefix="test-step-graph-to-file-", suffix=".jsonnet", dir=self.TEST_DIR ) as file_ref: step_graph.to_file(file_ref.name) new_step_graph = StepGraph.from_file(file_ref.name) assert step_graph == new_step_graph def test_to_file_without_config(self): from tango.format import JsonFormat step_a = AddNumbersStep(a_number=3, b_number=2, step_name="stepA", cache_results=False) step_b = AddNumbersStep( a_number=step_a, b_number=2, step_name="stepB", step_format=JsonFormat("gz") ) step_graph = StepGraph({"stepA": step_a, "stepB": step_b}) with NamedTemporaryFile( prefix="test-step-graph-to-file-without-config", suffix=".jsonnet", dir=self.TEST_DIR ) as file_ref: step_graph.to_file(file_ref.name) new_step_graph = StepGraph.from_file(file_ref.name) assert step_graph == new_step_graph def test_with_step_indexer(self): config = { "list": {"type": "range_step", "start": 0, "end": 3}, "added": { "type": "add_numbers", "a_number": 2, "b_number": {"type": "ref", "ref": "list", "key": 1}, }, } step_graph = StepGraph.from_params(deepcopy(config)) # type: ignore[arg-type] assert [s.name for s in step_graph["added"].dependencies] == ["list"] assert step_graph.to_config() == config def test_with_forced_dependencies(self): config = { "some_string": { "type": "string", "result": "I should run second", "step_extra_dependencies": [{"type": "ref", "ref": "other_string"}], }, "other_string": {"type": "string", "result": "I should run first"}, "added": { "type": "concat_strings", "string1": "Some string:", "string2": {"type": "ref", "ref": "some_string"}, }, } step_graph = StepGraph.from_params(deepcopy(config)) # type: ignore[arg-type] assert step_graph["some_string"].dependencies == {step_graph["other_string"]} assert step_graph["added"].recursive_dependencies == { step_graph["other_string"], step_graph["some_string"], } ================================================ FILE: tests/step_info_test.py ================================================ import json from pathlib import Path from typing import Any from tango.common.testing.steps import FloatStep from tango.step import Step from tango.step_graph import StepGraph from tango.step_info import StepInfo def test_step_info(): step = FloatStep(step_name="float", result=1.0) step_info = StepInfo.new_from_step(step) # Check Git metadata. if (Path.cwd() / ".git").exists(): assert step_info.environment.git is not None assert step_info.environment.git.commit is not None assert step_info.environment.git.remote is not None assert "allenai/tango" in step_info.environment.git.remote # Check pip requirements. assert step_info.environment.packages is not None # Test serialization / deserialization. serialized = json.dumps(step_info.to_json_dict()) deserialized = StepInfo.from_json_dict(json.loads(serialized)) assert deserialized == step_info def test_step_info_with_step_dependency(): """Checks that the StepInfo config is not parsed to a Step if it has dependencies on upstream steps""" @Step.register("foo", exist_ok=True) class FooStep(Step): def run(self, bar: Any) -> str: # type: ignore return "foo" + bar @Step.register("bar", exist_ok=True) class BarStep(Step): def run(self) -> str: # type: ignore return "Hey!" graph = StepGraph.from_params( { "foo": { "type": "foo", "bar": {"type": "ref", "ref": "bar"}, }, "bar": { "type": "bar", }, } ) step = graph["foo"] step_info = StepInfo.new_from_step(step) step_info_json = json.dumps(step_info.to_json_dict()) step_info = StepInfo.from_json_dict(json.loads(step_info_json)) assert isinstance(step_info.config, dict) ================================================ FILE: tests/step_test.py ================================================ import collections from typing import Any, Dict, Mapping, MutableMapping import pytest from tango import StepGraph from tango.common import Params, Registrable from tango.common.exceptions import ConfigurationError from tango.common.from_params import FromParams from tango.common.testing import TangoTestCase from tango.step import FunctionalStep, Step, step from tango.workspaces import MemoryWorkspace class TestStep(TangoTestCase): def test_from_params(self): step = Step.from_params({"type": "float", "result": 3}) result = step.result() assert result == 3 def test_from_params_wrong_type(self): with pytest.raises(TypeError): Step.from_params({"type": "float", "result": "not a float"}) def test_step_with_from_params_input(self): class Bar(FromParams): def __init__(self, x: int): self.x = x @Step.register("foo", exist_ok=True) class FooStep(Step): def run(self, bar: Bar) -> Bar: # type: ignore return bar step = Step.from_params({"type": "foo", "bar": {"x": 1}}) assert step.result().x == 1 def test_no_hash_arguments(self): @Step.register("no_hash_step") class SkipArgStep(Step): SKIP_ID_ARGUMENTS = {"arg"} def run(self, arg: str) -> int: # type: ignore return 5 step1 = SkipArgStep(arg="foo") step2 = SkipArgStep(arg="bar") assert step1.unique_id == step2.unique_id def test_skip_default_arguments(self): class SkipArgStep(Step): def run(self) -> int: # type: ignore return 5 old_hash = SkipArgStep().unique_id class SkipArgStep(Step): SKIP_DEFAULT_ARGUMENTS = {"arg": 5} def run(self, arg: int = 5) -> int: # type: ignore return arg assert SkipArgStep().unique_id == old_hash assert SkipArgStep(arg=5).unique_id == old_hash assert SkipArgStep(arg=6).unique_id != old_hash def test_massage_kwargs(self): class CountLettersStep(Step): @classmethod def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: kwargs = kwargs.copy() kwargs["text"] = kwargs["text"].lower() return kwargs def run(self, text: str) -> Mapping[str, int]: # type: ignore text = text.lower() counter: MutableMapping[str, int] = collections.Counter() for c in text: counter[c] += 1 return counter upper = CountLettersStep(text="FOO") lower = CountLettersStep(text="foo") assert upper.unique_id == lower.unique_id assert upper.result() == lower.result() def test_default_args(self): class DefaultArgStep(Step[int]): def run(self, left: int, right: int = 0) -> int: # type: ignore return left + right explicit = DefaultArgStep(left=1, right=0) implicit = DefaultArgStep(left=1) assert explicit.unique_id == implicit.unique_id assert explicit.result() == implicit.result() def test_steps_in_params(self): class Widget(Registrable): def __init__(self, x: int): self.x = x @Widget.register("gizmo") class GizmoWidget(Widget): def __init__(self, x: int): super().__init__(x * x) @Step.register("consumer") class WidgetConsumerStep(Step): def run(self, widget: Widget): # type: ignore return widget.x @Step.register("producer") class WidgetProducerStep(Step): def run(self, x: int) -> Widget: # type: ignore return GizmoWidget(x) config = { "widget_producer": Params({"type": "producer", "x": 4}), "widget_consumer": Params( {"type": "consumer", "widget": {"type": "ref", "ref": "widget_producer"}} ), } sg = StepGraph.from_params(config) assert len(sg["widget_consumer"].dependencies) > 0 class WidgetHolder(Registrable): def __init__(self, widget: Widget): self.widget = widget @WidgetHolder.register("gizmo") class GizmoWidgetHolder(WidgetHolder): def __init__(self, gizmo: GizmoWidget): super().__init__(gizmo) @Step.register("holder_consumer") class WidgetHolderConsumerStep(Step): def run(self, widget_holder: WidgetHolder) -> int: # type: ignore return widget_holder.widget.x config = { "widget_producer": Params({"type": "producer", "x": 4}), "holder_consumer": Params( { "type": "holder_consumer", "widget_holder": { "type": "gizmo", "gizmo": {"type": "ref", "ref": "widget_producer"}, }, } ), } sg = StepGraph.from_params(config) assert len(sg["holder_consumer"].dependencies) > 0 def test_functional_step(self): class Bar(FromParams): def __init__(self, x: int): self.x = x @step(exist_ok=True) def foo(bar: Bar) -> int: return bar.x assert issubclass(foo, FunctionalStep) assert foo().run(Bar(x=1)) == 1 foo_step = Step.from_params({"type": "foo", "bar": {"x": 1}}) assert isinstance(foo_step, FunctionalStep) assert isinstance(foo_step.kwargs["bar"], Bar) def test_bound_functional_step(self): class Bar(FromParams): def __init__(self, x: int): self.x = x @step(exist_ok=True, bind=True) def foo(self, bar: Bar) -> int: assert self.work_dir.is_dir() return bar.x foo_step = Step.from_params({"type": "foo", "bar": {"x": 1}}) assert isinstance(foo_step, FunctionalStep) assert foo_step.result(MemoryWorkspace()) == 1 def test_bound_functional_step_missing_self(self): @step(exist_ok=True, bind=True) def foo(x: int) -> int: return x with pytest.raises(ConfigurationError): Step.from_params({"type": "foo", "x": 1}) @step(exist_ok=True, bind=True) def bar(s, x: int) -> int: return x with pytest.raises(ConfigurationError): Step.from_params({"type": "bar", "x": 1}) ================================================ FILE: tests/steps/__init__.py ================================================ ================================================ FILE: tests/steps/dataset_remix_test.py ================================================ from tango.common.dataset_dict import DatasetDict from tango.steps.dataset_remix import DatasetRemixStep def test_dataset_remix_step(): step = DatasetRemixStep("remix") dataset_dict = DatasetDict( { "train": list(range(10)), "dev": list(range(10, 15)), "test": list(range(15, 20)), } ) result = step.run( input=dataset_dict, new_splits={ "all_train": "train + dev", "cross_val_train": "train[:8]", "cross_val_dev": "train[-2:]", }, ) assert len(result["all_train"]) == len(dataset_dict["train"]) + len(dataset_dict["dev"]) ================================================ FILE: tests/steps/shell_step_test.py ================================================ import pytest from tango.common.testing import TangoTestCase from tango.steps.shell_step import ShellStep, make_registrable class TestShellStep(TangoTestCase): def test_shell_step(self): step = ShellStep() result = step.run("echo hello") assert isinstance(result, str) assert result == "hello\n" def test_shell_step_failure(self): step = ShellStep() with pytest.raises(RuntimeError): step.run("ls -l non_existent_path") def test_shell_step_with_output_path(self, caplog): output_path = self.TEST_DIR / "test-folder" step = ShellStep() step.run(f"mkdir {output_path}", output_path=output_path) assert f"Output found at: {output_path}" in caplog.text def test_shell_step_different_validation(self, caplog): @make_registrable(exist_ok=True) def validate_func(path): """ Validates that the file contents of the `path` are a json string. """ import json with open(path) as f: json.load(f) output_path = self.TEST_DIR / "hello.json" command = f"python3 -c \"import json; print(json.dumps({{'a': 23}}))\" > {output_path}" step = ShellStep() step.run(command, output_path=output_path, validate_output=validate_func, shell=True) assert f"Output found at: {output_path}" in caplog.text def test_shell_step_in_config(self, caplog): output_path = str(self.TEST_DIR / "test-folder") config = { "steps": { "create_dir": { "type": "shell_step", "shell_command": f"mkdir {output_path}", "output_path": output_path, "validate_output": {"type": "check_path_existence"}, }, } } # Regular run contains all step outputs. self.run(config) assert f"Output found at: {output_path}" in caplog.text ================================================ FILE: tests/workspaces/__init__.py ================================================ ================================================ FILE: tests/workspaces/local_workspace_test.py ================================================ from shutil import copytree import pytest from sqlitedict import SqliteDict from tango import Step from tango.common.testing import TangoTestCase from tango.step_info import StepState from tango.workspaces import LocalWorkspace class AdditionStep(Step): def run(self, a: int, b: int) -> int: # type: ignore return a + b class TestLocalWorkspace(TangoTestCase): def test_local_workspace_one_step(self): workspace = LocalWorkspace(self.TEST_DIR) step = AdditionStep(a=1, b=2) with pytest.raises(KeyError): # This can't possibly work because the workspace has never seen that step before. step_info = workspace.step_info(step.unique_id) assert step_info.state == StepState.INCOMPLETE step_info = workspace.step_info(step) assert step_info.state == StepState.INCOMPLETE result = step.result(workspace) assert result == 3 step_info = workspace.step_info(step.unique_id) assert step_info.state == StepState.COMPLETED step_info = workspace.step_info(step) assert step_info.state == StepState.COMPLETED def test_local_workspace_two_steps(self): workspace = LocalWorkspace(self.TEST_DIR) step1 = AdditionStep(a=1, b=2) step2 = AdditionStep(a=step1, b=3) step_info = workspace.step_info(step2) assert step_info.state == StepState.INCOMPLETE step_info = workspace.step_info(step2.unique_id) assert step_info.state == StepState.INCOMPLETE assert step1.unique_id in step_info.dependencies step_info = workspace.step_info(step1.unique_id) assert step_info.state == StepState.INCOMPLETE step_info = workspace.step_info(step1) assert step_info.state == StepState.INCOMPLETE result = step2.result(workspace) assert result == 6 for step in [step1, step2]: step_info = workspace.step_info(step.unique_id) assert step_info.state == StepState.COMPLETED step_info = workspace.step_info(step) assert step_info.state == StepState.COMPLETED def test_local_workspace_upgrade_v1_to_v2(self): workspace_dir = self.TEST_DIR / "workspace" copytree( self.FIXTURES_ROOT / "v1_local_workspace", workspace_dir, symlinks=True, ) workspace = LocalWorkspace(workspace_dir) step_info = workspace.step_info("SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz") assert step_info.state == StepState.COMPLETED dependencies = list(step_info.dependencies) # Make sure all the dependencies are there. while len(dependencies) > 0: step_info = workspace.step_info(dependencies.pop()) dependencies.extend(step_info.dependencies) def test_remove_step(self): workspace = LocalWorkspace(self.TEST_DIR) step = AdditionStep(a=1, b=2) workspace.step_starting(step) workspace.step_finished(step, 1.0) with SqliteDict(workspace.step_info_file) as d: assert step.unique_id in d cache = workspace.step_cache assert step in cache workspace.remove_step(step.unique_id) with SqliteDict(workspace.step_info_file) as d: assert step.unique_id not in d cache = workspace.step_cache assert step not in cache ================================================ FILE: tests/workspaces/memory_workspace_test.py ================================================ from tango.common.testing.steps import FloatStep from tango.workspaces import MemoryWorkspace def test_remove_step(): workspace = MemoryWorkspace() step = FloatStep(step_name="float", result=1.0) workspace.step_starting(step) workspace.step_finished(step, 1.0) cache = workspace.step_cache assert step.unique_id in workspace.unique_id_to_info assert step in cache workspace.remove_step(step.unique_id) cache = workspace.step_cache assert step.unique_id not in workspace.unique_id_to_info assert step not in cache