[
  {
    "path": ".dockerignore",
    "content": ".dockerignore\n**.pyc\n**/__pycache__\n.gitignore\n.git\n.coverage\n.mypy_cache\ndocs\nexamples\ntests\ntest_fixtures\nintegration_tests\ndist\n*.egg-info\n"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "content": "# Contributing\n\nThanks for considering contributing! Please read this document to learn the various ways you can contribute to this project and how to go about doing it.\n\n## Bug reports and feature requests\n\n### Did you find a bug?\n\nFirst, do [a quick search](https://github.com/allenai/tango/issues) to see whether your issue has already been reported.\nIf your issue has already been reported, please comment on the existing issue.\n\nOtherwise, open [a new GitHub issue](https://github.com/allenai/tango/issues). Be sure to include a clear title\nand description. The description should include as much relevant information as possible. The description should\nexplain how to reproduce the erroneous behavior as well as the behavior you expect to see. Ideally you would include a\ncode sample or an executable test case demonstrating the expected behavior.\n\n### Do you have a suggestion for an enhancement or new feature?\n\nWe use GitHub issues to track feature requests. Before you create a feature request:\n\n- Make sure you have a clear idea of the enhancement you would like. If you have a vague idea, consider discussing\n  it first on a GitHub issue.\n- Check the documentation to make sure your feature does not already exist.\n- Do [a quick search](https://github.com/allenai/tango/issues) to see whether your feature has already been suggested.\n\nWhen creating your request, please:\n\n- Provide a clear title and description.\n- Explain why the enhancement would be useful. It may be helpful to highlight the feature in other libraries.\n- Include code examples to demonstrate how the enhancement would be used.\n\n## Making a pull request\n\nWhen you're ready to contribute code to address an open issue, please follow these guidelines to help us be able to review your pull request (PR) quickly.\n\n1.  **Initial setup** (only do this once)\n\n    <details><summary>Expand details 👇</summary><br/>\n\n    If you haven't already done so, please [fork](https://help.github.com/en/enterprise/2.13/user/articles/fork-a-repo) this repository on GitHub.\n\n    Then clone your fork locally with\n\n        git clone https://github.com/USERNAME/tango.git\n\n    or\n\n        git clone git@github.com:USERNAME/tango.git\n\n    At this point the local clone of your fork only knows that it came from _your_ repo, github.com/USERNAME/tango.git, but doesn't know anything the _main_ repo, [https://github.com/allenai/tango.git](https://github.com/allenai/tango). You can see this by running\n\n        git remote -v\n\n    which will output something like this:\n\n        origin https://github.com/USERNAME/tango.git (fetch)\n        origin https://github.com/USERNAME/tango.git (push)\n\n    This means that your local clone can only track changes from your fork, but not from the main repo, and so you won't be able to keep your fork up-to-date with the main repo over time. Therefore you'll need to add another \"remote\" to your clone that points to [https://github.com/allenai/tango.git](https://github.com/allenai/tango). To do this, run the following:\n\n        git remote add upstream https://github.com/allenai/tango.git\n\n    Now if you do `git remote -v` again, you'll see\n\n        origin https://github.com/USERNAME/tango.git (fetch)\n        origin https://github.com/USERNAME/tango.git (push)\n        upstream https://github.com/allenai/tango.git (fetch)\n        upstream https://github.com/allenai/tango.git (push)\n\n    Finally, you'll need to create a Python 3 virtual environment suitable for working on this project. There a number of tools out there that making working with virtual environments easier.\n    The most direct way is with the [`venv` module](https://docs.python.org/3.8/library/venv.html) in the standard library, but if you're new to Python or you don't already have a recent Python 3 version installed on your machine,\n    we recommend [Miniconda](https://docs.conda.io/en/latest/miniconda.html).\n\n    On Mac, for example, you can install Miniconda with [Homebrew](https://brew.sh/):\n\n        brew install miniconda\n\n    Then you can create and activate a new Python environment by running:\n\n        conda create -n tango python=3.9\n        conda activate tango\n\n    Once your virtual environment is activated, you can install your local clone in \"editable mode\" with\n\n        pip install -U pip setuptools wheel\n        pip install -e '.[dev,all]'\n\n    The \"editable mode\" comes from the `-e` argument to `pip`, and essential just creates a symbolic link from the site-packages directory of your virtual environment to the source code in your local clone. That way any changes you make will be immediately reflected in your virtual environment.\n\n    To test your installation, just run\n\n        tango info\n\n    </details>\n\n2.  **Ensure your fork is up-to-date**\n\n    <details><summary>Expand details 👇</summary><br/>\n\n    Once you've added an \"upstream\" remote pointing to [https://github.com/allenai/tango.git](https://github.com/allenai/tango), keeping your fork up-to-date is easy:\n\n        git checkout main  # if not already on main\n        git pull --rebase upstream main\n        git push\n\n    </details>\n\n3.  **Create a new branch to work on your fix or enhancement**\n\n    <details><summary>Expand details 👇</summary><br/>\n\n    Committing directly to the main branch of your fork is not recommended. It will be easier to keep your fork clean if you work on a separate branch for each contribution you intend to make.\n\n    You can create a new branch with\n\n        # replace BRANCH with whatever name you want to give it\n        git checkout -b BRANCH\n        git push -u origin BRANCH\n\n    </details>\n\n4.  **Test your changes**\n\n    <details><summary>Expand details 👇</summary><br/>\n\n    Our continuous integration (CI) testing runs [a number of checks](https://github.com/allenai/tango/actions) for each pull request on [GitHub Actions](https://github.com/features/actions). You can run most of these tests locally, which is something you should do _before_ opening a PR to help speed up the review process and make it easier for us.\n\n    First, you should run [`isort`](https://github.com/PyCQA/isort) and [`black`](https://github.com/psf/black) to make sure you code is formatted consistently.\n    Many IDEs support code formatters as plugins, so you may be able to setup isort and black to run automatically everytime you save.\n    For example, [`black.vim`](https://github.com/psf/black/tree/master/plugin) will give you this functionality in Vim. But both `isort` and `black` are also easy to run directly from the command line.\n    Just run this from the root of your clone:\n\n        isort .\n        black .\n\n    Our CI also uses [`ruff`](https://github.com/charliermarsh/ruff) to lint the code base and [`mypy`](http://mypy-lang.org/) for type-checking. You should run both of these next with\n\n        ruff check .\n\n    and\n\n        mypy .\n\n    We also strive to maintain high test coverage, so most contributions should include additions to [the unit tests](https://github.com/allenai/tango/tree/main/tests). These tests are run with [`pytest`](https://docs.pytest.org/en/latest/), which you can use to locally run any test modules that you've added or changed.\n\n    For example, if you've fixed a bug in `tango/a/b.py`, you can run the tests specific to that module with\n\n        pytest -v tests/a/b_test.py\n\n    If your contribution involves additions to any public part of the API, we require that you write docstrings\n    for each function, method, class, or module that you add.\n    See the [Writing docstrings](#writing-docstrings) section below for details on the syntax.\n    You should test to make sure the API documentation can build without errors by running\n\n        make docs\n\n    If the build fails, it's most likely due to small formatting issues. If the error message isn't clear, feel free to comment on this in your pull request.\n\n    And finally, please update the [CHANGELOG](https://github.com/allenai/tango/blob/main/CHANGELOG.md) with notes on your contribution in the \"Unreleased\" section at the top.\n\n    After all of the above checks have passed, you can now open [a new GitHub pull request](https://github.com/allenai/tango/pulls).\n    Make sure you have a clear description of the problem and the solution, and include a link to relevant issues.\n\n    We look forward to reviewing your PR!\n\n    </details>\n\n### Writing docstrings\n\nWe use [Sphinx](https://www.sphinx-doc.org/en/master/index.html) to build our API docs, which automatically parses all docstrings\nof public classes and methods. All docstrings should adhere to the [Numpy styling convention](https://www.sphinx-doc.org/en/master/usage/extensions/example_numpy.html).\n\n## Adding a new integration\n\nIn order to add a new integration, there are several additional steps and guidelines you should follow\nin addition to everything listed in [Making a pull request](#making-a-pull-request).\n\n1. First start by creating a new submodule `tango.integrations.name_of_integration` and put all of the code for your integration in there.\n2. Then you must add a module docstring to the `__init__.py` file of the submodule which imports all of the public components of the integration,\n   and defines the [`__all__`](https://docs.python.org/3/tutorial/modules.html#importing-from-a-package) special variable to include all of those components.\n   This ensures all of the public components will show up in the documentation.\n3. Next that you should add unit tests of your code to `tests/integrations/name_of_integration/`.\n4. Then add a new file `docs/source/api/integrations/name_of_integration.rst`, and include the directive:\n\n   ```\n   .. automodule:: tango.integrations.name_of_integration\n      :members:\n   ```\n\n   Take a look at any of the other files in that folder to see how it should look exactly.\n\n5. And then add `name_of_integration` to the `toctree` in `docs/source/api/integrations/index.rst`.\n6. After that, add any additional requirements that your integration depends on to `requirements.txt`. Be sure to put those under the \"Extra dependencies for integrations\" section,\n   and add the special inline comment `# needed by: name_of_integration`.\n7. And finally, in the `checks` job definition in `.github/workflows/main.yml`, add a new object\n   to the matrix for your integration following the other examples there.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.yml",
    "content": "name: 🐛 Bug Report\ndescription: Create a report to help us reproduce and fix the bug\nlabels: 'bug'\n\nbody:\n- type: markdown\n  attributes:\n    value: >\n      #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/allenai/tango/issues?q=is%3Aissue+sort%3Acreated-desc+).\n- type: textarea\n  attributes:\n    label: 🐛 Describe the bug\n    description: |\n      Please provide a clear and concise description of what the bug is.\n\n      If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as succinct (minimal) as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports, etc. For example:\n\n      ```python\n      # All necessary imports at the beginning\n      import tango\n\n      # A succinct reproducing example trimmed down to the essential parts:\n      assert False is True, \"Oh no!\"\n      ```\n\n      If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.\n\n      Please also paste or describe the results you observe along with the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.\n    placeholder: |\n      A clear and concise description of what the bug is.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Versions\n    description: |\n      Please run the following and paste the output below.\n      ```sh\n      python --version && pip freeze\n      ```\n  validations:\n    required: true\n- type: markdown\n  attributes:\n    value: >\n      Thanks for contributing 🎉!\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/documentation.yml",
    "content": "name: 📚 Documentation\ndescription: Report an issue related to https://ai2-tango.readthedocs.io/latest\nlabels: 'documentation'\n\nbody:\n- type: textarea\n  attributes:\n    label: 📚 The doc issue\n    description: >\n      A clear and concise description of what content in https://ai2-tango.readthedocs.io/latest is an issue.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Suggest a potential alternative/fix\n    description: >\n      Tell us how we could improve the documentation in this regard.\n- type: markdown\n  attributes:\n    value: >\n      Thanks for contributing 🎉!\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.yml",
    "content": "name: 🚀 Feature request\ndescription: Submit a proposal/request for a new feature\nlabels: 'feature request'\n\nbody:\n- type: textarea\n  attributes:\n    label: 🚀 The feature, motivation and pitch\n    description: >\n      A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *\"I'm working on X and would like Y to be possible\"*. If this is related to another GitHub issue, please link here too.\n  validations:\n    required: true\n- type: textarea\n  attributes:\n    label: Alternatives\n    description: >\n      A description of any alternative solutions or features you've considered, if any.\n- type: textarea\n  attributes:\n    label: Additional context\n    description: >\n      Add any other context or screenshots about the feature request.\n- type: markdown\n  attributes:\n    value: >\n      Thanks for contributing 🎉!\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "version: 2\nupdates:\n- package-ecosystem: pip\n  directory: \"/\"\n  schedule:\n    interval: \"daily\"\n  open-pull-requests-limit: 10\n- package-ecosystem: \"github-actions\"\n  directory: \"/\"\n  schedule:\n    interval: \"daily\"\n"
  },
  {
    "path": ".github/workflows/changelog.yml",
    "content": "name: Changelog\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\non:\n  pull_request:\n    branches:\n      - main\n    paths:\n      - 'tango/**'\n\njobs:\n  changelog:\n    name: CHANGELOG\n    runs-on: ubuntu-latest\n    if: github.event_name == 'pull_request'\n\n    steps:\n    - uses: actions/checkout@v3\n      with:\n        fetch-depth: 0\n\n    - name: Check that CHANGELOG has been updated\n      run: |\n        # If this step fails, this means you haven't updated the CHANGELOG.md\n        # file with notes on your contribution.\n        git diff --name-only $(git merge-base origin/main HEAD) | grep '^CHANGELOG.md$' && echo \"Thanks for helping keep our CHANGELOG up-to-date!\"\n"
  },
  {
    "path": ".github/workflows/docker.yml",
    "content": "name: Docker\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\non:\n  pull_request:\n    branches:\n      - main\n    paths:\n      - \"Dockerfile\"\n      - \".dockerignore\"\n      - \"pyproject.toml\"\n  push:\n    tags:\n      - \"v*.*.*\"\n\njobs:\n  build:\n    name: Build (${{ matrix.build.tag }})\n    runs-on: ubuntu-latest\n    strategy:\n      fail-fast: false\n      matrix:\n        build:\n          - base_image: ghcr.io/allenai/pytorch:1.12.1-cuda11.3-python3.9\n            tag: cuda11.3\n    env:\n      IMAGE_NAME: ghcr.io/allenai/tango\n    steps:\n      - uses: actions/checkout@v3\n\n      - name: Build Docker image\n        run: |\n          docker build --build-arg BASE_IMAGE=${{ matrix.build.base_image }} -t \"${IMAGE_NAME}:${{ matrix.build.tag }}\" .\n\n      - name: Test Docker image\n        run: |\n          docker run --rm \"${IMAGE_NAME}:${{ matrix.build.tag }}\" info\n\n      - name: Log in to ghcr.io\n        if: github.event_name != 'pull_request'\n        run: |\n          echo \"${{ secrets.GITHUB_TOKEN }}\" | docker login ghcr.io -u ${{ github.actor }} --password-stdin\n\n      - name: Push latest to ghcr.io\n        if: github.event_name != 'pull_request'\n        run: |\n          docker push \"${IMAGE_NAME}:${{ matrix.build.tag }}\"\n\n      - name: Push release version to ghcr.io\n        if: startsWith(github.ref, 'refs/tags/')\n        run: |\n          GITHUB_TAG=${GITHUB_REF#refs/tags/}\n          docker tag \"${IMAGE_NAME}:${{ matrix.build.tag }}\" \"${IMAGE_NAME}:${GITHUB_TAG}-${{ matrix.build.tag }}\"\n          docker push \"${IMAGE_NAME}:${GITHUB_TAG}-${{ matrix.build.tag }}\"\n"
  },
  {
    "path": ".github/workflows/docker_testing.yml",
    "content": "# This workflow is just for building our Docker image for GPU testing on Beaker,\n# and pushing it to Beaker. We only run it when the relevant Dockerfile (or .dockerignore) changes.\nname: Docker testing\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n  cancel-in-progress: true\n\non:\n  pull_request:\n    branches:\n      - main\n    paths:\n      - 'Dockerfile.test'\n      - '.dockerignore'\n      - 'scripts/entrypoint.sh'\n  push:\n    branches:\n      - main\n    paths:\n      - 'Dockerfile.test'\n      - '.dockerignore'\n      - 'scripts/entrypoint.sh'\n\njobs:\n  build:\n    name: Build\n    runs-on: ubuntu-latest\n    env:\n      BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}\n      BEAKER_WORKSPACE: ai2/tango-testing\n      IMAGE_NAME: tango-testing\n    steps:\n      - uses: actions/checkout@v3\n\n      - uses: allenai/setup-beaker@v2\n        with:\n          token: ${{ secrets.BEAKER_TOKEN }}\n          workspace: ${{ env.BEAKER_WORKSPACE }}\n\n      - name: Build Docker image\n        run: |\n          docker build -t \"$IMAGE_NAME\" -f Dockerfile.test .\n\n      - name: Determine current commit SHA (pull request)\n        if: github.event_name == 'pull_request'\n        run: |\n          echo \"COMMIT_SHA=${{ github.event.pull_request.head.sha }}\" >> $GITHUB_ENV\n\n      - name: Determine current commit SHA (push)\n        if: github.event_name != 'pull_request'\n        run: |\n          echo \"COMMIT_SHA=$GITHUB_SHA\" >> $GITHUB_ENV\n\n      - name: Test Docker image\n        run: |\n          docker run --rm --env COMMIT_SHA=\"$COMMIT_SHA\" \"$IMAGE_NAME\" tango info\n\n       # In order to push a new version of an image to beaker, we have to delete the old version first.\n       # This doesn't actually delete the backing Docker image, so we'll still benefit from layer\n       # caching when we push new versions. But we have to be careful to minimize the amount\n       # of time between deletion and creation, because during that time any Beaker job trying to start\n       # that depends on that image will fail. So to minimize this downtime, we first push a\n       # \"temp\" version of the image, then delete the current one and quickly rename the \"temp\" one to take its place.\n       # The image might not exist yet though, so it's okay if the delete fails.\n\n      - name: Delete existing commit image\n        continue-on-error: true\n        run: |\n          beaker image delete petew/${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }}\n\n      - name: Upload new commit image\n        run: |\n          beaker image create --workspace ${{ env.BEAKER_WORKSPACE }} --name ${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }} ${{ env.IMAGE_NAME }}\n\n      - name: Delete existing image\n        if: github.event_name != 'pull_request'\n        continue-on-error: true\n        run: |\n          beaker image delete petew/${{ env.IMAGE_NAME }}\n\n      - name: Rename new commit image to final image\n        if: github.event_name != 'pull_request'\n        run: |\n          beaker image rename petew/${{ env.IMAGE_NAME }}-${{ env.COMMIT_SHA }} ${{ env.IMAGE_NAME }}\n"
  },
  {
    "path": ".github/workflows/integration_tests.yml",
    "content": "name: Integration tests\n\non:\n  workflow_dispatch:\n    inputs:\n      test:\n        description: the integration test to run\n        default: fairscale_benchmarks\n        required: true\n        type: choice\n        options:\n          - fairscale_benchmarks\n      cluster:\n        description: the beaker cluster to run the test on\n        default: ai2/tango-integration-tests\n        required: true\n        type: choice\n        options:\n          - ai2/tango-integration-tests\n          - ai2/allennlp-cirrascale\n  # Uncomment this trigger to test changes on a pull request.\n  # You also have to uncomment the lines below that mention 'for pull request checks'\n  # pull_request:\n  #   branches:\n  #     - '*'\n\njobs:\n  run_test:\n    name: ${{ github.event.inputs.test }}\n    # name: fairscale_benchmarks  # for pull request checks\n    runs-on: [ubuntu-latest]\n    timeout-minutes: 60\n    env:\n      TEST_NAME: ${{ github.event.inputs.test }}\n      # TEST_NAME: fairscale_benchmarks  # for pull request checks\n      BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}\n      BEAKER_WORKSPACE: ai2/tango-integration-tests\n      BEAKER_CLUSTER: ${{ github.event.inputs.cluster }}\n      # BEAKER_CLUSTER: ai2/allennlp-cirrascale  # for pull request checks\n      IMAGE_NAME: petew/tango-testing\n    steps:\n      - uses: actions/checkout@v3\n\n      - name: Validate inputs\n        run: |\n          # The 'test' input should be a directory in `integration_tests/`\n          test -d \"integration_tests/${TEST_NAME}\"\n\n      - name: Determine current commit SHA (pull request)\n        if: github.event_name == 'pull_request'\n        run: |\n          echo \"COMMIT_SHA=${{ github.event.pull_request.head.sha }}\" >> $GITHUB_ENV\n\n      - name: Determine current commit SHA (push)\n        if: github.event_name != 'pull_request'\n        run: |\n          echo \"COMMIT_SHA=$GITHUB_SHA\" >> $GITHUB_ENV\n\n      - name: Install beaker client\n        shell: bash\n        run: |\n          mkdir -p \"$HOME/bin\"\n\n          # Download and install from latest GitHub release.\n          curl -s https://api.github.com/repos/allenai/beaker/releases/latest \\\n            | grep 'browser_download_url.*linux' \\\n            | cut -d '\"' -f 4 \\\n            | wget -qi - \\\n          && tar -xvzf beaker_linux.tar.gz -C \"$HOME/bin\"\n\n          # Add to path.\n          echo \"$HOME/bin\" >> \"$GITHUB_PATH\"\n\n      - name: Verify beaker install\n        run: |\n          beaker account whoami\n\n      - name: Create beaker experiment config\n        run: |\n          cat >beaker_config.yml << EOL\n          version: v2-alpha\n          description: ${{ env.TEST_NAME }}\n          tasks:\n            - name: test\n              image:\n                beaker: ${{ env.IMAGE_NAME }}\n              command: [\"/entrypoint.sh\", \"integration_tests/${{ env.TEST_NAME }}/run.sh\"]\n              envVars:\n                - name: COMMIT_SHA\n                  value: $COMMIT_SHA\n                - name: WANDB_API_KEY\n                  secret: WANDB_API_KEY\n                - name: FILE_FRIENDLY_LOGGING\n                  value: \"true\"\n                - name: TOKENIZERS_PARALLELISM  # set this to avoid warnings\n                  value: \"true\"\n                - name: PYTHONUNBUFFERED\n                  value: \"true\"\n              result:\n                path: '/results'\n              resources:\n                gpuCount: 4\n              context:\n                cluster: ${{ env.BEAKER_CLUSTER }}\n                priority: normal\n          EOL\n          cat beaker_config.yml\n\n      - name: Submit beaker job\n        run: |\n          TIMESTAMP=$(date +%H%M%S)\n          EXPERIMENT=$(beaker experiment create beaker_config.yml --workspace $BEAKER_WORKSPACE --name \"${TEST_NAME}-${{ github.run_number }}-${TIMESTAMP}\" | awk '{print $2}')\n          if [ -z \"$EXPERIMENT\" ]; then\n            exit 1\n          else\n            echo \"EXPERIMENT=$EXPERIMENT\" >> $GITHUB_ENV\n            echo \"Experiment $EXPERIMENT submitted. See progress at https://beaker.org/ex/$EXPERIMENT\"\n          fi\n\n      - name: Wait for job to finish\n        run: |\n          beaker experiment await $EXPERIMENT test finalized --timeout 60m\n          # Check the job's exit code.\n          test $(beaker experiment get $EXPERIMENT --format=json | jq '.[0].jobs[0].status.exitCode') -eq 0\n\n      - name: Get logs\n        if: always()\n        run: |\n          # EXPERIMENT could be empty if the submission step failed.\n          # We'll exit right away if that's the case.\n          if [ -z \"$EXPERIMENT\" ]; then\n            echo \"No logs to show\"\n            exit 0\n          fi\n\n          # Download logs from beaker.\n          beaker experiment results $EXPERIMENT --prefix out.log --output results\n\n          # If the experiment failed during startup, there might not be any logs.\n          if [ -f results/test/out.log ]; then\n            echo \"\"\n            echo \">>> Logs:\"\n            echo \"\"\n            cat results/test/out.log\n          else\n            echo \"No logs to show\"\n          fi\n\n      - name: Stop job\n        if: cancelled()\n        run: |\n          if [ ! -z \"$EXPERIMENT\" ]; then\n            beaker experiment stop $EXPERIMENT\n          fi\n"
  },
  {
    "path": ".github/workflows/main.yml",
    "content": "name: Main\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.ref }}\n\non:\n  pull_request:\n    branches:\n      - \"*\"\n  push:\n    branches:\n      - main\n    tags:\n      - \"v*.*.*\"\n\nenv:\n  CACHE_PREFIX: v5 # Change this to invalidate existing cache.\n  PYTHON_PATH: ./\n  DEFAULT_PYTHON: 3.9\n  WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}\n  BEAKER_TOKEN: ${{ secrets.BEAKER_TOKEN }}\n  BEAKER_WORKSPACE: ai2/tango-testing\n  BEAKER_IMAGE: petew/tango-testing\n  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}\n\njobs:\n  checks:\n    name: python ${{ matrix.python }} - ${{ matrix.task.name }}\n    runs-on: [ubuntu-latest]\n    timeout-minutes: 30\n    permissions:\n      contents: \"read\"\n      id-token: \"write\"\n    strategy:\n      fail-fast: false\n      matrix:\n        python: [\"3.9\"]\n        task:\n          - name: Lint\n            extras: dev,all\n            requires_torch: true\n            run: |\n              ruff check .\n\n          - name: Type check\n            extras: dev,all\n            requires_torch: true\n            run: |\n              mypy --check-untyped-defs .\n\n          - name: Build\n            extras: dev,all\n            requires_torch: true\n            run: |\n              tango --version\n              python -m build\n\n          - name: Style\n            extras: dev\n            requires_torch: false\n            run: |\n              isort --check .\n              black --check .\n\n          - name: Docs\n            extras: dev,all\n            requires_torch: true\n            run: |\n              cd docs && make html SPHINXOPTS=\"-W --keep-going\"\n\n          - name: Test\n            extras: dev\n            requires_torch: false\n            run: |\n              pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/\n\n          - name: Datasets integration\n            extras: dev,datasets\n            requires_torch: false\n            run: |\n              pytest -v --color=yes --doctest-modules tango/integrations/datasets tests/integrations/datasets\n\n          - name: PyTorch integration\n            extras: dev,torch\n            requires_torch: true\n            run: |\n              pytest -v --color=yes --doctest-modules tango/integrations/torch tests/integrations/torch\n\n          - name: Transformers integration\n            extras: dev,flax,transformers\n            requires_torch: true\n            run: |\n              pytest -v --color=yes --doctest-modules tango/integrations/transformers tests/integrations/transformers\n\n          - name: FairScale integration\n            extras: dev,fairscale\n            requires_torch: true\n            run: |\n              pytest -v --color=yes --doctest-modules tango/integrations/fairscale tests/integrations/fairscale\n\n          - name: W&B integration\n            extras: dev,torch,flax,wandb\n            requires_torch: true\n            run: |\n              pytest -v --color=yes --doctest-modules tango/integrations/wandb tests/integrations/wandb\n\n          - name: Beaker integration\n            extras: dev,beaker\n            requires_torch: false\n            run: |\n              pytest -v --color=yes --doctest-modules tango/integrations/beaker tests/integrations/beaker\n\n          - name: Flax integration\n            extras: dev,flax,transformers\n            requires_torch: false\n            run: |\n              pytest -v --color=yes --doctest-modules tango/integrations/flax tests/integrations/flax\n\n          - name: GS integration\n            extras: dev,gs\n            requires_torch: false\n            run: |\n              pytest -v --color=yes --doctest-modules tango/integrations/gs tests/integrations/gs\n\n          - name: Example - train_lm\n            extras: dev,all\n            requires_torch: true\n            run: |\n              cd examples/train_lm\n              pytest -v --color=yes test.py\n\n        include:\n          # Run the core tests on other Python versions as well.\n          - task:\n              name: Test\n              extras: dev\n              requires_torch: false\n              run: |\n                pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/\n            python: \"3.8\"\n\n          - task:\n              name: Test\n              extras: dev\n              requires_torch: false\n              run: |\n                pytest -v --durations=10 --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/\n            python: \"3.10\"\n\n    steps:\n      - uses: \"actions/checkout@v3\"\n      - name: Checkout\n        if: github.event_name != 'pull_request'\n        uses: actions/checkout@v3\n\n      # For pull requests we need to checkout the HEAD commit instead of the merge\n      # commit since some tests depend on having an existing commit.\n      - name: Checkout (pull request)\n        if: github.event_name == 'pull_request'\n        uses: actions/checkout@v3\n        with:\n          ref: ${{ github.event.pull_request.head.sha }}\n\n      - name: Setup Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: ${{ matrix.python }}\n\n      - name: Install prerequisites\n        run: |\n          pip install --upgrade pip setuptools wheel virtualenv\n\n      - name: Set build variables\n        shell: bash\n        run: |\n          set -e\n          # Get the exact Python version to use in the cache key.\n          echo \"PYTHON_VERSION=$(python --version)\" >> $GITHUB_ENV\n          echo \"RUNNER_ARCH=$(uname -m)\" >> $GITHUB_ENV\n          # Use week number in cache key so we can refresh the cache weekly.\n          echo \"WEEK_NUMBER=$(date +%V)\" >> $GITHUB_ENV\n          echo \"EXTRAS_HASH=$(python scripts/hash_extras.py ${{ matrix.task.extras }})\" >> $GITHUB_ENV\n\n      - uses: actions/cache@v3\n        id: virtualenv-cache\n        with:\n          path: .venv\n          key: ${{ env.CACHE_PREFIX }}-${{ env.WEEK_NUMBER }}-${{ runner.os }}-${{ env.RUNNER_ARCH }}-${{ env.PYTHON_VERSION }}-${{ env.EXTRAS_HASH }}-${{ hashFiles('pyproject.toml') }}\n\n      - name: Setup virtual environment (no cache hit)\n        if: steps.virtualenv-cache.outputs.cache-hit != 'true'\n        run: |\n          test -d .venv || virtualenv -p $(which python) --copies --reset-app-data .venv\n\n      # Reference: https://github.com/marketplace/actions/authenticate-to-google-cloud#setup\n      - name: Authenticate to Google Cloud\n        if: matrix.task.name == 'GS integration'\n        uses: \"google-github-actions/auth@v1\"\n        with:\n          workload_identity_provider: \"projects/10554368204/locations/global/workloadIdentityPools/tango-ci-pool/providers/tango-ci-provider\"\n          service_account: \"tango-service@ai2-allennlp.iam.gserviceaccount.com\"\n\n      - name: Pre-install torch\n        if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'torch') || contains(matrix.task.extras, 'all') || matrix.task.requires_torch)\n        run: |\n          . .venv/bin/activate\n          pip install torch==2.0.0 --extra-index-url https://download.pytorch.org/whl/cpu\n\n      - name: Pre-install flax\n        if: steps.virtualenv-cache.outputs.cache-hit != 'true' && (contains(matrix.task.extras, 'flax') || contains(matrix.task.extras, 'all'))\n        run: |\n          . .venv/bin/activate\n          pip install flax jax jaxlib \"tensorflow-cpu>=2.9.1\" optax\n\n      - name: Install editable (no cache hit)\n        if: steps.virtualenv-cache.outputs.cache-hit != 'true'\n        run: |\n          . .venv/bin/activate\n          pip install -e .[${{ matrix.task.extras }}]\n\n      - name: Install editable (cache hit)\n        if: steps.virtualenv-cache.outputs.cache-hit == 'true'\n        run: |\n          . .venv/bin/activate\n          pip install --no-deps -e .[${{ matrix.task.extras }}]\n\n      - name: Show environment info\n        run: |\n          . .venv/bin/activate\n          echo \"========= Python location ===========\"\n          which python\n          echo \"========= Python version ============\"\n          python --version\n          echo \"========= Python packages ===========\"\n          pip freeze\n          echo \"========= Tango installation ========\"\n          tango info\n\n      - name: ${{ matrix.task.name }}\n        run: |\n          . .venv/bin/activate\n          ${{ matrix.task.run }}\n\n      - name: Upload package distribution files\n        if: matrix.task.name == 'Build' && matrix.python == env.DEFAULT_PYTHON\n        uses: actions/upload-artifact@v3\n        with:\n          name: package\n          path: dist\n\n      - name: Upload docs build\n        if: matrix.task.name == 'Docs' && matrix.python == env.DEFAULT_PYTHON\n        uses: actions/upload-artifact@v3\n        with:\n          name: docs\n          path: docs/build\n\n      - name: Clean up\n        if: always()\n        run: |\n          . .venv/bin/activate\n          pip uninstall -y ai2-tango\n\n  gpu_tests:\n    name: GPU Tests\n    runs-on: ubuntu-latest\n    steps:\n      - name: Determine current commit SHA (pull request)\n        if: github.event_name == 'pull_request'\n        run: |\n          echo \"COMMIT_SHA=${{ github.event.pull_request.head.sha }}\" >> $GITHUB_ENV\n\n      - name: Determine current commit SHA (push)\n        if: github.event_name != 'pull_request'\n        run: |\n          echo \"COMMIT_SHA=$GITHUB_SHA\" >> $GITHUB_ENV\n\n      - name: GPU Tests\n        uses: allenai/beaker-run-action@v1.2\n        with:\n          spec: |\n            version: v2\n            description: GPU Tests\n            budget: ai2/oe-training\n            tasks:\n              - name: tests\n                image:\n                  beaker: ${{ env.BEAKER_IMAGE }}\n                context:\n                  preemptible: true \n                resources:\n                  gpuCount: 2\n                envVars:\n                  - name: COMMIT_SHA\n                    value: ${{ env.COMMIT_SHA }}\n                command: [\"/entrypoint.sh\", \"pytest\", \"-v\", \"-m\", \"gpu\", \"tests/\"]\n                result:\n                  path: /unused\n          token: ${{ secrets.BEAKER_TOKEN }}\n          workspace: ${{ env.BEAKER_WORKSPACE }}\n          clusters: ai2/general-cirrascale,ai2/allennlp-cirrascale,ai2/aristo-cirrascale,ai2/mosaic-cirrascale,ai2/s2-cirrascale\n\n  release:\n    name: Release\n    runs-on: ubuntu-latest\n    needs: [gpu_tests, checks]\n    if: startsWith(github.ref, 'refs/tags/')\n    steps:\n      - uses: actions/checkout@v3\n        with:\n          fetch-depth: 0\n\n      - name: Setup Python\n        uses: actions/setup-python@v4\n        with:\n          python-version: ${{ env.DEFAULT_PYTHON }}\n\n      - name: Install requirements\n        run: |\n          pip install -e .[dev]\n\n      - name: Prepare environment\n        run: |\n          echo \"RELEASE_VERSION=${GITHUB_REF#refs/tags/v}\" >> $GITHUB_ENV\n          echo \"TAG=${GITHUB_REF#refs/tags/}\" >> $GITHUB_ENV\n\n      - name: Download package distribution files\n        uses: actions/download-artifact@v3\n        with:\n          name: package\n          path: dist\n\n      - name: Generate release notes\n        run: |\n          python scripts/release_notes.py > ${{ github.workspace }}-RELEASE_NOTES.md\n\n      - name: Publish package to PyPI\n        run: |\n          twine upload -u __token__ -p ${{ secrets.PYPI_PASSWORD }} dist/*\n\n      - name: Publish GitHub release\n        uses: softprops/action-gh-release@v1\n        with:\n          body_path: ${{ github.workspace }}-RELEASE_NOTES.md\n          prerelease: ${{ contains(env.TAG, 'rc') }}\n          files: |\n            dist/*\n"
  },
  {
    "path": ".github/workflows/update_dependency_pr.yml",
    "content": "name: Update dependency PR\n\non:\n  pull_request:\n    types:\n      - opened\n    paths:\n      - \"pyproject.toml\"\n\npermissions:\n  pull-requests: write\n\njobs:\n  torch:\n    name: torch\n    runs-on: ubuntu-latest\n    if: startsWith(github.head_ref, 'dependabot/pip/torch-')\n    steps:\n      - uses: actions/github-script@v6\n        with:\n          script: |\n            github.rest.issues.createComment({\n              issue_number: context.issue.number,\n              owner: context.repo.owner,\n              repo: context.repo.repo,\n              body: 'Hello! This is a [PyTorch](https://pytorch.org/) upgrade, which means you will also need to update:\\n- [ ] The base image in `Dockerfile`\\n- [ ] The base image in `Dockerfile.test`\\n- [ ] The torch version hard-coded in `.github/workflows/main.yml`'\n            })\n"
  },
  {
    "path": ".gitignore",
    "content": "# build artifacts\n\n.eggs/\n.mypy_cache\nai2_tango.egg-info/\nbuild/\ndist/\npip-wheel-metadata/\nruns/\nworkspace/\n\n# dev tools\n\n.envrc\n.python-version\n.idea\n.venv/\n.vscode/\n/*.iml\n\n\n# jupyter notebooks\n\n.ipynb_checkpoints\n\n\n# miscellaneous\n\n.cache/\ndoc/_build/\n*.swp\n.DS_Store\n\n\n# python\n\n*.pyc\n*.pyo\n__pycache__\n\n\n# testing and continuous integration\n\n.coverage\n.pytest_cache/\n.benchmarks\n\n# documentation build artifacts\n\ndocs/build\nsite/\n\n# internal experiment configs\n*-internal.jsonnet\n*-internal.json\n*-internal.yaml\n*-internal.yml\n"
  },
  {
    "path": ".readthedocs.yaml",
    "content": "version: 2\n\nsphinx:\n  configuration: docs/source/conf.py\n  fail_on_warning: true\n\nbuild:\n  os: ubuntu-22.04\n  tools:\n    python: \"3.10\"\n\npython:\n  install:\n    - method: pip\n      path: .\n      extra_requirements:\n        - dev\n        - all\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\n\nAll notable changes to this project will be documented in this file.\n\nThe format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),\nand this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).\n\n## Unreleased\n\n### Fixed\n\n- Fixed a bunch of dependencies\n- Upgraded to new version of wandb\n\n## [v1.3.2](https://github.com/allenai/tango/releases/tag/v1.3.2) - 2023-10-27\n\n### Fixed\n\n- Fix issues with gcloud auth in beaker executor.\n\n## [v1.3.1](https://github.com/allenai/tango/releases/tag/v1.3.1) - 2023-10-25\n\n### Fixed\n\n- Minor bugs in the `GSWorkspace()`.\n\n### Changed\n\n- Added CLI-style execution functions for experiments defined in Python.\n- Added `display()` to `ExecutorOutput` for producing a table that summarizes the run.\n\n## [v1.3.0](https://github.com/allenai/tango/releases/tag/v1.3.0) - 2023-10-13\n\n### Added\n - Added the `Workspace.remove_step()` method to safely remove steps.\n- The `GSWorkspace()` can now be initialized with google cloud bucket subfolders.\n\n### Changed\n\n- The `BeakerExecutor` now uses the HEAD commit at the time the executor is instantiated to executor a step instead of the HEAD commit at the time the step is run.\n\n### Fixed\n\n- Removed unnecessary code coverage dev requirements.\n- Fixed issue where new version of torch caused no LR schedulers to be registered.\n- Updated pinned versions of jax, jaxlib, and flax.\n\n## [v1.2.1](https://github.com/allenai/tango/releases/tag/v1.2.1) - 2023-04-06\n\n### Added\n\n- Added the following workspace methods to support the Tango viz UI: `Workspace.search_registered_runs()`, `Workspace.search_step_info()`, `Workspace.num_registered_runs()`, and `Workspace.num_steps()`.\n\n### Fixed\n\n- Fixes a bug where `FromParams` would fail to parse when an object takes a `Step` argument directly.\n- Changed a name so we don't override the built-in name `set`.\n- Fixed a bug that would cause O(n^2) memory consumption in dense step graphs.\n\n\n## [v1.2.0](https://github.com/allenai/tango/releases/tag/v1.2.0) - 2023-02-10\n\n### Added\n\n- You can now add arguments to steps without invalidating the cache. See `Step.SKIP_DEFAULT_ARGUMENTS`.\n- Fixed integration status messages in `tango info` command.\n- Added abstractions for `RemoteClient`, `RemoteStepCache`, and `RemoteWorkspace`.\n- Added a GS integration that comes with `GSWorkspace`, a remote `Workspace` implementation that uses google cloud storage.\n- You can now bind functional steps to the underlying `Step` instance with `@step(bind=True)`, meaning the first argument to the function will be a `Step`.\n- Added `ShellStep` for running arbitrary shell commands.\n- Added `@make_registrable` decorator to make arbitrary functions registrable, to make it easier to refer to them in tango configurations.\n\n### Fixed\n\n- Jsonnet parsing is now much faster and works on Windows.\n- Warnings about locks are now reliably printed every 30 seconds\n- We now make sure Beaker jobs have the latest version of beaker-py, so that we're compatible with the latest API changes.\n- Stopping early now works when the metric doesn't change at all.\n- Fixed bug with `FromParams` which didn't handle variable length tuples correctly.\n\n### Changed\n\n- The default log level for Tango is now `warning`.\n- You can specify multiple steps with `-s` from the `tango run` command.\n\n\n## [v1.1.0](https://github.com/allenai/tango/releases/tag/v1.1.0) - 2022-12-01\n\n### Added\n\n- Added `gpu_type` field to `StepResources`. The `BeakerExecutor` can use this to determine which clusters to a submit a step to.\n- Added `machine` field to `StepResources`. You can set this to \"local\" when using the `BeakerExecutor` to force it to run the step locally.\n- Added `--ext-var` argument to `tango run` for setting JSONNET external variables\n  when loading the experiment config.\n- Added `@step()` decorator to create `Step` classes from functions.\n- Added the `transformers::with_soft_prompt` integration, to make soft-prompted prefix transformers easy.\n\n### Removed\n\n- Removed PyTorch Lightning integration.\n- Removed `tango server` command and `--serve/--no-serve` option for `tango run`.\n- Removed `source_release.py`, which was checked in by accident.\n\n### Fixed\n\n- Fixed issue where Executor `parallelism` option in a Tango settings file would be ignored.\n- Fixed a bug where the unique ID of a step that depends on a key-value of the result of another step could change if the name of the other step changes.\n- Fixed a bug where importing certain libraries (like torchmetrics) would mess with our exception handling because they set `sys.excepthook` for some reason. Now we always reset `sys.excepthook` after importing.\n- The type hints for the flax trainer suggested that the training split is optional when in fact it's mandatory.\n- Made `BeakerWorkspace` / `BeakerStepLock` more robust when a job is preempted.\n- Minor performance improvements for the Beaker executor and workspace.\n- Fixed bug with `step_extra_dependencies` where uncacheable dependencies wouldn't be run.\n\n\n## [v1.0.2](https://github.com/allenai/tango/releases/tag/v1.0.2) - 2022-11-14\n\n### Changed\n\n- `BeakerScheduler` can now return a list of clusters.\n\n## [v1.0.1](https://github.com/allenai/tango/releases/tag/v1.0.1) - 2022-10-20\n\n### Fixed\n\n- `LightningTrainStep` now can take a `Lazy` model object which results in a gauranteed deterministic hash.\n- Fixed issue where remote `Workspace` implementations like `WandbWorkspace` and `BeakerWorkspace` would use the same local cache regardless of the W&B / Beaker workspace\n  being used.\n- Fixed bug with `TorchEvalStep` when constructing callbacks.\n- Fixed some import error issues caused when an integration is not installed.\n- Fix incorrect reporting of final results in `MulticoreExecutor`.\n\n### Changed\n\n- Wandb step cache retries api call in case of timeout\n- `beaker-py >= 1.11` required.\n\n## [v1.0.0](https://github.com/allenai/tango/releases/tag/v1.0.0) - 2022-10-05\n\n### Added\n\n- Added `step_extra_dependencies` input field to `Step` class that can be used to force a dependency on another step even if the current step doesn't directly depend on the output of the other step. See [#418](https://github.com/allenai/tango/issues/418) for more context.\n\n### Changed\n\n- `beaker-py >= 1.10` required.\n\n### Fixed\n\n- Long log lines will be soft-wrapped to ensure that links are clickable.\n- Fixed a bug where some workspaces could be left in a bad state if a step's `Format` failed to serialize the step's result in `Workspace.step_finished()`.\n- Sometimes functions and methods end up as arguments to steps, which means we have to hash them. Instead of taking\n  a hash of the function, we now take a hash of the function's module and name.\n- Fixed a bug with the Beaker executor where it would hang at the end of a run if a step failed that is a dependency of another step.\n- Fixed tests to work with new version of transformers.\n- Fixed `Executor.execute_sub_graph_for_step()` to be able to run the step's dependencies in parallel.\n\n\n## [v0.14.0](https://github.com/allenai/tango/releases/tag/v0.14.0) - 2022-09-20\n\n### Added\n\n- Adds a function to modify a Hugging Face transformer with IA3 adaptors\n- Added a `BeakerScheduler` registrable class, specified as the argument `scheduler` to `BeakerExecutor`, which controls the resources assigned to steps ran on Beaker.\n  Users can implement their own `BeakerScheduler` subclasses to customize the resource assignment behavior.\n\n### Changed\n\n- In the `tango run` command, `--no-server` is now the default. Use `--server` to start the server.\n\n### Fixed\n\n- Made `BeakerExecutor` more robust to connection, timeout, SSL, and other recoverable HTTP errors.\n- Made the `BeakerStepLock` more robust, and as a result `BeakerWorkspace` is more\n  robust and should require less manual intervention for locks in a bad state.\n- Fixed a bug with the internal scheduling logic of the `BeakerExecutor` which\n  could delay submitting some steps in parallel.\n- Fixed a bug where creating a `StepInfo` object from params might result in unnecessary imports.\n- Fixed a bug where canceling the Beaker executor might not work properly.\n- Fixed a bug where the trainer trains too much when `train_epochs` is set and you're using gradient accumulation.\n- Fixed a bug where included modules might not be found when using multiprocessing when they're not on `sys.path` / `PYTHONPATH`.\n- Fixed how the results of uncacheable steps are displayed by `tango run`.\n- Beaker executor won't run duplicate cacheable steps at the same time.\n\n## [v0.13.0](https://github.com/allenai/tango/releases/tag/v0.13.0) - 2022-09-07\n\n### Added\n\n- You can now reference into a particular index of the result of another step in a config. For example: `{type: \"ref\", ref: \"some_previous_step\", key: 0}`.\n  The key field can be an integer if the result of the referenced step is a list or tuple, or a string if the result of the referenced step is a dictionary.\n- Added `priority` parameter to Beaker executor for setting the default task priority for Beaker jobs.\n- Added `Workspace.step_result()` method for getting a step's result from the latest\n  run.\n- `tango run` will now display a URL to the logs for failed steps when you use the `BeakerExecutor`.\n\n### Changed\n\n- The `TorchTrainStep` now enables monitoring arbitrary model outputs during training. `TorchTrainEngine.forward_train` now returns a tuple `loss, model_outputs` for each micro batch and the list of model outputs for all micro batches in a batch is passed to the `TrainCallback.log_batch` and `TrainCallback.post_batch`.\n- Tango will now automatically search Python modules in the current working directory\n  for registered classes so that you don't always need to use the `--include-package` setting.\n- The minimum supported Python version is now 3.8.\n- Added support for PyTorch Lightning 1.7.x\n- The Beaker Executor will no-longer live-stream logs from Beaker jobs, but logs will be viewable on Beaker and more readable.\n- Only the Beaker executor requires a clean working directory\n\n### Fixed\n\n- Fixed a bug that did not allow a wandb artifact's type to be set from a step's metadata dictionary. \n- Fixed a bug with how the Beaker executor streams log lines from Beaker which sometimes resulted in messages missing some starting characters, and tqdm lines being duplicated.\n- Fixed a bug in the Beaker workspace where the lock dataset wouldn't be removed if the step\n  was found to be in an invalid state.\n- Improved cluster choice logic in `BeakerExecutor` to ensure greater diversity of clusters when submitting many steps at once.\n- Fixed bug where sub-processes of the multicore executor would use the wrong executor if `executor` was defined in a `tango.yml` file.\n- Deterministic hashes for numpy and torch tensors were not deterministic. Now they are.\n\n\n## [v0.12.0](https://github.com/allenai/tango/releases/tag/v0.12.0) - 2022-08-23\n\n### Added\n\n- **Step resources:**\n  - Added a `step_resources` parameter to the `Step` class which should be used to describe the computational resources required to run a step.\n    `Executor` implementations can use this information. For example, if your step needs 2 GPUs, you should set\n    `step_resources=StepResources(gpu_count=2)` (`\"step_resources\": {\"gpu_count\": 2}` in the configuration language).\n  - Added a `Step.resources()` property method. By default this returns the value specified by the `step_resources` parameter.\n    If your step implementation always requires the same resources, you can just override this method so you don't have to provide\n    the `step_resources` parameter.\n- **Step execution:**\n  - Added an `executor` field to the `tango.yml` settings. You can use this to define the executor you want to use by default.\n  - Added a Beaker `Executor` to the Beaker integration, registered as an `Executor` with the name \"beaker\".\n    To use this executor, add these lines to your `tango.yml` file:\n    ```yaml\n    executor:\n      type: beaker\n      beaker_workspace: ai2/my-workspace\n      clusters:\n        - ai2/general-cirrascale\n    ```\n    See the docs for the `BeakerExecutor` for more information on the input parameters.\n- **Step class:**\n  - Added a metadata field to the step class API. This can be set through the class\n    variable `METADATA` or through the constructor argument `step_metadata`.\n- **Weights & Biases integration:**\n  - You can now change the artifact kind for step result artifacts by adding a field\n    called \"artifact_kind\" to a step's metadata.\n    For models, setting \"artifact_kind\" to \"model\" will add the corresponding artifact to W&B's new model zoo.\n\n### Changed\n\n- **CLI:**\n  - The `tango run` command will throw an error if you have uncommitted changes in your repository, unless\n    you use the `--allow-dirty` flag.\n  - The `tango run` command will use the lightweight base executor (single process) by default.\n    To use the multi-process executor, set `-j/--parallelism` to 1 or higher or -1 to use all available CPU cores.\n\n### Fixed\n\n- Fixed bug where `StepInfo` environment and platform metadata could be out-of-date if a step is run again due to failure.\n- Fixed a bug where an unfortunate combination of early stopping and decreasing model performance could result in a crash in the torch trainer.\n\n## [v0.11.0](https://github.com/allenai/tango/releases/tag/v0.11.0) - 2022-08-04\n\n### Added\n\n- Added a [Flax](https://flax.readthedocs.io/en/latest/) integration along with an example config.\n\n## [v0.10.1](https://github.com/allenai/tango/releases/tag/v0.10.1) - 2022-07-26\n\n### Fixed\n\n- Fixed issue where the StepInfo config argument could be parsed into a Step. \n- Restored capability to run tests out-of-tree.\n\n## [v0.10.0](https://github.com/allenai/tango/releases/tag/v0.10.0) - 2022-07-07\n\n### Changed\n\n- Renamed `workspace` parameter of `BeakerWorkspace` class to `beaker_workspace`.\n- `Executor` class is now a `Registrable` base class. `MulticoreExecutor` is registered as \"multicore\".\n\n### Removed\n\n- Removed `StepExecutionMetadata`. Its fields have been absorbed into `StepInfo`.\n\n### Fixed\n\n- Improved `Step.ensure_result()` such that the step's result doesn't have to be read from the cache.\n- Fixed an issue with the output from `MulticoreExecutor` such that it's now consistent with the default `Executor` for steps that were found in the cache.\n- One of our error messages referred to a configuration file that no longer exists.\n- Improved performance of `BeakerWorkspace`.\n\n### Added\n\n- Added the ability to train straight `Model` instead of just `Lazy[Model]`\n\n\n## [v0.9.1](https://github.com/allenai/tango/releases/tag/v0.9.1) - 2022-06-24\n\n### Fixed\n\n- Fixed non-deterministic behavior in `TorchTrainStep`.\n- Fixed bug in `BeakerWorkspace` where `.step_info(step)` would raise a `KeyError` if the step hasn't been registered as part of a run yet.\n- Fixed a bug in `BeakerWorkspace` where it would send too many requests to the beaker service.\n- Fixed a bug where `WandbWorkspace.step_finished()` or `.step_failed()` would crash if called\n  from a different process than `.step_starting()`.\n- Fixed a bug in `WandbWorkspace.step_finished()` which led to a `RuntimeError` sometimes while\n  caching the result of a step.\n\n\n## [v0.9.0](https://github.com/allenai/tango/releases/tag/v0.9.0) - 2022-06-01\n\n### Added\n\n- Added a [Beaker](https://beaker.org) integration that comes with `BeakerWorkspace`, a remote `Workspace` implementation that uses Beaker Datasets under the hood.\n- Added a `datasets::dataset_remix` step that provides the split remixing functionality of `tango.steps.datasest_remix.DatasetRemixStep` now for Huggingface `DatasetDict`.\n- Added a config and code example of Registrable to the First Step docs with edits for clarity.\n\n### Changed\n\n- If you try to import something from a tango integration that is not fully installed due to missing dependencies, an `IntegrationMissingError` will be raised\ninstead of `ModuleNotFound`.\n- You can now set `-j 0` in `tango run` to disable multicore execution altogether.\n\n### Fixed\n\n- Improved how steps and workspaces handle race conditions when different processes are competing to execute the same step. This would result in a `RuntimeError` before with most workspaces, but now it's handled gracefully.\n- Fixed bug which caused GradScaler state to not be saved and loaded with checkpoints. \n\n## [v0.8.0](https://github.com/allenai/tango/releases/tag/v0.8.0) - 2022-05-19\n\n### Added\n\n- Added a Weights & Baises remote `Workspace` implementation: `WandbWorkspace`, registered as \"wandb\".\n  This can be instantiated from a workspace URL in the form \"wandb://entity/project\".\n- Added a method `Workspace.step_result_for_run` which gives the result of a step given the run name and step name within that run.\n- Added property `Workspace.url`, which returns a URL for the workspace that can be used to instantiate the exact same workspace using `Workspace.from_url()`. Subclasses must implement this.\n\n### Changed\n\n- `StepInfo` start and end times will be always be in UTC now.\n- `WandbTrainCallback` now logs system metrics from each worker process in distributed training.\n- `StepCache.__contains__()` and `StepCache.__getitem__()` now take accept either a `Step` or `StepInfo` as an argument (`Union[Step, StepInfo]`).\n- Refactored `tango.step_graph.StepGraph` to allow initialization from a `Dict[str, Step]`.\n- `Executor.execute_step_graph()` now attempts to execute all steps and summarizes success/failures.\n\n### Fixed\n\n- Fixed bug with `LocalWorkspace.from_parsed_url()` ([#278](https://github.com/allenai/tango/issues/278)).\n- Deprecation warnings will now be logged from `tango` CLI.\n- Fixed the text format in the case of serializing an iterator of string.\n- Added missing default value of `None` to `TangoGlobalSettings.find_or_default()`.\n- Mypy has become incompatible with transformers and datasets, so we have to disable the checks in some places.\n- The `VERSION` member of step arguments that were wrapped in `Lazy` were not respected. Now they are.\n\n\n## [v0.7.0](https://github.com/allenai/tango/releases/tag/v0.7.0) - 2022-04-19\n\n### Added\n\n- Added the \"-n/--name\" option to `tango run`. This option allows the user to give the run an arbitrary name.\n- Added a convenience property `.workspace` to `Step` class that can be called from a step's `.run()` method to get the current `Workspace` being used.\n- Gave `FromParams` objects (which includes all `Registrable` objects) the ability to version themselves.\n- Added CLI option to run a single step in a config using `--step-name` or `-s`.\n- Added a `MultiCoreExecutor` that executes steps in parallel.\n- Added an `ExecutorOutput` dataclass that is returned by `Executor.execute_step_graph()`.\n- `StepGraph` now prints itself in a readable way.\n- Tango now automatically detects when it's running under a debugger, and disables multicore support accordingly. Many debuggers can't properly follow sub-processes, so this is a convenience for people who love debuggers.\n- Added more models to the stuff we can import from the transformers library.\n- Added new example for finetuning text-to-text models.\n\n### Changed\n\n- Renamed `click_logger` to `cli_logger`, and we now use [rich](https://github.com/Textualize/rich)'s logging `Handler` as the default handler, which means prettier output, better tracebacks, and you can use rich's markup syntax with the `cli_logger` to easily add style to text.\n- Refactored `tango.step_graph.StepGraph` to allow initialization from a `Dict[str, Step]`.\n- `Executor.execute_step_graph()` now attempts to execute all steps and summarizes success/failures.\n- Upgraded PyTorch version in `tango` Docker image to latest `v1.11.0+cu113`.\n- `RunGeneration` now allows model object as input.\n\n### Fixed\n\n- Fixed bug that mistakenly disallowed fully-qualified names containing `\"_\"` (underscores) in the config.\n- Fixed bug where `TorchTrainStep` working directory would be left in an unrecoverable state if training failed after saving the final model weights.\n- Fixed bug in `FromParams` where `**kwargs` might be passed down to the constructors of arguments.\n- Fixed bug in the way dependencies are tracked between steps.\n- Fixed bug that caused `MulticoreExecutor` to hang in case of a failing step that was required recursively (not directly) downstream.\n- Fixed bug in the way dependencies are tracked between steps\n- Compatibility with PyTorch Lightning 1.6\n\n\n## [v0.6.0](https://github.com/allenai/tango/releases/tag/v0.6.0) - 2022-02-25\n\n### Added\n\n- New example that finetunes a pre-trained ResNet model on the Cats & Dogs dataset.\n- Added a '@requires_gpus' decorator for marking tests as needing GPUs. Tests marked with this will be run in the \"GPU Tests\" workflow\n  on dual k80 GPUs via Beaker.\n- Added the \"-w/--workspace\" option to `tango run` and `tango server` commands. This option takes a path or URL, and instantiates the workspace from the URL using the newly added `Workspace.from_url()` method.\n- Added the \"workspace\" field to `TangoGlobalSettings`.\n- Added the \"environment\" field to `TangoGlobalSettings` for setting environment variables each\n  time `tango` is run.\n- Added a utility function to get a `StepGraph` directly from a file.\n- Added `tango.settings` module and `tango settings` group of commands.\n- A format for storing sequences as `SqliteSparseSequence`\n- A way to massage kwargs before they determine the unique ID of a `Step`\n\n### Changed\n\n- `local_workspace.ExecutorMetadata` renamed to `StepExecutionMetadata` and now saved as `execution-metadata.json`.\n- `tango run` without the option \"-w/--workspace\" or \"-d/--workspace-dir\" will now use a `MemoryWorkspace` instead of a `LocalWorkspace` in a temp directory, unless you've specified\n  a default workspace in a `TangoGlobalSettings` file.\n- Moved `tango.workspace.MemoryWorkspace` and `tango.local_workspace.LocalWorkspace` to `tango.workspaces.*`.\n- Moved `tango.step_cache.MemoryStepCache` and `tango.step_cache.LocalStepCache` to `tango.step_caches.*`.\n- Deprecated the `-d/--workspace-dir` command-line option. Please use `-w/--workspace` instead.\n\n### Fixed\n\n- Fixed a small bug `LocalWorkspace` would fail to capture the conda environment in our Docker image.\n- Fixed activation of `FILE_FRIENDLY_LOGGING` when set from the corresponding environment variable.\n- Fixed setting log level via the environment variable `TANGO_LOG_LEVEL`.\n- Use relative paths within the `work_dir` for symbolic links to the latest and the best checkpoints in `TorchTrainStep`.\n- Fixed some scenarios where Tango can hang after finishing all steps.\n- `distributed_port` and `log_every` parameters won't factor into `TorchTrainStep`'s unique ID.\n- `MappedSequence` now works with slicing.\n- `MappedSequence` now works with Huggingface `Dataset`.\n- Uncacheable steps are now visible in Tango UI.\n- Fixed bug in `Registrable.list_available()` where an error might be raised if the default implementation hadn't been explicitly imported.\n- Fixed issue where having a default argument to the `run()` method wasn't getting applied to the step's unique ID.\n\n\n## [v0.5.0](https://github.com/allenai/tango/releases/tag/v0.5.0) - 2022-02-09\n\n### Added\n\n- Added `TrainingEngine` abstraction to torch integration.\n- Added [FairScale](https://fairscale.readthedocs.io/en/latest/) with a `FairScaleTrainingEngine`\n  that leverages FairScale's `FullyShardedDataParallel`. This is meant to be used within the `TorchTrainStep`.\n- All PyTorch components (such as learning rate schedulers, optimizers, data collators, etc) from the\n  transformers library and now registered under the corresponding class in the torch integration.\n  For example, transformers `Adafactor` optimizer is registered as an `Optimizer` under the name\n  \"transformers::Adafactor\". More details can be found in the documentation for the transformers integration.\n\n### Changed\n\n- Various changes to the parameters othe `TorchTrainStep` due to the introduction of the `TrainingEngine` class.\n- Params logged as `DEBUG` level instead of `INFO` to reduce noise in logs.\n- The waiting message for `FileLock` is now clear about which file it's waiting for.\n- Added an easier way to get the default Tango global config\n- Most methods to `TorchTrainCallback` also take an `epoch` parameter now.\n- `WandbTrainCallback` now logs peak GPU memory occupied by PyTorch tensors per worker. This is useful because W&B's system metrics only display the total GPU memory reserved by PyTorch, which is always higher than the actual amount of GPU memory occupied by tensors. So these new metrics give a more accurate view into how much memory your training job is actually using.\n- Plain old Python functions can now be used in `Lazy` objects.\n- `LocalWorkspace` now creates a symlink to the outputs of the latest run.\n- Tango is now better at guessing when a step has died and should be re-run.\n- Tango is now more lenient about registering the same class under the same name twice.\n- When you use `dict` instead of `Dict` in your type annotations, you now get a legible error message. Same for `List`, `Tuple`, and `Set`.\n\n### Fixed\n\n- Fixed a bug in `Registrable` and `FromParams` where registered function constructors would not properly construct\n  arguments that were classes.\n- Fixed a bug in `FromParams` that would cause a crash when an argument to the constructor had the name `params`.\n- Made `FromParams` more efficient by only trying to parse the params as a `Step` when it looks like it actually could be a step.\n- Fixed bug where `Executor` would crash if `git` command could not be found.\n- Fixed bug where validation settings were not interpreted the right way by the torch trainer.\n- When you register the same name twice using `Registrable`, you get an error message. That error message now contains the correct class name.\n\n\n## [v0.4.0](https://github.com/allenai/tango/releases/tag/v0.4.0) - 2022-01-27\n\n### Changed\n\n- Default log level is `WARNING` instead of `ERROR`.\n- The web UI now renders the step graph left-to-right.\n- The web UI now shows runs by date, with the most recent run at the top.\n- The web UI now shows steps in a color-coded way.\n- The `tango run` command now prints user-friendly paths if possible.\n- The `--include-package` flag now also accepts paths instead of module names.\n- `tango.common.sqlite_sparse_sequence.SqliteSparseSequence` now lives at `tango.common.sequences.SqliteSparseSequence`.\n\n### Fixed\n\n- Ensure tqdm log lines always make it into the log file `out.log` even when log level is `WARNING` or `ERROR`.\n- Numerous parts of Tango now have documentation when they didn't before.\n\n\n## [v0.4.0rc5](https://github.com/allenai/tango/releases/tag/v0.4.0rc5) - 2022-01-19\n\n### Added\n\n- Added `TorchEvalStep` to torch integration, registered as \"torch::eval\".\n\n### Changed\n\n- Renamed `aggregate_val_metric` to `auto_aggregate_val_metric` in `TorchTrainStep`.\n- `devices` parameter to `TorchTrainStep` replaced with `device_count: int`.\n- Run name printed at the end of a run so it's easier to find.\n- Type information added to package data. See [PEP 561](https://www.python.org/dev/peps/pep-0561) for more information.\n- A new integration, `transformers`, with two new steps for running seq2seq models.\n- Added `logging_tqdm`, if you don't want a progress bar, but you still want to see progress in the logs.\n- Added `threaded_generator()`, for wrapping generators so that they run in a separate thread from the generator's consumer.\n- Added a new example for evaluating the T0 model on XSum, a summarization task.\n- Added `MappedSequence` for functionally wrapping sequences.\n- Added `TextFormat`, in case you want to store the output of your steps in raw text instead of JSON.\n- Steps can now list arguments in `SKIP_ID_ARGUMENTS` to indicate that the argument should not affect a step's\n  unique id. This is useful for arguments that affect the execution of a step, but not the output.\n- `Step` now implements `__str__`, so steps look pretty in the debugger.\n- Added `DatasetCombineStep`, a step that combines multiple datasets into one.\n- Added `common.logging.initialize_worker_logging()` function for configuring logging from worker processes/threads.\n- Logs from `tango run ...` will be written to a file called `out.log` in the run directory.\n\n### Fixed\n\n- Fixed torch `StopEarlyCallback` state not being recovered properly on restarts.\n- Fixed file friendly logging by removing special styling characters.\n- Ensured exceptions captured in logs.\n- `LocalWorkspace` now works properly with uncacheable steps.\n- When a Tango run got killed hard, with `kill -9`, or because the machine lost power, `LocalWorkspace` would\n  sometimes keep a step marked as \"running\", preventing further executions. This still happens sometimes, but it\n  is now much less likely (and Tango gives you instructions for how to fix it).\n- To make all this happen, `LocalWorkspace` now saves step info in a Sqlite database. Unfortunately that means that\n  the workspace format changes and existing workspace directories won't work properly with it.\n- Fixed premature cleanup of temporary directories when using `MemoryWorkspace`\n\n\n## [v0.4.0rc4](https://github.com/allenai/tango/releases/tag/v0.4.0rc4) - 2021-12-20\n\n### Fixed\n\n- Fixed a bug where `StepInfo` fails to deserialize when `error` is an exception that can't be pickled.\n\n\n## [v0.4.0rc3](https://github.com/allenai/tango/releases/tag/v0.4.0rc3) - 2021-12-15\n\n### Added\n\n- Added `DatasetsFormat` format and `LoadStreamingDataset` step to `datasets` integration.\n- `SqliteDictFormat` for datasets.\n- Added `pre_epoch()` and `post_epoch()` callback methods to PyTorch `TrainCallback`.\n\n### Changed\n\n- `LoadDataset` step from `datasets` integration is now cacheable, using the `DatasetsFormat` format by default.\n  But this only works with non-streaming datasets. For streaming datasets, you should use the `LoadStreamingDataset` step instead.\n\n### Fixed\n\n- Fixed bug where `KeyboardInterrupt` exceptions were not handled properly by steps and workspaces.\n- `WandbTrainCallback` now will use part of the step's unique ID as the name for the W&B run by default, to make\n  it easier to indentify which tango step corresponds to each run in W&B.\n- `WandbTrainCallback` will save the entire `TrainConfig` object to the W&B config.\n\n\n## [v0.4.0rc2](https://github.com/allenai/tango/releases/tag/v0.4.0rc2) - 2021-12-13\n\n### Added\n\n- Sample experiment configurations that prove Euler's identity\n\n### Changed\n\n- Loosened `Click` dependency to include v7.0.\n- Loosened `datasets` dependency.\n- Tightened `petname` dependency to exclude next major release for safety.\n\n### Fixed\n\n- `Workspace`, `MemoryWorkspace`, and `LocalWorkspace` can now be imported directly from the `tango`\n  base module.\n- Uncacheable leaf steps would never get executed. This is now fixed.\n- We were treating failed steps as if they were completed by accident.\n- The visualization had a problem with showing steps that never executed because a dependency failed.\n- Fixed a bug where `Lazy` inputs to a `Step` would fail to resolve arguments that come from the result\n  of another step.\n- Fixed a bug in `TorchTrainStep` where some arguments for distributed training (`devices`, `distributed_port`) weren't being set properly.\n\n\n## [v0.4.0rc1](https://github.com/allenai/tango/releases/tag/v0.4.0rc1) - 2021-11-30\n\n### Added\n\n- Introduced the concept of the `Workspace`, with `LocalWorkspace` and `MemoryWorkspace` as initial implementations.\n- Added a stub of a webserver that will be able to visualize runs as they happen.\n- Added separate classes for `LightningTrainingTypePlugin`, `LightningPrecisionPlugin`, `LightningClusterEnvironmentPlugin`, `LightningCheckpointPlugin` for compatibility with `pytorch-lightning>=1.5.0`.\n- Added a visualization of workspaces that can show step graphs while they're executing.\n\n### Removed\n\n- Removed old `LightningPlugin` class\n- Removed requirement of the `overrides` package\n\n### Changed\n\n- Made it possible to construct a step graph out of `Step` objects, instead of constructing it out of `StepStub` objects.\n- Removed dataset fingerprinting code, since we can now use `Step` to make sure things are cached.\n- Made steps deterministic by default.\n- Brought back `MemoryStepCache`, so we can run steps without configuring anything.\n- W&B `torch::TrainCallback` logs with `step=step+1` now so that training curves in the W&B dashboard\n  match up with checkpoints saved locally and are easier to read (e.g. step 10000 instead of 9999).\n- `filelock >= 3.4` required, parameter `poll_intervall`  to `tango.common.file_lock.FileLock.acquire` renamed\n  to `poll_interval`.\n\n### Fixed\n\n- Fixed bug in `FromParams` where a parameter to a `FromParams` class may not be instantiated correctly\n  if it's a class with a generic type parameter.\n\n## [v0.3.6](https://github.com/allenai/tango/releases/tag/v0.3.6) - 2021-11-12\n\n### Added\n\n- Added a `.log_batch()` method on `torch::TrainCallback` which is given the average loss across\n  distributed workers, but only called every `log_every` steps.\n\n### Removed\n\n- Removed `.pre_log_batch()` method on `torch::TrainCallback`.\n\n### Fixed\n\n- Fixed typo in parameter name `remove_stale_checkpoints` in `TorchTrainStep` (previously was `remove_state_checkpoints`).\n- Fixed bug in `FromParams` that would cause failures when `from __future__ import annotations`\n  was used with Python older than 3.10. See [PEP 563](https://www.python.org/dev/peps/pep-0563/)\n  for details.\n\n## [v0.3.5](https://github.com/allenai/tango/releases/tag/v0.3.5) - 2021-11-05\n\n### Fixed\n\n- Fixed a bug in `FromParams` where the \"type\" parameter was ignored in some cases\n  where the `Registrable` base class did not directly inherit from `Registrable`.\n\n## [v0.3.4](https://github.com/allenai/tango/releases/tag/v0.3.4) - 2021-11-04\n\n### Added\n\n- Added `StopEarlyCallback`, a `torch::TrainCallback` for early stopping.\n- Added parameter `remove_stale_checkpoints` to `TorchTrainStep`.\n\n### Changed\n\n- Minor changes to `torch::TrainCallback` interface.\n- Weights & Biases `torch::TrainCallback` now logs best validation metric score.\n\n## [v0.3.3](https://github.com/allenai/tango/releases/tag/v0.3.3) - 2021-11-04\n\n### Added\n\n- Added support for PEP 604 in `FromParams`, i.e. writing union types as \"X | Y\" instead of \"Union[X, Y]\".\n- [internals] Added a spot for miscellaneous end-to-end integration tests (not to be confused with \"tests of integrations\") in `tests/end_to_end/`.\n- [internals] Core tests now run on all officially supported Python versions.\n\n### Fixed\n\n- Fixed a bug in `FromParams` where non-`FromParams` class parameters were not instantiated\n  properly (or at all).\n- Fixed a bug in `FromParams` where kwargs were not passed on from a wrapper class to the wrapped class.\n- Fixed small bug where some errors from git would be printed when executor metadata is created\n  outside of a git repository.\n\n## [v0.3.2](https://github.com/allenai/tango/releases/tag/v0.3.2) - 2021-11-01\n\n### Fixed\n\n- Fixed a bug with `FromParams` that caused `.from_params()` to fail when the params contained\n  an object that was already instantiated.\n- tango command no longer installs a SIGTERM handler, which fixes some bugs with integrations that use multiprocessing.\n\n## [v0.3.1](https://github.com/allenai/tango/releases/tag/v0.3.1) - 2021-10-29\n\n### Changed\n- Updated the `LightningTrainStep` to optionally take in a `LightningDataModule` as input.\n\n## [v0.3.0](https://github.com/allenai/tango/releases/tag/v0.3.0) - 2021-10-28\n\n### Added\n\n- Added `IterableDatasetDict`, a version of `DatasetDict` for streaming-like datasets.\n- Added a [PyTorch Lightning](https://www.pytorchlightning.ai) integration with `LightningTrainStep`.\n\n### Fixed\n\n- Fixed bug with `FromParams` and `Lazy` where extra arguments would sometimes be passed down through\n  to a `Lazy` class when they shouldn't.\n\n## [v0.2.4](https://github.com/allenai/tango/releases/tag/v0.2.4) - 2021-10-22\n\n### Added\n\n- Added support for [torch 1.10.0](https://github.com/pytorch/pytorch/releases).\n\n### Changed\n\n- `--file-friendly-logging` flag is now an option to the main `tango` command, so needs\n  to be passed before `run`, e.g. `tango --file-friendly-logging run ...`.\n\n### Fixed\n\n- Fixed bug with `Step.from_params`.\n- Ensure logging is initialized is spawn processes during distributed training with `TorchTrainStep`.\n\n## [v0.2.3](https://github.com/allenai/tango/releases/tag/v0.2.3) - 2021-10-21\n\n### Added\n\n- Added support for global settings file, `tango.yml`.\n- Added 'include_package' (array of string) param to config spec.\n- Added a custom error `StopEarly` that a `TrainCallback` can raise within the `TorchTrainStep`\n  to stop training early without crashing.\n- Added step config, tango command, and tango version to executor metadata.\n- Executor now also saves pip dependencies and conda environment files to the run directory\n  for each step.\n\n### Fixed\n\n- Ensured `**kwargs` arguments are logged in `FromParams`.\n\n## [v0.2.2](https://github.com/allenai/tango/releases/tag/v0.2.2) - 2021-10-19\n\n### Added\n\n- Added new steps to `datasets` integration: `ConcatenateDatasets` (\"datasets::concatenate\") and `InterleaveDatasets` (datasets::interleave).\n- Added `__contains__` and `__iter__` methods on `DatasetDict` so that it is now a `Mapping` class.\n- Added `tango info` command that - among other things - displays which integrations are installed.\n\n## [v0.2.1](https://github.com/allenai/tango/releases/tag/v0.2.1) - 2021-10-18\n\n### Added\n\n- Added `convert_to_tango_dataset_dict()` function in the `datasets` integration.\n  It's important for step caching purposes to use this to convert a HF `DatasetDict`\n  to a native Tango `DatasetDict` when that `DatasetDict` is part of the input to another\n  step. Otherwise the HF `DatasetDict` will have to be pickled to determine its hash.\n\n### Changed\n\n- `Format.checksum()` is now an abstract method. Subclasses should only compute checksum\n  on the serialized artifact and nothing else in the directory.\n- [internals] Changed the relationship between `Executor`, `StepCache`, and `Step.`\n  `Executor` now owns the `StepCache`, and `Step` never interacts with `StepCache` directly.\n\n## [v0.2.0](https://github.com/allenai/tango/releases/tag/v0.2.0) - 2021-10-15\n\n### Added\n\n- Added a [Weights & Biases](https://wandb.ai) integration with a training callback (\"wandb::log\")\n  for `TorchTrainStep` (\"torch::train\") that logs training and validation metrics to W&B.\n\n### Fixed\n\n- Fixed `Format.checksum()` when there is a symlink to a directory in the cache folder.\n\n## [v0.1.3](https://github.com/allenai/tango/releases/tag/v0.1.3) - 2021-10-15\n\n### Added\n\n- Added the ability to track a metric other than \"loss\" for validation in `TorchTrainStep` (\"torch::train\").\n\n### Fixed\n\n- Final model returned from `TorchTrainStep` (\"torch::train\") will have best weights loaded.\n- Checkpoints are saved from `TorchTrainStep` (\"torch::train\") even when there is no validation loop.\n- Fixed `TorchTrainStep` (\"torch::train\") when `validation_split` is `None`.\n- Fixed distributed training with `TorchTrainStep` (\"torch::train\") on GPU devices.\n\n## [v0.1.2](https://github.com/allenai/tango/releases/tag/v0.1.2) - 2021-10-13\n\n### Added\n\n- Added support for YAML configuration files.\n\n## [v0.1.1](https://github.com/allenai/tango/releases/tag/v0.1.1) - 2021-10-12\n\n### Added\n\n- `TorchTrainStep` now displays a progress bar while saving a checkpoint to file.\n- The default executor now saves a \"executor-metadata.json\" file to the directory for each step.\n\n### Changed\n\n- Renamed `DirectoryStepCache` to `LocalStepCache` (registered as \"local\").\n- `LocalStepCache` saves metadata to `cache-metadata.json` instead of `metadata.json`.\n\n### Fixed\n\n- Fixed bug with `TorchTrainStep` during distributed training.\n- `FromParams` will automatically convert strings into `Path` types now when the annotation\n  is `Path`.\n\n## [v0.1.0](https://github.com/allenai/tango/releases/tag/v0.1.0) - 2021-10-11\n\n### Added\n\n- Added `StepGraph` and `Executor` abstractions.\n- Added a basic PyTorch training step registered as `\"torch::train\"`, along with other registrable\n  components, such as `Model`, `DataLoader`, `Sampler`, `DataCollator`, `Optimizer`, and `LRScheduler`.\n- Added `DatasetRemixStep` in `tango.steps`.\n- Added module `tango.common.sequences`.\n- Added `DatasetDict` class in `tango.common.dataset_dict`.\n- Added [🤗 Datasets](https://github.com/huggingface/datasets) integration.\n- Added command-line options to set log level or disable logging completely.\n\n### Changed\n\n- `Step.work_dir`, `Step.unique_id`, `Step.dependencies`, and `Step.recursive_dependencies`\n  are now a properties instead of methods.\n- `tango run` command will acquire a lock on the directory to avoid race conditions.\n- Integrations can now be installed with `pip install tango[INTEGRATION_NAME]`. For example,\n  `pip install tango[torch]`.\n- Added method `Registrable.search_modules()` for automatically finding and importing the modules\n  where a given ``name`` might be registered.\n- `FromParams.from_params()` and `Registrable.resolve_class_name` will now call `Registrable.search_modules()` to automatically import modules where the type might be defined.\n  Thus for classes that are defined and registered within any `tango.*` submodules it is not necessary to explicitly import them.\n\n### Fixed\n\n- `Step` implementations can now take arbitrary `**kwargs` in their `run()` methods.\n\n## [v0.0.3](https://github.com/allenai/tango/releases/tag/v0.0.3) - 2021-09-27\n\n### Added\n\n- Added `tango` command.\n\n## [v0.0.2](https://github.com/allenai/tango/releases/tag/v0.0.2) - 2021-09-27\n\n### Added\n\n- Ported over core tango components from AllenNLP.\n\n## [v0.0.1](https://github.com/allenai/tango/releases/tag/v0.0.1) - 2021-09-22\n\n### Added\n\n- Added initial project boilerplate.\n"
  },
  {
    "path": "CITATION.cff",
    "content": "cff-version: 1.2.0\nmessage: \"If you use this software, please cite it as below.\"\nauthors:\n- family-names: \"Groeneveld\"\n  given-names: \"Dirk\"\n  affiliation: \"Allen Institute for Artificial Intelligence\"\n- family-names: \"Bhagia\"\n  given-names: \"Akshita\"\n  affiliation: \"Allen Institute for Artificial Intelligence\"\n- family-names: \"Walsh\"\n  given-names: \"Pete\"\n  affiliation: \"Allen Institute for Artificial Intelligence\"\ntitle: \"AI2 Tango\"\nabstract: \"Organize your experiments into discrete steps that can be cached and reused throughout the lifetime of your research project.\"\nversion: \"1.3.2\"\nrepository-code: \"https://github.com/allenai/tango\"\nlicense: \"Apache-2.0\"\ndate-released: \"2023-10-27\"\nrepository-code: \"https://github.com/allenai/tango\"\n"
  },
  {
    "path": "Dockerfile",
    "content": "# This Dockerfile can be used to build a Docker image suitable for tango projects.\n\nARG BASE_IMAGE=ghcr.io/allenai/pytorch:2.0.0-cuda11.7-python3.10\nFROM ${BASE_IMAGE}\n\nWORKDIR /stage\n\nCOPY . .\nRUN /opt/conda/bin/pip install --no-cache-dir .[all]\n\nWORKDIR /workspace\n\nRUN rm -rf /stage/\n\nENTRYPOINT [\"/opt/conda/bin/tango\"]\n"
  },
  {
    "path": "Dockerfile.test",
    "content": "# This Dockerfile is for building an image suitable for running tango's GPU tests and integration tests.\n# There are no instruction lines in this Dockerfile that install tango. Instead, the entrypoint\n# script handles installing tango from a particular commit at runtime, based on the environment\n# variable \"COMMIT_SHA\". That way we don't need to rebuild and push the image each time we run\n# tests, and we can be sure the dependencies are always up-to-date.\n\nFROM ghcr.io/allenai/pytorch:2.0.0-cuda11.7-python3.10\n\nCOPY scripts/entrypoint.sh /entrypoint.sh\nRUN chmod +x /entrypoint.sh\n\nWORKDIR /testing\n\nENTRYPOINT [\"/entrypoint.sh\"]\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY : docs\ndocs :\n\trm -rf docs/build/\n\tsphinx-autobuild -b html --watch tango/ --watch examples/ docs/source/ docs/build/\n\n.PHONY : run-checks\nrun-checks :\n\tisort --check .\n\tblack --check .\n\truff check .\n\tmypy --check-untyped-defs .\n\tCUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules --ignore=tests/integrations --ignore=tango/integrations tests/ tango/\n\tCUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules tango/integrations/torch tests/integrations/torch\n\tCUDA_VISIBLE_DEVICES='' pytest -v --color=yes --doctest-modules tango/integrations/transformers tests/integrations/transformers\n"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n<br>\n<img src=\"https://raw.githubusercontent.com/allenai/tango/main/docs/source/_static/tango_final_horizontal.png\" width=\"600\"/>\n<br>\n<br>\n<p>\n<!-- start tagline -->\nAI2 Tango replaces messy directories and spreadsheets full of file versions by organizing experiments into discrete steps that can be cached and reused throughout the lifetime of a research project.\n<!-- end tagline -->\n</p>\n<hr/>\n<a href=\"https://github.com/allenai/tango/actions\">\n    <img alt=\"CI\" src=\"https://github.com/allenai/tango/workflows/CI/badge.svg?event=push&branch=main\">\n</a>\n<a href=\"https://pypi.org/project/ai2-tango/\">\n    <img alt=\"PyPI\" src=\"https://img.shields.io/pypi/v/ai2-tango\">\n</a>\n<a href=\"https://ai2-tango.readthedocs.io/en/latest/?badge=latest\">\n    <img src=\"https://readthedocs.org/projects/ai2-tango/badge/?version=latest\" alt=\"Documentation Status\" />\n</a>\n<a href=\"https://github.com/allenai/tango/blob/main/LICENSE\">\n    <img alt=\"License\" src=\"https://img.shields.io/github/license/allenai/tango.svg?color=blue&cachedrop\">\n</a>\n<br/>\n</div>\n\n## Quick links\n\n- [Documentation](https://ai2-tango.readthedocs.io/)\n- [PyPI Package](https://pypi.org/project/ai2-tango/)\n- [Contributing](https://github.com/allenai/tango/blob/main/CONTRIBUTING.md)\n- [License](https://github.com/allenai/tango/blob/main/LICENSE)\n\n## In this README\n\n- [Quick start](#quick-start)\n- [Installation](#installation)\n  - [Installing with PIP](#installing-with-pip)\n  - [Installing with Conda](#installing-with-conda)\n  - [Installing from source](#installing-from-source)\n  - [Checking your installation](#checking-your-installation)\n  - [Docker image](#docker-image)\n- [FAQ](#faq)\n- [Team](#team)\n- [License](#license)\n\n## Quick start\n\nCreate a Tango step:\n\n```python\n# hello.py\n\nfrom tango import step\n\n@step()\ndef hello(name: str) -> str:\n    message = f\"Hello, {name}!\"\n    print(message)\n    return message\n```\n\nAnd create a corresponding experiment configuration file:\n\n```jsonnet\n// hello.jsonnet\n\n{\n  steps: {\n    hello: {\n      type: \"hello\",\n      name: \"World\",\n    }\n  }\n}\n```\n\nThen run the experiment using a local workspace to cache the result:\n\n```bash\ntango run hello.jsonnet -w /tmp/workspace\n```\n\nYou'll see something like this in the output:\n\n```\nStarting new run expert-llama\n● Starting step \"hello\"...\nHello, World!\n✓ Finished step \"hello\"\n✓ Finished run expert-llama\n```\n\nIf you run this a second time the output will now look like this:\n\n```\nStarting new run open-crab\n✓ Found output for step \"hello\" in cache...\n✓ Finished run open-crab\n```\n\nYou won't see \"Hello, World!\" this time because the result of the step was found in the cache, so it wasn't run again.\n\nFor a more detailed introduction check out the [First Steps](https://ai2-tango.readthedocs.io/en/latest/first_steps.html) walk-through.\n\n## Installation\n\n<!-- start install -->\n\n**ai2-tango** requires Python 3.8 or later.\n\n### Installing with `pip`\n\n**ai2-tango** is available [on PyPI](https://pypi.org/project/ai2-tango/). Just run\n\n```bash\npip install ai2-tango\n```\n\nTo install with a specific integration, such as `torch` for example, run\n\n```bash\npip install 'ai2-tango[torch]'\n```\n\nTo install with all integrations, run\n\n```bash\npip install 'ai2-tango[all]'\n```\n\n### Installing with `conda`\n\n**ai2-tango** is available on conda-forge. You can install just the base package with\n\n```bash\nconda install tango -c conda-forge\n```\n\nYou can pick and choose from the integrations with one of these:\n\n```bash\nconda install tango-datasets -c conda-forge\nconda install tango-torch -c conda-forge\nconda install tango-wandb -c conda-forge\n```\n\nYou can also install everything:\n\n```bash\nconda install tango-all -c conda-forge\n```\n\nEven though **ai2-tango** itself is quite small, installing everything will pull in a lot of dependencies.\nDon't be surprised if this takes a while!\n\n### Installing from source\n\nTo install **ai2-tango** from source, first clone [the repository](https://github.com/allenai/tango):\n\n```bash\ngit clone https://github.com/allenai/tango.git\ncd tango\n```\n\nThen run\n\n```bash\npip install -e '.[all]'\n```\n\nTo install with only a specific integration, such as `torch` for example, run\n\n```bash\npip install -e '.[torch]'\n```\n\nOr to install just the base tango library, you can run\n\n```bash\npip install -e .\n```\n\n### Checking your installation\n\nRun\n\n```bash\ntango info\n```\n\nto check your installation.\n\n### Docker image\n\nYou can build a Docker image suitable for tango projects by using [the official Dockerfile](https://github.com/allenai/tango/blob/main/Dockerfile) as a starting point for your own Dockerfile, or you can simply use one of our [prebuilt images](https://github.com/allenai/tango/pkgs/container/tango) as a base image in your Dockerfile. For example:\n\n```Dockerfile\n# Start from a prebuilt tango base image.\n# You can choose the right tag from the available options here:\n# https://github.com/allenai/tango/pkgs/container/tango/versions\nFROM ghcr.io/allenai/tango:cuda11.3\n\n# Install your project's additional requirements.\nCOPY requirements.txt .\nRUN /opt/conda/bin/pip install --no-cache-dir -r requirements.txt\n\n# Install source code.\n# This instruction copies EVERYTHING in the current directory (build context),\n# which may not be what you want. Consider using a \".dockerignore\" file to\n# exclude files and directories that you don't want on the image.\nCOPY . .\n```\n\nMake sure to choose the right base image for your use case depending on the version of tango you're using and the CUDA version that your host machine supports.\nYou can see a list of all available image tags [on GitHub](https://github.com/allenai/tango/pkgs/container/tango/versions).\n\n<!-- end install -->\n\n## FAQ\n\n<!-- start faq -->\n\n### Why is the library named Tango?\n\nThe motivation behind this library is that we can make research easier by composing it into well-defined steps.  What happens when you choreograph a number of steps together?  Well, you get a dance.  And since our [team's leader](https://nasmith.github.io/) is part of a tango band, \"AI2 Tango\" was an obvious choice!\n\n### How can I debug my steps through the Tango CLI?\n\nYou can run the `tango` command through [pdb](https://docs.python.org/3/library/pdb.html). For example:\n\n```bash\npython -m pdb -m tango run config.jsonnet\n```\n\n### How is Tango different from [Metaflow](https://metaflow.org), [Airflow](https://airflow.apache.org), or [redun](https://github.com/insitro/redun)?\n\nWe've found that existing DAG execution engines like these tools are great for production workflows but not as well suited for messy, collaborative research projects\nwhere code is changing constantly. AI2 Tango was built *specifically* for these kinds of research projects.\n\n### How does Tango's caching mechanism work?\n\nAI2 Tango caches the results of steps based on the `unique_id` of the step. The `unique_id` is essentially a hash of all of the inputs to the step along with:\n\n1. the step class's fully qualified name, and\n2. the step class's `VERSION` class variable (an arbitrary string).\n\nUnlike other workflow engines like [redun](https://github.com/insitro/redun), Tango does *not* take into account the source code of the class itself (other than its fully qualified name) because we've found that using a hash of the source code bytes is way too sensitive and less transparent for users.\nWhen you change the source code of your step in a meaningful way you can just manually change the `VERSION` class variable to indicate to Tango\nthat the step has been updated.\n\n<!-- end faq -->\n\n## Team\n\n<!-- start team -->\n\n**ai2-tango** is developed and maintained by the AllenNLP team, backed by [the Allen Institute for Artificial Intelligence (AI2)](https://allenai.org/).\nAI2 is a non-profit institute with the mission to contribute to humanity through high-impact AI research and engineering.\nTo learn more about who specifically contributed to this codebase, see [our contributors](https://github.com/allenai/tango/graphs/contributors) page.\n\n<!-- end team -->\n\n## License\n\n<!-- start license -->\n\n**ai2-tango** is licensed under [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0).\nA full copy of the license can be found [on GitHub](https://github.com/allenai/tango/blob/main/LICENSE).\n\n<!-- end license -->\n"
  },
  {
    "path": "RELEASE_PROCESS.md",
    "content": "# GitHub Release Process\n\n## Steps\n\n1. Update the version in `tango/version.py`.\n\n2. Run the release script:\n\n    ```bash\n    ./scripts/release.sh\n    ```\n\n    This will automatically update the CHANGELOG, commit the changes to the CHANGELOG and `version.py` (and any other files you might have changed),\n    and then create a new tag in git which will trigger a workflow on GitHub Actions that handles the rest.\n\n## Fixing a failed release\n\nIf for some reason the GitHub Actions release workflow failed with an error that needs to be fixed, you'll have to delete both the tag and corresponding release from GitHub. After you've pushed a fix, delete the tag from your local clone with\n\n```bash\ngit tag -l | xargs git tag -d && git fetch -t\n```\n\nThen repeat the steps above.\n"
  },
  {
    "path": "docs/.gitignore",
    "content": "build\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=source\r\nset BUILDDIR=build\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.https://www.sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/source/_static/css/custom.css",
    "content": ""
  },
  {
    "path": "docs/source/api/commands.rst",
    "content": "Commands\n========\n\n.. automodule:: tango.__main__\n"
  },
  {
    "path": "docs/source/api/components/executor.rst",
    "content": "Executor\n========\n\nBase class\n----------\n\n.. autoclass:: tango.executor.Executor\n   :members: \n\n.. autoclass:: tango.executor.ExecutorOutput\n   :members: \n\n.. autoclass:: tango.executor.ExecutionMetadata\n   :members: \n"
  },
  {
    "path": "docs/source/api/components/format.rst",
    "content": "Format\n======\n\nBase class\n----------\n\n.. autoclass:: tango.format.Format\n   :members: \n   :private-members:\n\nImplementations\n---------------\n\n.. automodule:: tango.format\n   :members:\n   :exclude-members: Format,read,write,checksum\n"
  },
  {
    "path": "docs/source/api/components/index.rst",
    "content": "Components\n==========\n\nThe core components of **AI2 Tango**.\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Components\n\n   step\n   step_info\n   step_graph\n   workspace\n   step_cache\n   format\n   executor\n"
  },
  {
    "path": "docs/source/api/components/step.rst",
    "content": "Step\n====\n\nBase class\n----------\n\n.. autoclass:: tango.step.Step\n   :members: \n   :special-members:\n   :exclude-members: from_params\n\n.. autofunction:: tango.step.step\n\n.. autoclass:: tango.step.WithUnresolvedSteps\n   :members:\n\n.. autoclass:: tango.step.StepResources\n   :members:\n\nImplementations\n---------------\n\n.. automodule:: tango.steps\n   :members:\n"
  },
  {
    "path": "docs/source/api/components/step_cache.rst",
    "content": "StepCache\n=========\n\nBase class\n----------\n\n.. autoclass:: tango.step_cache.StepCache\n   :members: \n   :special-members:\n\nImplementations\n---------------\n\n.. autoclass:: tango.step_caches.LocalStepCache\n   :members:\n\n.. autoclass:: tango.step_caches.MemoryStepCache\n\nMetadata\n--------\n\n.. autoclass:: tango.step_cache.CacheMetadata\n   :members:\n"
  },
  {
    "path": "docs/source/api/components/step_graph.rst",
    "content": "StepGraph\n=========\n\n.. autoclass:: tango.step_graph.StepGraph\n   :members: \n"
  },
  {
    "path": "docs/source/api/components/step_info.rst",
    "content": "StepInfo\n========\n\n.. autoclass:: tango.step_info.StepInfo\n   :member-order: bysource\n   :members:\n\n.. autoclass:: tango.step_info.StepState\n   :member-order: bysource\n   :members:\n\n.. autoclass:: tango.step_info.PlatformMetadata\n   :member-order: bysource\n   :members:\n\n.. autoclass:: tango.step_info.EnvironmentMetadata\n   :member-order: bysource\n   :members:\n\n.. autoclass:: tango.step_info.GitMetadata\n   :member-order: bysource\n   :members:\n\n.. autoclass:: tango.step_info.TangoMetadata\n   :member-order: bysource\n   :members:\n"
  },
  {
    "path": "docs/source/api/components/workspace.rst",
    "content": "Workspace\n=========\n\nBase class\n----------\n\n.. autoclass:: tango.workspace.Workspace\n   :members:\n\nImplementations\n---------------\n\n.. autoclass:: tango.workspaces.LocalWorkspace\n\n.. autoclass:: tango.workspaces.MemoryWorkspace\n\nMetadata\n--------\n\n.. autoclass:: tango.workspace.Run\n   :members:\n\n.. autoclass:: tango.workspace.RunInfo\n   :members:\n\nMiscellaneous\n-------------\n\n.. autoclass:: tango.workspace.RunSort\n   :members:\n\n.. autoclass:: tango.workspace.StepInfoSort\n   :members:\n"
  },
  {
    "path": "docs/source/api/det_hash.rst",
    "content": "Deterministic Hashing\n=====================\n\nIn order to detect whether a :class:`~tango.step.Step` has to be re-run or not, Tango relies on some tools to compute\ndeterministic hashes from the inputs to the :class:`~tango.step.Step`.\n\nThe center-piece of this module is the :func:`~tango.common.det_hash.det_hash` function, which computes a deterministic hash of an\narbitrary Python object. The other things in this module influence how that works in various ways.\n\n.. automodule:: tango.common.det_hash\n   :members:\n"
  },
  {
    "path": "docs/source/api/exceptions.rst",
    "content": "Exceptions\n==========\n\n.. autoexception:: tango.common.exceptions.TangoError\n   :members:\n\n.. automodule:: tango.common.exceptions\n   :members:\n   :exclude-members: TangoError\n"
  },
  {
    "path": "docs/source/api/integrations/beaker.rst",
    "content": "🧪 Beaker\n=========\n\n.. automodule:: tango.integrations.beaker\n\nReference\n---------\n\n.. autoclass:: tango.integrations.beaker.BeakerWorkspace\n\n.. autoclass:: tango.integrations.beaker.BeakerStepCache\n\n.. autoclass:: tango.integrations.beaker.BeakerExecutor\n   :members: DEFAULT_BEAKER_IMAGE\n\n.. autoclass:: tango.integrations.beaker.BeakerScheduler\n   :members:\n\n.. autoclass:: tango.integrations.beaker.SimpleBeakerScheduler\n\n.. autoclass:: tango.integrations.beaker.ResourceAssignment\n   :members:\n\n.. autoclass:: tango.integrations.beaker.ResourceAssignmentError\n"
  },
  {
    "path": "docs/source/api/integrations/datasets.rst",
    "content": "🤗 Datasets\n===========\n\n.. automodule:: tango.integrations.datasets\n\nReference\n---------\n\n.. autofunction:: tango.integrations.datasets.convert_to_tango_dataset_dict\n\n.. autoclass:: tango.integrations.datasets.DatasetsFormat\n\n.. autoclass:: tango.integrations.datasets.LoadDataset\n   :members:\n\n.. autoclass:: tango.integrations.datasets.LoadStreamingDataset\n   :members:\n\n.. autoclass:: tango.integrations.datasets.InterleaveDatasets\n   :members:\n\n.. autoclass:: tango.integrations.datasets.ConcatenateDatasets\n   :members:\n\n.. autoclass:: tango.integrations.datasets.DatasetRemixStep\n   :members:"
  },
  {
    "path": "docs/source/api/integrations/fairscale.rst",
    "content": "🔥 FairScale\n============\n\n.. automodule:: tango.integrations.fairscale\n\nReference\n---------\n\n.. autoclass:: tango.integrations.fairscale.FairScaleTrainingEngine\n\n.. autoclass:: tango.integrations.fairscale.FSDPConfig\n    :members:\n\n.. autofunction:: tango.integrations.fairscale.with_wrapped_modules\n"
  },
  {
    "path": "docs/source/api/integrations/flax.rst",
    "content": "Flax\n=======\n\n.. automodule:: tango.integrations.flax\n\nReference\n---------\n\nTrain step\n~~~~~~~~~~\n\n.. autoclass:: tango.integrations.flax.FlaxTrainStep\n   :members:\n\n.. autoclass:: tango.integrations.flax.TrainConfig\n   :members:\n\nEval step\n~~~~~~~~~\n\n.. autoclass:: tango.integrations.flax.FlaxEvalStep\n   :members:\n\nFlax format\n~~~~~~~~~~~~\n\n.. autoclass:: tango.integrations.flax.FlaxFormat\n\nModel\n~~~~~\n\n.. autoclass:: tango.integrations.flax.Model\n   :members:\n\nOptim\n~~~~~\n\n.. autoclass:: tango.integrations.flax.Optimizer\n   :members:\n\n.. autoclass:: tango.integrations.flax.LRScheduler\n   :members:\n\nData\n~~~~\n\n.. autoclass:: tango.integrations.flax.DataLoader\n   :members:\n\n.. autoclass:: tango.integrations.flax.FlaxDataLoader\n   :members:\n\nCallbacks\n~~~~~~~~~\n\n.. autoclass:: tango.integrations.flax.TrainCallback\n   :members:\n   :member-order: bysource\n\n.. autoclass:: tango.integrations.flax.EvalCallback\n   :members:\n   :member-order: bysource\n"
  },
  {
    "path": "docs/source/api/integrations/gs.rst",
    "content": "☁️ Google Cloud Storage\n=======================\n\n.. automodule:: tango.integrations.gs\n\nReference\n---------\n\n.. autoclass:: tango.integrations.gs.GSWorkspace\n\n.. autoclass:: tango.integrations.gs.GSStepCache\n"
  },
  {
    "path": "docs/source/api/integrations/index.rst",
    "content": "Integrations\n============\n\n.. automodule:: tango.integrations\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Integrations\n\n   torch\n   fairscale\n   datasets\n   transformers\n   wandb\n   beaker\n   flax\n   gs\n"
  },
  {
    "path": "docs/source/api/integrations/torch.rst",
    "content": "🔥 PyTorch\n==========\n\n.. automodule:: tango.integrations.torch\n\nReference\n---------\n\nTrain step\n~~~~~~~~~~\n\n.. autoclass:: tango.integrations.torch.TorchTrainStep\n   :members:\n\n.. autoclass:: tango.integrations.torch.TrainConfig\n   :members:\n\nEval step\n~~~~~~~~~\n\n.. autoclass:: tango.integrations.torch.TorchEvalStep\n   :members:\n\nTorch format\n~~~~~~~~~~~~\n\n.. autoclass:: tango.integrations.torch.TorchFormat\n\nModel\n~~~~~\n\n.. autoclass:: tango.integrations.torch.Model\n   :members:\n\nTrainingEngine\n~~~~~~~~~~~~~~\n\n.. autoclass:: tango.integrations.torch.TrainingEngine\n   :members:\n\n.. autoclass:: tango.integrations.torch.TorchTrainingEngine\n\nOptim\n~~~~~\n\n.. autoclass:: tango.integrations.torch.Optimizer\n   :members:\n\n.. autoclass:: tango.integrations.torch.LRScheduler\n   :members:\n\nData\n~~~~\n\n.. autoclass:: tango.integrations.torch.DataLoader\n   :members:\n\n.. autoclass:: tango.integrations.torch.Sampler\n   :members:\n\n.. autoclass:: tango.integrations.torch.DataCollator\n   :members:\n   :special-members: __call__\n\n.. autoclass:: tango.integrations.torch.ConcatTensorDictsCollator\n   :members:\n\nCallbacks\n~~~~~~~~~\n\n.. autoclass:: tango.integrations.torch.TrainCallback\n   :members:\n   :member-order: bysource\n\n.. autoclass:: tango.integrations.torch.EvalCallback\n   :members:\n   :member-order: bysource\n\n.. autoclass:: tango.integrations.torch.StopEarlyCallback\n\n.. autoclass:: tango.integrations.torch.StopEarly\n   :members:\n"
  },
  {
    "path": "docs/source/api/integrations/transformers.rst",
    "content": "🤗 Transformers\n===============\n\n.. automodule:: tango.integrations.transformers\n    :members:\n\n.. autofunction:: tango.integrations.transformers.ia3.modify_with_ia3"
  },
  {
    "path": "docs/source/api/integrations/wandb.rst",
    "content": "⚖️ Weights & Biases\n===================\n \n.. automodule:: tango.integrations.wandb\n\nReference\n---------\n\n.. autoclass:: tango.integrations.wandb.WandbWorkspace\n\n.. autoclass:: tango.integrations.wandb.WandbStepCache\n\n.. autoclass:: tango.integrations.wandb.WandbTrainCallback\n\n.. autoclass:: tango.integrations.wandb.WandbFlaxTrainCallback\n"
  },
  {
    "path": "docs/source/api/logging.rst",
    "content": "Logging\n=======\n\n.. automodule:: tango.common.logging\n\nReference\n---------\n\n.. autodata:: tango.common.logging.TANGO_LOG_LEVEL\n\n.. autodata:: tango.common.logging.FILE_FRIENDLY_LOGGING\n\n.. autodata:: tango.common.logging.cli_logger\n\n.. autofunction:: tango.common.logging.initialize_logging\n\n.. autofunction:: tango.common.logging.initialize_worker_logging\n\n.. autofunction:: tango.common.logging.initialize_prefix_logging\n\n.. autofunction:: tango.common.logging.teardown_logging\n\n.. autofunction:: tango.common.logging.file_handler\n"
  },
  {
    "path": "docs/source/api/sequences.rst",
    "content": "Sequences\n=========\n\nThis module contains some utilities to make sequences out of other sequences. All of these are lazy, so they\ntake minimal time and memory when you create them. These work particularly well when used together. For example,\nyou can concatenate two sequences (:class:`~tango.common.sequences.ConcatenatedSequence`), and then shuffle\nthem (:class:`~tango.common.sequences.ShuffledSequence`).\n\nThis module is not dependent on other Tango modules and can be used in isolation.\n\n.. automodule:: tango.common.sequences\n   :members:\n"
  },
  {
    "path": "docs/source/api/settings.rst",
    "content": "Global settings\n---------------\n\nSome command-line options can set globally in a ``tango.yml`` or ``tango.yaml`` settings file.\nTango will check the current directory and ``~/.config/``, in that order.\n\nThe full spec of this config is defined by the :class:`~tango.settings.TangoGlobalSettings` class.\n\n.. autoclass:: tango.settings.TangoGlobalSettings\n   :members:\n   :exclude-members: path,find_or_default\n   :member-order: bysource\n"
  },
  {
    "path": "docs/source/api/utilities.rst",
    "content": "Utilities\n=========\n\n.. automodule:: tango.common\n   :members:\n   :exclude-members: det_hash\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\nimport logging\nimport os\nimport sys\nfrom datetime import datetime\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n\nsys.path.insert(0, os.path.abspath(\"../../\"))\n\nfrom tango.version import VERSION, VERSION_SHORT  # noqa: E402\n\n# -- Project information -----------------------------------------------------\n\nproject = \"AI2 Tango\"\ncopyright = f\"{datetime.today().year}, Allen Institute for Artificial Intelligence\"\nauthor = \"Allen Institute for Artificial Intelligence\"\nversion = VERSION_SHORT\nrelease = VERSION\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    \"sphinx.ext.autodoc\",\n    \"sphinx.ext.napoleon\",\n    \"myst_parser\",\n    \"sphinx.ext.intersphinx\",\n    \"sphinx.ext.viewcode\",\n    \"sphinx.ext.doctest\",\n    \"sphinx_copybutton\",\n    \"sphinx_autodoc_typehints\",\n]\n\nsuppress_warnings = [\"myst.header\"]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = [\"_templates\"]\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = [\"_build\"]\n\nsource_suffix = [\".rst\", \".md\"]\n\n# -- Extension configuration -------------------------------------------------\n\nintersphinx_mapping = {\n    \"python\": (\"https://docs.python.org/3\", None),\n    \"rich\": (\"https://rich.readthedocs.io/en/latest\", None),\n    \"torch\": (\"https://pytorch.org/docs/stable\", None),\n    \"flax\": (\"https://flax.readthedocs.io/en/latest\", None),\n    \"fairscale\": (\"https://fairscale.readthedocs.io/en/latest/\", None),\n    \"datasets\": (\"https://huggingface.co/docs/datasets/master/en\", None),\n    \"transformers\": (\"https://huggingface.co/docs/transformers/master/en\", None),\n    \"beaker\": (\"https://beaker-py.readthedocs.io/en/latest/\", None),\n}\n\n# Tell myst-parser to assign header anchors for h1-h3.\nmyst_heading_anchors = 3\n\n# By default, sort documented members by type within classes and modules.\nautodoc_member_order = \"groupwise\"\n\npython_use_unqualified_type_names = True\n\n# Include default values when documenting parameter types.\ntypehints_defaults = \"comma\"\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = \"furo\"\n\nhtml_title = f\"ai2-tango v{VERSION}\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = [\"_static\"]\n\nhtml_css_files = [\"css/custom.css\"]\n\nhtml_favicon = \"_static/favicon.ico\"\n\nhtml_theme_options = {\n    \"light_css_variables\": {\n        \"color-announcement-background\": \"#1B4596\",\n        \"color-announcement-text\": \"#FFFFFF\",\n    },\n    \"dark_css_variables\": {},\n    \"light_logo\": \"tango_final_squareish.png\",\n    \"dark_logo\": \"tango_final_squareish.png\",\n    \"footer_icons\": [\n        {\n            \"name\": \"GitHub\",\n            \"url\": \"https://github.com/allenai/tango\",\n            \"html\": \"\"\"\n                <svg stroke=\"currentColor\" fill=\"currentColor\" stroke-width=\"0\" viewBox=\"0 0 16 16\">\n                    <path fill-rule=\"evenodd\" d=\"M8 0C3.58 0 0 3.58 0 8c0 3.54 2.29 6.53 5.47 7.59.4.07.55-.17.55-.38 0-.19-.01-.82-.01-1.49-2.01.37-2.53-.49-2.69-.94-.09-.23-.48-.94-.82-1.13-.28-.15-.68-.52-.01-.53.63-.01 1.08.58 1.23.82.72 1.21 1.87.87 2.33.66.07-.52.28-.87.51-1.07-1.78-.2-3.64-.89-3.64-3.95 0-.87.31-1.59.82-2.15-.08-.2-.36-1.02.08-2.12 0 0 .67-.21 2.2.82.64-.18 1.32-.27 2-.27.68 0 1.36.09 2 .27 1.53-1.04 2.2-.82 2.2-.82.44 1.1.16 1.92.08 2.12.51.56.82 1.27.82 2.15 0 3.07-1.87 3.75-3.65 3.95.29.25.54.73.54 1.48 0 1.07-.01 1.93-.01 2.2 0 .21.15.46.55.38A8.013 8.013 0 0 0 16 8c0-4.42-3.58-8-8-8z\"></path>\n                </svg>\n            \"\"\",  # noqa: E501\n            \"class\": \"\",\n        },\n    ],\n}\n\n# -- Hack to get rid of stupid warnings from sphinx_autodoc_typehints --------\n\n\nclass ShutupSphinxAutodocTypehintsFilter(logging.Filter):\n    def filter(self, record: logging.LogRecord) -> bool:\n        if \"Cannot resolve forward reference\" in record.msg:\n            return False\n        return True\n\n\nlogging.getLogger(\"sphinx.sphinx_autodoc_typehints\").addFilter(ShutupSphinxAutodocTypehintsFilter())\n"
  },
  {
    "path": "docs/source/examples/euler.md",
    "content": "```{include} ../../../examples/euler/README.md\n```\n\n## Running the experiment\n\nIf you haven't already, clone the [tango repository](https://github.com/allenai/tango) and then\nchange directories into `examples/euler`.\n\nYou can then run the experiment with:\n\n```bash\ntango run euler_general.jsonnet -i complex_arithmetic -w workspace\n```\n\nThis will leave its results in a subdirectory of `workspace/runs/` corresponding to the name of the run.\nThe output it prints should look something like this:\n```\nStarting new run comic-heron\nServer started at http://localhost:8080/run/comic-heron\n[step i_times_pi] ● Starting step \"i_times_pi\"...\n[step i_times_pi] ✓ Finished step \"i_times_pi\"\n[step cos] ● Starting step \"cos\"...\n[step cos] ✓ Finished step \"cos\"\n[step sin] ● Starting step \"sin\"...\n[step sin] ✓ Finished step \"sin\"\n[step pow_e] ✓ Found output for step \"i_times_pi\" in cache (needed by \"pow_e\")...\n[step pow_e] ● Starting step \"pow_e\"...\n[step pow_e] ✓ Finished step \"pow_e\"\n[step i_times_sin] ✓ Found output for step \"sin\" in cache (needed by \"i_times_sin\")...\n[step i_times_sin] ● Starting step \"i_times_sin\"...\n[step i_times_sin] ✓ Finished step \"i_times_sin\"\n[step sum] ✓ Found output for step \"cos\" in cache (needed by \"sum\")...\n[step sum] ✓ Found output for step \"i_times_sin\" in cache (needed by \"sum\")...\n[step sum] ● Starting step \"sum\"...\n[step sum] ✓ Finished step \"sum\"\n[step sub] ✓ Found output for step \"sum\" in cache (needed by \"sub\")...\n[step sub] ✓ Found output for step \"pow_e\" in cache (needed by \"sub\")...\n[step sub] ● Starting step \"sub\"...\n[step sub] ✓ Finished step \"sub\"\n[step print] ✓ Found output for step \"sub\" in cache (needed by \"print\")...\n[step print] ● Starting step \"print\"...\n[step print] 0j\n[step print] ✓ Finished step \"print\"\n✓ Finished run comic-heron\n\n ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n ┃ Step Name   ┃ Status      ┃ Cached Result                                                     ┃\n ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n │ cos         │ ✓ succeeded │ workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk       │\n │ i_times_pi  │ ✓ succeeded │ workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae     │\n │ i_times_sin │ ✓ succeeded │ workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf     │\n │ pow_e       │ ✓ succeeded │ workspace/cache/ExponentiateStep-1swPpNipP6HBSP5rKdNjEqbYAWNf4CdG │\n │ print       │ ✓ succeeded │ N/A                                                               │\n │ sin         │ ✓ succeeded │ workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk         │\n │ sub         │ ✓ succeeded │ workspace/cache/SubtractionStep-4ygj1UyLk6TCVBxN7DWTCccbMa7M1C5v  │\n │ sum         │ ✓ succeeded │ workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP     │\n └─────────────┴─────────────┴───────────────────────────────────────────────────────────────────┘\n                                                                 ✓ 8 succeeded\n\nUse your workspace to get the cached result of a step, e.g.\n\n >>> from tango import Workspace\n >>> workspace = Workspace.from_url(...)\n >>> workspace.step_result_for_run(\"comic-heron\", \"sum\")\n```\n\nA few things are of note here:\n 1. Tango assigns a name to your run. In this case, the name is \"comic-heron\".\n 2. In this configuration, the \"print\" step prints the output (\"`0j`\"). Most of the time though, you will look\n    for the output in the output directories that are given in the table.\n 3. You might notice that the \"print\" step produces no output. That's because it is uncacheable, and thus writes\n    out nothing.\n\n\n## Change a step\n\nLet's make an update to a step! Open `complex_arithmetic.py` and change `AdditionStep`. The actual change you make\nin the `run()` method does not matter, but the important thing is to update the `VERSION` member of the\n`AdditionStep` class. `AdditionStep` does not yet have a `VERSION`, so we will give it one:\n```Python\n@Step.register(\"cadd\")\nclass AdditionStep(Step):\n    VERSION = \"002\"     # This is the important change.\n    \n    def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex:  # type: ignore\n        return make_complex(a) + make_complex(b)\n```\n\nNow run the config again with\n```bash\ntango run euler_general.jsonnet -i complex_arithmetic -w workspace\n```\n\nThis time, the output will look like this:\n```\nStarting new run right-amoeba\nServer started at http://localhost:8080/run/right-amoeba\n[step sum] ✓ Found output for step \"cos\" in cache (needed by \"sum\")...\n[step sum] ✓ Found output for step \"i_times_sin\" in cache (needed by \"sum\")...\n[step sum] ● Starting step \"sum\"...\n[step sum] ✓ Finished step \"sum\"\n[step sub] ✓ Found output for step \"sum\" in cache (needed by \"sub\")...\n[step sub] ✓ Found output for step \"pow_e\" in cache (needed by \"sub\")...\n[step sub] ● Starting step \"sub\"...\n[step sub] ✓ Finished step \"sub\"\n[step print] ✓ Found output for step \"sub\" in cache (needed by \"print\")...\n[step print] ● Starting step \"print\"...\n[step print] 0j\n[step print] ✓ Finished step \"print\"\n✓ Finished run right-amoeba\n\n ┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n ┃ Step Name   ┃ Status      ┃ Cached Result                                                     ┃\n ┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n │ cos         │ - not run   │ workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk       │\n │ i_times_pi  │ - not run   │ workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae     │\n │ i_times_sin │ - not run   │ workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf     │\n │ pow_e       │ - not run   │ workspace/cache/ExponentiateStep-1swPpNipP6HBSP5rKdNjEqbYAWNf4CdG │\n │ print       │ ✓ succeeded │ N/A                                                               │\n │ sin         │ - not run   │ workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk         │\n │ sub         │ ✓ succeeded │ workspace/cache/SubtractionStep-42mdcQBtrNAYvxYhmzdd1vj2uCG8N5Yf  │\n │ sum         │ ✓ succeeded │ workspace/cache/AdditionStep-002-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP │\n └─────────────┴─────────────┴───────────────────────────────────────────────────────────────────┘\n                                                           ✓ 3 succeeded, 5 not run\n\nUse your workspace to get the cached result of a step, e.g.\n\n >>> from tango import Workspace\n >>> workspace = Workspace.from_url(...)\n >>> workspace.step_result_for_run(\"right-amoeba\", \"sum\")\n```\n\nAs you can see, it re-used the cached results for several of the steps, and only ran three steps anew.\n\n```{eval-rst}\n:class:`tango.step.Step.VERSION` is just one of the ways in which you can change the behavior of a step. Head over to the\ndocumentation of the :class:`tango.step.Step` class to see the others.\n```\n"
  },
  {
    "path": "docs/source/examples/eval_p3.md",
    "content": "```{include} ../../../examples/eval_p3/README.md\n```\n\n## `RougeScoreStep`\n\n`RougeScoreStep` is defined in `eval.py`:\n\n```{literalinclude} ../../../examples/eval_p3/eval.py\n:language: py\n```\n\n## Config\n\nThe configuration file, `config.jsonnet`, uses some advanced [Jsonnet](https://jsonnet.org) concepts like `std.foldl`\nto create the same configuration for all 10 prompts:\n\n```{literalinclude} ../../../examples/eval_p3/config.jsonnet\n```\n\n## Run it\n\nYou can run the experiment with:\n\n```bash\ntango run config.jsonnet -i eval -d /tmp/workspace\n```\n"
  },
  {
    "path": "docs/source/examples/index.rst",
    "content": "Examples\n========\n\nReal-world examples of using Tango.\nYou can find all of these `on GitHub <https://github.com/allenai/tango/tree/main/examples>`_ as well.\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Examples\n\n   euler\n   train_lm\n   eval_p3\n"
  },
  {
    "path": "docs/source/examples/train_lm.md",
    "content": "# Fine-tuning a language model\n\n```{include} ../../../examples/train_lm/README.md\n:start-after: <!-- start overview -->\n:end-before: <!-- end overview -->\n```\n\n```{tip}\nYou can find the full code for this example on [GitHub](https://github.com/allenai/tango/tree/main/examples/train_lm).\n```\n\n## Components\n\nWe'll need to write a step for tokenizing the data and preparing it for language model training.\nAll of the other steps we need are provided by Tango integrations.\n\nSo, create a file called `tokenize_step.py` with following contents:\n\n```{literalinclude} ../../../examples/train_lm/tokenize_step.py\n:language: py\n```\n\n## Configuration file\n\nNext you'll need to create a configuration file that defines the experiment. Just copy over these contents into a file called `config.jsonnet`:\n\n\n```{literalinclude} ../../../examples/train_lm/config.jsonnet\n```\n\n## Run it\n\nNow we can run the experiment with:\n\n```bash\ntango run config.jsonnet -i tokenize_step.py -d /tmp/results\n```\n"
  },
  {
    "path": "docs/source/faq.md",
    "content": "# FAQ\n\n```{include} ../../README.md\n:start-after: <!-- start faq -->\n:end-before: <!-- end faq -->\n```\n"
  },
  {
    "path": "docs/source/first_steps.md",
    "content": "# First Steps\n\n## What is a Step?\n\nTango is a Python library for choreographing machine learning research experiments by executing\na series of steps.\nA step can do anything, really, such as [prepare a dataset](tango.integrations.datasets.LoadDataset), [train a model](tango.integrations.torch.TorchTrainStep), send an email to your mother wishing her happy birthday, *etc*.\n\nConcretely, each step is just a subclass of {class}`~tango.step.Step`, where the {meth}`~tango.step.Step.run` method in particular defines what the step actually does.\nSo anything that can be implemented in Python can be run as a step.\n\nSteps can also depend on other steps in that the output of one step can be part of the input to another step.\nTherefore, the steps that make up an experiment form a [directed graph](tango.step_graph.StepGraph).\n\nThe concept of the {class}`~tango.step.Step` is the bread and butter that makes Tango so general and powerful.\n*So* powerful, in fact, that you might be wondering if Tango is [Turing-complete](https://en.wikipedia.org/wiki/Turing_completeness)?\nWell, we don't know yet, but we can say at least that Tango is **Tango-complete** 😉\n\n## Configuration files\n\nExperiments themselves are defined through JSON, [Jsonnet](https://jsonnet.org/), or YAML configuration files.\nAt a minimum, these files must contain the \"steps\" field, which should be a mapping of arbitrary (yet unique) step names to the configuration of the corresponding step.\n\nFor example, let's create a config file called `config.jsonnet` with the following contents:\n\n```json\n{\n  \"steps\": {\n    \"random_name\": {\n      \"type\": \"random_choice\",\n      \"choices\": [\"Turing\", \"Tango\", \"Larry\"],\n    },\n    \"say_hello\": {\n      \"type\": \"concat_strings\",\n      \"string1\": \"Hello, \",\n      \"string2\": {\n        \"type\": \"ref\",\n        \"ref\": \"random_name\"\n      }\n    },\n    \"print\": {\n      \"type\": \"print\",\n      \"input\": {\n        \"type\": \"ref\",\n        \"ref\": \"say_hello\"\n      }\n    }\n  }\n}\n```\n\n*Can you guess what this experiment does?*\n\nThere are three steps in this experiment graph: \"random_name\" is the name of one step, \"say_hello\" is the name of another, and \"print\" is the name of the last.\nThe \"type\" parameter within the config of each step tells Tango which {class}`~tango.step.Step` class implementation to use for that step.\n\nSo, within the \"random_name\" step config\n\n```json\n\"random_name\": {\n  \"type\": \"random_choice\",\n  \"choices\": [\"Turing\", \"Tango\", \"Larry\"],\n}\n```\n\nthe `\"type\": \"random_choice\"` part tells Tango to use the {class}`~tango.step.Step` subclass that is registered by the name \"random_choice\".\n\nBut wait... what do we mean by *registered*?\n\nTango keeps track of an internal registry for certain classes (such as the {class}`~tango.step.Step` class) that is just a mapping of arbitrary unique names to subclasses.\nWhen you look through Tango's source code, you'll see things like:\n\n```python\n@Step.register(\"foo\")\nclass Foo(Step):\n    ...\n```\n\nThis is how subclasses get added to the registry.\nIn this case the subclass `Foo` is added to the `Step` registry under the name \"foo\", so if you were to use `\"type\": \"foo\"` in your configuration file, Tango would understand\nthat you mean to use the `Foo` class for the given step.\n\n```{tip}\nAny class that inherits from {class}`~tango.common.registrable.Registrable` can have its own\nregistry.\n```\n\nNow back to our example.\nThe step classes referenced in our configuration file (\"random_choice\" and \"concat_strings\") don't actually exist in the Tango library (though the [\"print\" step](tango.steps.PrintStep) does),\nbut we can easily implement and register them on our own.\n\nLet's put them in a file called `components.py`:\n\n```python\n# file: components.py\n\nimport random\nfrom typing import List\n\nfrom tango import Step\n\n@Step.register(\"random_choice\")\nclass RandomChoiceStep(Step):\n    DETERMINISTIC = False\n\n    def run(self, choices: List[str]) -> str:\n        return random.choice(choices)\n\n@Step.register(\"concat_strings\")\nclass ConcatStringsStep(Step):\n    def run(self, string1: str, string2: str) -> str:\n        return string1 + string2\n```\n\n```{important}\nIt's important that you use type hints in your code so that Tango can properly construct Python objects from the corresponding serialized (JSON) objects\nand warn you when the types don't match up.\n```\n\nSo as long as Tango is able to import this module (`components.py`) these step implementations will be added to the registry\nand Tango will know how to instantiate and run them.\n\nThere's also a short-hand way of implementing steps, using the {func}`@step() <tango.step.step>` function decorator:\n\n```python\nfrom tango import step\n\n@step(deterministic=False)\ndef random_choice(choices: List[str]) -> str:\n    return random.choice(choices)\n\n@step()\ndef concat_strings(string1: str, string2: str) -> str:\n    return string1 + string2\n```\n\nThis will register these steps under the name of the corresponding function, i.e. \"random_choice\" and \"concat_strings\", by default, though that can be overridden by specifying the \"name\" parameter to the decorator:\n\n```python\n@step(name=\"random-string\", deterministic=False)\ndef random_choice(choices: List[str]) -> str:\n    return random.choice(choices)\n```\n\n## Executing an experiment\n\nAt this point we've implemented our custom steps (`components.py`) and created our configuration\nfile `config.jsonnet`, so we're ready to actually run this experiment.\n\nFor that, just use the `tango run` command:\n\n```\n$ tango run config.jsonnet -i components\n```\n\n```{tip}\n- The `-i` option is short for `--include-package`, which takes the name of a Python package which Tango will try to import.\nIn this case our custom steps are in `components.py`, so we need Tango to import this module to find those steps.\nAs long as `components.py` is in the current directory or somewhere else on the `PYTHONPATH`, Tango will be able to find and import\nthis module when you pass `-i components` (note the lack of the `.py` at the end).\n```\n\nYou should see something like this in the output:\n\n```\nStarting new run cute-kitten\n● Starting step \"random_name\"\n✓ Finished step \"random_name\"\n● Starting step \"say_hello\"\n✓ Finished step \"say_hello\"\n● Starting step \"print\"\nHello, Tango\n✓ Finished step \"print\"\n```\n\n## Step caching\n\nThis particular experiment didn't write any results to disk, but in many situations you'll want to save the output of at least some of your steps.\n\nFor example, if you're using the {class}`~tango.integrations.torch.TorchTrainStep` step, the output is a trained model, which is certainly a useful thing to keep around.\nIn other cases, you may not actually care about the direct result of a particular step, but it could still be useful to save it when possible so that Tango doesn't need to run the step\nagain unnecessarily.\n\nThis is where Tango's caching mechanism comes in.\n\nTo demonstrate this, let's look at another example that pretends to do some expensive computation.\nHere is the `config.jsonnet` file:\n\n```json\n{\n  \"steps\": {\n    \"add_numbers\": {\n      \"type\": \"really_inefficient_addition\",\n      \"num1\": 34,\n      \"num2\": 8\n    }\n  }\n}\n```\n\nAnd let's implement \"really_inefficient_addition\":\n\n```python\n# components.py\n\nimport time\n\nfrom tango import Step, JsonFormat\nfrom tango.common import Tqdm\n\n\n@Step.register(\"really_inefficient_addition\")\nclass ReallyInefficientAdditionStep(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT = JsonFormat()\n\n    def run(self, num1: int, num2: int) -> int:\n        for _ in Tqdm.tqdm(range(100), desc=\"Computing...\", total=100):\n            time.sleep(0.05)\n        return num1 + num2\n```\n\nThere are a couple of things to note about this step, other than the obvious inefficiencies; the class variables\nwe've defined: {attr}`~tango.step.Step.DETERMINISTIC`, {attr}`~tango.step.Step.CACHEABLE`, and\n{attr}`~tango.step.Step.FORMAT`.\n\n`DETERMINISTIC = True` tells Tango that, given particular inputs, the output to this step will always be the same\nevery time it is ran, which has implications on caching.\nBy default, Tango assumes steps are deterministic.\nYou can override this by saying `DETERMINISTIC = False`.\nTango will warn you when you try to cache a non-deterministic step.\n\n`CACHEABLE = True` tells Tango that it can cache this step and `FORMAT = JsonFormat()` defines which\n{class}`~tango.format.Format` Tango will use to serialize the result of the step.\n\nThis time when we run the experiment we'll designate a specific directory for Tango to use:\n\n```bash\n$ tango run config.jsonnet -i components -d workspace/\n```\n```\nStarting new run live-tarpon\n● Starting step \"add_numbers\"\nComputing...: 100%|##########| 100/100 [00:05<00:00, 18.99it/s]\n✓ Finished step \"add_numbers\"\n✓ The output for \"add_numbers\" is in workspace/runs/live-tarpon/add_numbers\n```\n\nThe last line in the output tells us where we can find the result of our \"add_numbers\" step. `live-tarpon` is\nthe name of the run. Run names are randomly generated and may be different on your machine. `add_numbers` is the\nname of the step in your config. The whole path is a symlink to a directory, which contains (among other things)\na file `data.json`:\n\n```bash\n$ cat workspace/runs/live-tarpon/add_numbers/data.json\n```\n```\n42\n```\n\nNow look what happens when we run this step again:\n\n```bash\n$ tango run config.jsonnet -i components -d workspace/\n```\n```\nStarting new run modest-shrimp\n✓ Found output for \"add_numbers\" in cache\n✓ The output for \"add_numbers\" is in workspace/runs/modest-shrimp/add_numbers\n```\n\nTango didn't have to run our really inefficient addition step this time because it found the previous cached\nresult. It put the results in the result directory for a different run (in our case, the `modest-shrimp` run),\nbut once again it is a symlink that links to the same results from our first run.\n\nIf we changed the inputs to the step in `config.jsonnet`:\n\n```diff\n     \"add_numbers\": {\n       \"type\": \"really_inefficient_addition\",\n       \"num1\": 34,\n-      \"num2\": 8\n+      \"num2\": 2\n     }\n   }\n }\n```\n\nAnd ran it again:\n\n```bash\n$ tango run config.jsonnet -i components -d workspace/\n```\n```\nStarting new run true-parrot\n● Starting step \"add_numbers\"\nComputing...: 100%|##########| 100/100 [00:05<00:00, 19.13it/s]\n✓ Finished step \"add_numbers\"\n✓ The output for \"add_numbers\" is in workspace/runs/true-parrot/add_numbers\n```\n\nYou'd see that Tango had to run our \"add_numbers\" step again.\n\nYou may have noticed that `workspace/runs/true-parrot/add_numbers` is now a symlink that points to a different\nplace than it did for the first two runs. That's because it produced a different result this time. All the\nresult symlinks point into the `workspace/cache/` directory, where all the step's results are cached.\n\nThis means that if we ran the experiment again with the original inputs, Tango would still find the cached result\nand wouldn't need to rerun the step.\n\n## Arbitrary objects as inputs\n\n### `FromParams`\n\nSo far the inputs to all of the steps in our examples have been built-in Python types that can be deserialized from JSON (e.g. {class}`int`, {class}`str`, etc.),\nbut sometimes you need the input to a step to be an instance of an arbitrary Python class.\n\nTango allows this as well as it can infer from type hints what the class is and how to instantiate it.\nWhen writing your own classes, it's recommended that you have your class inherit from the {class}`~tango.common.from_params.FromParams` class, which will gaurantee that\nTango can instantiate it from a config file.\n\nFor example, suppose we had a step like this:\n\n```python\nfrom tango import Step\nfrom tango.common import FromParams\n\n\nclass Bar(FromParams):\n    def __init__(self, x: int) -> None:\n        self.x = x\n\n\n@Step.register(\"foo\")\nclass FooStep(Step):\n    def run(self, bar: Bar) -> int:\n        return bar.x\n```\n\n```{tip}\nIf you've used [AllenNLP](https://github.com/allenai/allennlp) before, this will look familiar!\nIn fact, it's the same system under the hood.\n```\n\nThen we could create a config like this:\n\n```json\n{\n  \"steps\": {\n    \"foo\": {\n      \"type\": \"foo\",\n      \"bar\": {\"x\": 1}\n    }\n  }\n}\n```\n\nAnd Tango will figure out how to deserialize `{\"x\": 1}` into a `Bar` instance.\n\nYou can also have `FromParams` objects nested within other `FromParams` objects or standard containers\nlike {class}`list`:\n\n```python\nfrom typing import List\n\nfrom tango import Step\nfrom tango.common import FromParams\n\n\nclass Bar(FromParams):\n    def __init__(self, x: int) -> None:\n        self.x = x\n\n\nclass Baz(FromParams):\n    def __init__(self, bar: Bar) -> None:\n        self.bar = bar\n\n\n@Step.register(\"foo\")\nclass FooStep(Step):\n    def run(self, bars: List[Bar], baz: Baz) -> int:\n        return sum([bar.x for bar in bars]) + baz.bar.x\n```\n\n### `Registrable`\n\nThe {class}`~tango.common.registrable.Registrable` class is a special kind of {class}`~tango.common.from_params.FromParams` class that allows you to specify from the config which subclass of an expected class to deserialize into.\n\nThis is actually how we've been instantiating specific `Step` subclasses. Because {class}`~tango.step.Step` inherits from {class}`~tango.common.registrable.Registrable`, we can use the `\"type\"` fields in the config file to specify a `Step` subclass.\n\nThis is also very useful when you're writing a step that requires a certain type as input, but you want to be able to change the exact subclass of the type from your config file. For example, the {class}`~tango.integrations.torch.TorchTrainStep` takes `Registrable` inputs such as {class}`~tango.integrations.torch.Model`. Model variants can then be subclasses that are specified in the config file by their registered names. A sketch of this might look like the following: \n\n```python\nfrom tango import Step\nfrom tango.common import FromParams, Registrable\n\nclass Model(torch.nn.Module, Registrable):\n    ...\n\n@Model.register(\"variant1\")\nclass Variant1(Model):\n    ...\n\n@Model.register(\"variant2\")\nclass Variant2(Model):\n    ...\n\n@Step.register(\"torch::train\")\nclass TorchTrainerStep(Step):\n    def run(self, model: Model, ...) -> Model:\n        ...\n```\n\nAnd a sketch of the config file would be something like this:\n\n```json\n{\n  \"steps\": {\n    \"train\": {\n      \"type\": \"torch::train\",\n      \"model\": {\n        \"type\": \"variant1\",\n      }\n    }\n  }\n}\n```\n\nAs in the `FromParams` example the specifications can be nested, but now we also denote the subclass with the `\"type\": \"...\"` field. To swap models we need only change \"variant1\" to \"variant2\" in the config. The value for \"type\" can either be the name that the class is registered under (e.g. \"train\" for `TorchTrainStep`), or the fully qualified class name (e.g. `tango.integrations.torch.TorchTrainStep`).\n\nYou'll see more examples of this in the [next section](examples/index).\n"
  },
  {
    "path": "docs/source/index.md",
    "content": "# **AI2 Tango**\n\n```{include} ../../README.md\n:start-after: <!-- start tagline -->\n:end-before: <!-- end tagline -->\n```\n\n```{toctree}\n:maxdepth: 2\n:hidden:\n:caption: Getting started\n\ninstallation\nfirst_steps\nexamples/index\nfaq\n```\n\n```{toctree}\n:maxdepth: 2\n:hidden:\n:caption: API Reference\n\napi/commands\napi/components/index\napi/integrations/index\napi/settings\napi/exceptions\napi/logging\napi/sequences\napi/det_hash\napi/utilities\n```\n\n```{toctree}\n:hidden:\n:caption: Development\n\nCONTRIBUTING\nCHANGELOG\nLicense <https://raw.githubusercontent.com/allenai/tango/main/LICENSE>\nGitHub Repository <https://github.com/allenai/tango>\n```\n\nTo learn about Tango in 5 minutes, head over to the [First Steps section](first_steps).\n\nIf you'd rather learn from examples, check out the [Examples section](examples/index).\n\n## Team\n\n```{include} ../../README.md\n:start-after: <!-- start team -->\n:end-before: <!-- end team -->\n```\n\n## License\n\n```{include} ../../README.md\n:start-after: <!-- start license -->\n:end-before: <!-- end license -->\n```\n\n## Indices and tables\n\n```{eval-rst}\n* :ref:`genindex`\n* :ref:`modindex`\n```\n"
  },
  {
    "path": "docs/source/installation.md",
    "content": "Installation\n============\n\n```{include} ../../README.md\n:start-after: <!-- start install -->\n:end-before: <!-- end install -->\n```\n"
  },
  {
    "path": "examples/euler/README.md",
    "content": "Euler\n=====\n\nThis is a toy example that proves Euler's identity using Tango. You can use this to play with the concept of a\n`Step` and see how Tango runs things without getting distracted by the details of what you're running."
  },
  {
    "path": "examples/euler/complex_arithmetic.py",
    "content": "import cmath\nfrom typing import Tuple, Union\n\nfrom tango import Step\n\nComplexOrTuple = Union[complex, Tuple[float, float]]\n\n\ndef make_complex(x: ComplexOrTuple) -> complex:\n    if isinstance(x, complex):\n        return x\n    elif isinstance(x, (int, float)):\n        return complex(x)\n    else:\n        return complex(*x)\n\n\n@Step.register(\"cadd\")\nclass AdditionStep(Step):\n    def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex:  # type: ignore\n        return make_complex(a) + make_complex(b)\n\n\n@Step.register(\"csub\")\nclass SubtractionStep(Step):\n    def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex:  # type: ignore\n        return make_complex(a) - make_complex(b)\n\n\n@Step.register(\"cexp\")\nclass ExponentiateStep(Step):\n    def run(self, x: ComplexOrTuple, base: ComplexOrTuple = cmath.e) -> complex:  # type: ignore\n        return make_complex(base) ** make_complex(x)\n\n\n@Step.register(\"cmul\")\nclass MultiplyStep(Step):\n    def run(self, a: ComplexOrTuple, b: ComplexOrTuple) -> complex:  # type: ignore\n        return make_complex(a) * make_complex(b)\n\n\n@Step.register(\"csin\")\nclass SineStep(Step):\n    def run(self, x: ComplexOrTuple) -> complex:  # type: ignore\n        return cmath.sin(make_complex(x))\n\n\n@Step.register(\"ccos\")\nclass CosineStep(Step):\n    def run(self, x: ComplexOrTuple) -> complex:  # type: ignore\n        return cmath.cos(make_complex(x))\n"
  },
  {
    "path": "examples/euler/euler.jsonnet",
    "content": "local i = [0.0, 1.0];\nlocal pi = [3.1415926535, 0.0];\n\n{\n    \"steps\": {\n        \"i_times_pi\": {\n            \"type\": \"cmul\",\n            \"a\": i,\n            \"b\": pi\n        },\n        \"pow_e\": {\n            \"type\": \"cexp\",\n            \"x\": { \"type\": \"ref\", \"ref\": \"i_times_pi\" }\n        },\n        \"plus_one\": {\n            \"type\": \"cadd\",\n            \"a\": { \"type\": \"ref\", \"ref\": \"pow_e\" },\n            \"b\": [1, 0]\n        },\n        \"print\": {\n            \"type\": \"print\",\n            \"input\": { \"type\": \"ref\", \"ref\": \"plus_one\" }\n        }\n    }\n}"
  },
  {
    "path": "examples/euler/euler_general.jsonnet",
    "content": "local i = [0.0, 1.0];\nlocal pi = [3.1415926535, 0.0];\n\n{\n    \"steps\": {\n        \"cos\": {\n            \"type\": \"ccos\",\n            \"x\": pi\n        },\n        \"sin\": {\n            \"type\": \"csin\",\n            \"x\": pi\n        },\n        \"i_times_sin\": {\n            \"type\": \"cmul\",\n            \"a\": i,\n            \"b\": { \"type\": \"ref\", \"ref\": \"sin\" }\n        },\n        \"sum\": {\n            \"type\": \"cadd\",\n            \"a\": { \"type\": \"ref\", \"ref\": \"cos\" },\n            \"b\": { \"type\": \"ref\", \"ref\": \"i_times_sin\" },\n        },\n\n        \"i_times_pi\": {\n            \"type\": \"cmul\",\n            \"a\": i,\n            \"b\": pi\n        },\n        \"pow_e\": {\n            \"type\": \"cexp\",\n            \"x\": { \"type\": \"ref\", \"ref\": \"i_times_pi\" }\n        },\n\n        \"sub\": {\n            \"type\": \"csub\",\n            \"a\": { \"type\": \"ref\", \"ref\": \"sum\" },\n            \"b\": { \"type\": \"ref\", \"ref\": \"pow_e\" },\n        },\n\n        \"print\": {\n            \"type\": \"print\",\n            \"input\": { \"type\": \"ref\", \"ref\": \"sub\" }\n        }\n    }\n}"
  },
  {
    "path": "examples/euler/run.sh",
    "content": "#!/bin/bash\n\ntango run euler_general.jsonnet -d workspace --include-package complex_arithmetic\n"
  },
  {
    "path": "examples/eval_p3/README.md",
    "content": "# Evaluating T0\n\nThis example uses the `transformers::run_generation_dataset` step to run the\n[T0 model](https://api.semanticscholar.org/CorpusID:239009562). It runs the\n[XSum summarization data](https://github.com/EdinburghNLP/XSum), prompted in 10 different ways, and computes\nROUGE scores for all variants. Finally, it computes an overall ROUGE score.\n\nThis example uses mostly built-in Tango steps. You will need the `datasets` and `transformers` integrations.\nThe only custom step in this example is the `RougeScoreStep`, which computes ROUGE scores from the\ngenerated text."
  },
  {
    "path": "examples/eval_p3/config.jsonnet",
    "content": "local model = \"bigscience/T0_3B\";\nlocal batch_size = 8;\n\nlocal datasets = [\n    'xsum_DOC_boils_down_to_simple_idea_that',\n    'xsum_DOC_given_above_write_one_sentence',\n    'xsum_DOC_how_would_you_rephrase_few_words',\n    'xsum_DOC_tldr',\n    'xsum_DOC_write_summary_of_above',\n    'xsum_article_DOC_summary',\n    'xsum_college_roommate_asked_DOC_so_I_recap',\n    'xsum_read_below_DOC_write_abstract',\n    'xsum_summarize_DOC',\n    'xsum_summarize_this_DOC_summary'\n];\n\n# This creates three steps for each of the datasets:\n# 1. Load the dataset.\n# 2. Generate output based on the dataset.\n# 3. Evaluate the output against the gold answers.\nlocal dataset_steps = std.foldl(\n    function(x, dataset_name) x + {\n        [\"dataset_\" + dataset_name]: {\n            \"type\": \"datasets::load\",\n            \"path\": \"bigscience/P3\",\n            \"name\": dataset_name,\n        },\n        [\"generation_\" + dataset_name]: {\n            \"type\": \"transformers::run_generation_dataset\",\n            \"max_length\": 200,\n            \"input\": {\"ref\": \"dataset_\" + dataset_name},\n            \"batch_size\": batch_size,\n            \"model\": model,\n            \"prompt_field\": \"inputs_pretokenized\",\n            \"output_field\": \"generation\",\n            \"splits\": [\"validation\"]\n        },\n        [\"eval_\" + dataset_name]: {\n            \"type\": \"rouge_score\",\n            \"input\": {\"ref\": \"generation_\" + dataset_name},\n            \"input_split\": \"validation\",\n            \"target_field\": \"targets_pretokenized\",\n            \"prediction_field\": \"generation\"\n        }\n    },\n    datasets,\n    {}\n);\n\n# In addition to the three steps per dataset, we also combine all the generations and\n# evaluate them all together.\n{\n    \"steps\": dataset_steps + {\n        \"all_generations\": {\n            \"type\": \"dataset_combine\",\n            \"inputs\": std.map(\n                function(dataset_name) {\"ref\": \"generation_\" + dataset_name},\n                datasets\n            )\n        },\n        \"all_evaluations\": {\n            \"type\": \"rouge_score\",\n            \"input\": {\"ref\": \"all_generations\"},\n            \"input_split\": \"validation\",\n            \"target_field\": \"targets_pretokenized\",\n            \"prediction_field\": \"generation\"\n        }\n    }\n}\n"
  },
  {
    "path": "examples/eval_p3/eval.py",
    "content": "import logging\nfrom typing import Dict\n\nfrom torch import Tensor\nfrom torchmetrics.text.rouge import ROUGEScore\n\nfrom tango import Format, JsonFormat, Step\nfrom tango.common import DatasetDict\nfrom tango.common.tqdm import Tqdm\n\nlogger = logging.getLogger(__name__)\n\n\n@Step.register(\"rouge_score\")\nclass RougeScoreStep(Step[Dict[str, Tensor]]):\n    VERSION = \"002\"\n    FORMAT: Format = JsonFormat()\n\n    def run(  # type: ignore\n        self,\n        input: DatasetDict,\n        input_split: str,\n        target_field: str,\n        prediction_field: str,\n        use_stemmer: bool = True,\n    ) -> Dict[str, Tensor]:\n        metric = ROUGEScore(\n            use_stemmer=use_stemmer,\n            rouge_keys=(\"rouge1\", \"rouge2\", \"rougeL\"),\n            accumulate=\"avg\",\n        )\n\n        for instance in Tqdm.tqdm(input[input_split], desc=\"Calculating scores\"):\n            target = instance[target_field]\n            for prediction in instance[prediction_field]:\n                metric.update(prediction, target)\n\n        return metric.compute()\n"
  },
  {
    "path": "examples/finetune/__init__.py",
    "content": ""
  },
  {
    "path": "examples/finetune/config.jsonnet",
    "content": "##################\n# Model settings #\n##################\n\nlocal pretrained_model = \"t5-base\";\nlocal load_with_low_cpu_mem_usage = false;\n\nlocal modules_to_wrap = [\"[a-zA-Z_.]+\\\\.[0-9]+\"];  # TODO: works for t5 and gpt2. confirm with other models too.\n\n####################\n# Trainer settings #\n####################\n\n# Trainer settings, adjust to your use-case.\nlocal training_steps = 20;  # total number of optimization steps to train for\nlocal validate_every = 5;  # how often to validate and save checkpoints\n\nlocal devices = 1;  # number of devices to train on (will use GPUs if enough are available, otherwise CPU)\nlocal grad_accum = 1;  # number of gradient accumulation steps (changes the effective batch size)\n# This is the batch size per GPU, ignoring gradient accumulation:\nlocal batch_size = 2;\n# So the effective batch size is `batch_size * grad_accum * devices`\n\nlocal activation_checkpointing = false;  # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2)\nlocal amp = false;  # use PyTorch's native automatic mixed precision\nlocal fsdp = false;  # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2)\nlocal cpu_offloading = false;  # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.\n\n######################\n# Optimizer settings #\n######################\n\nlocal warmup_steps = 20;\nlocal learning_rate = 0.00005;  # you can probably use a higher LR for a small model like \"gpt2\"\n\n\nassert fsdp == true || cpu_offloading == false : \"cpu_offloading only available with fsdp\";\n\n# FullyShardedDataParallel config:\nlocal fsdp_config = if fsdp then {\n    reshard_after_forward: true,\n    move_params_to_cpu: cpu_offloading,\n    move_grads_to_cpu: cpu_offloading,\n    mixed_precision: amp,\n} else null;\n\nlocal training_engine = {\n    type: if fsdp then \"fairscale\" else \"torch\",\n    optimizer: {\n        type: \"torch::AdamW\",\n        lr: learning_rate,\n        betas: [0.9, 0.95],\n        eps: 1e-6,\n    },\n    lr_scheduler: {\n        type: \"transformers::linear\",\n        num_warmup_steps: warmup_steps,\n        num_training_steps: training_steps,\n    },\n    amp: amp,\n    [if fsdp then \"fsdp_config\" else null]: fsdp_config,\n};\n\nlocal distributed_dataloader = {\n    batch_size: batch_size,\n    sampler: {\n        type: \"torch::DistributedSampler\",\n        shuffle: true,\n        drop_last: true,\n    },\n};\n\nlocal single_device_dataloader = {\n    shuffle: true,\n    batch_size: batch_size,\n};\n\nlocal dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader;\n\n{\n    steps: {\n        raw_data: {\n            type: \"datasets::load\",\n            path: \"snli\",\n        },\n        /*\"subset_data\": {\n            type: \"subset-data\",\n            data: { type: \"ref\", ref: \"raw_data\" },\n            max_samples: 10,\n        },*/\n        processed_data: {\n            type: \"snli-text2text\",\n            data: { type: \"ref\", ref: \"raw_data\" },\n        },\n        trained_model: {\n            type: \"transformers::finetune\",\n            model: {\n                type: \"fairscale::with_wrapped_modules\",\n                model: {\n                    type: \"transformers::finetune::from_pretrained\",\n                    pretrained_model_name_or_path: pretrained_model,\n                    low_cpu_mem_usage: load_with_low_cpu_mem_usage,\n                },\n                modules_to_wrap: modules_to_wrap,  # tell FairScale to wrap the transformer's blocks individually\n                fsdp_config: fsdp_config,\n                activation_checkpointing: activation_checkpointing,\n            },\n            tokenizer: {\n                pretrained_model_name_or_path: pretrained_model\n            },\n            dataset_dict: { type: \"ref\", ref: \"processed_data\" },\n            train_dataloader: dataloader,\n            validation_split: \"validation\",\n            grad_accum: grad_accum,\n            train_steps: training_steps,\n            validate_every: validate_every,\n            checkpoint_every: validate_every,\n            log_every: 1,\n            device_count: devices,\n            training_engine: training_engine,\n        },\n        generations: {\n            type: \"transformers::run_generation_dataset\",\n            max_length: 5,\n            input: {\"type\": \"ref\", \"ref\": \"processed_data\"},\n            batch_size: batch_size,\n            model: {\"type\": \"ref\", \"ref\": \"trained_model\"},\n            prompt_field: \"source\",\n            output_field: \"generation\",\n            splits: [\"validation\"]\n        }\n    }\n}\n"
  },
  {
    "path": "examples/finetune/snli_steps.py",
    "content": "from typing import Union\n\nimport datasets as ds\n\nfrom tango.integrations.datasets import DatasetsFormat\nfrom tango.step import Step\n\n\n@Step.register(\"subset-data\")\nclass SubsetData(Step):\n    \"\"\"\n    Creates a subset of the data; mostly to be used for testing/debugging.\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    VERSION = \"001\"\n\n    FORMAT = DatasetsFormat()\n\n    def run(  # type: ignore\n        self,\n        data: Union[ds.DatasetDict, ds.Dataset],\n        max_samples: int = 5,\n    ) -> Union[ds.DatasetDict, ds.Dataset]:\n        \"\"\"\n        Returns a copy of the `data` with number of samples limited to `max_samples` for\n        each split.\n\n        :param data:\n            The dataset or dataset dict object.\n        :param max_samples:\n            The maximum number of samples to return per split.\n        \"\"\"\n\n        # Unlike `ds.Dataset.select`, this works on both `ds.Dataset` and `ds.DatasetDict`.\n        def filter_fn(example, indices):\n            return indices < max_samples\n\n        return data.filter(filter_fn, with_indices=True)\n\n\n@Step.register(\"snli-text2text\")\nclass SnliText2Text(Step):\n    \"\"\"\n    Converts the snli dataset to a text-to-text format.\n\n    Examples\n    --------\n\n    original_instance = {\n        \"premise\": \"Two cats are sitting on a wall.\",\n        \"hypothesis\": \"The cats are chasing a mouse.\",\n        \"label\": 2  # contradiction\n    }\n\n    returned_instance = {\n        \"source\": \"nli premise: Two cats are sitting on a wall. hypothesis: The cats are chasing a mouse. label: \"\n        \"target\": \"contradiction\"\n    }\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    VERSION = \"001\"\n\n    FORMAT = DatasetsFormat()\n\n    def run(  # type: ignore\n        self,\n        data: Union[ds.DatasetDict, ds.Dataset],\n        source_prefix: str = \"nli\",\n        premise_prefix: str = \"premise\",\n        hypothesis_prefix: str = \"hypothesis\",\n        label_prefix: str = \"label\",\n        num_workers: int = 1,\n    ) -> Union[ds.DatasetDict, ds.Dataset]:\n        \"\"\"\n        :param data:\n            The snli `Dataset` or `DatasetDict` object.\n        :param source_prefix:\n            The str to add before the start of the source sequence.\n        :param premise_prefix:\n            The str to add before the start of the `premise` in the source sequence.\n        :param hypothesis_prefix:\n            The str to add before the start of the `hypothesis` in the source sequence.\n        :param label_prefix:\n            The str to add as the prompt for the label.\n        :param num_workers:\n            The number of workers to use for processing the data.\n        \"\"\"\n\n        def filter_no_gold(example, indices):\n            if example[\"label\"] == -1:\n                return False\n            return True\n\n        data = data.filter(filter_no_gold, with_indices=True)\n\n        label_map = {0: \"entailment\", 1: \"neutral\", 2: \"contradiction\"}\n\n        def _mapper(example):\n            return {\n                \"source\": (\n                    f'{source_prefix} {premise_prefix}: {example[\"premise\"]} '\n                    f'{hypothesis_prefix}: {example[\"hypothesis\"]} {label_prefix}: '\n                ),\n                \"target\": f'{label_map[example[\"label\"]]}',\n            }\n\n        if isinstance(data, ds.Dataset):\n            old_cols = data.column_names\n        else:\n            old_cols = list(data.column_names.values())[0]\n\n        dataset = data.map(\n            _mapper,\n            batched=False,\n            num_proc=num_workers,\n            remove_columns=old_cols,  # remove all old columns\n            desc=\"Converting data to text-to-text format\",\n        )\n\n        return dataset\n"
  },
  {
    "path": "examples/finetune/test.py",
    "content": "import typing\n\nimport datasets as ds\nimport pytest\n\nfrom tango.common import Params\nfrom tango.common.testing import TangoTestCase, run_experiment\n\n\nclass TestFinetuneSNLI(TangoTestCase):\n    @pytest.mark.parametrize(\n        \"model, model_type\",\n        [(\"patrickvonplaten/t5-tiny-random\", \"t5\"), (\"sshleifer/tiny-gpt2\", \"gpt2\")],\n    )\n    @typing.no_type_check  # mypy has become incompatible with the datasets library\n    def test_config(self, model: str, model_type: str):\n        overrides = {\n            \"steps.trained_model.model.model.pretrained_model_name_or_path\": model,\n            \"steps.trained_model.tokenizer.pretrained_model_name_or_path\": model,\n            \"steps.subset_data\": {\n                \"type\": \"subset-data\",\n                \"data\": {\"type\": \"ref\", \"ref\": \"raw_data\"},\n                \"max_samples\": 10,\n            },\n            \"steps.processed_data.data.ref\": \"subset_data\",\n        }\n        config = Params.from_file(\"config.jsonnet\", params_overrides=overrides)\n        # Make sure we've overrode the model entirely.\n        flattened = config.as_flat_dict()\n        for key, value in flattened.items():\n            if \"model_name\" in key or (isinstance(value, str) and model_type in value):\n                assert value == model\n\n        with run_experiment(config, include_package=[\"snli_steps.py\"]) as run_dir:\n            assert (run_dir / \"processed_data\").is_dir()\n            processed = ds.load_from_disk(run_dir / \"processed_data\" / \"data\")\n            assert len(processed[\"train\"][0].keys()) == 2\n            assert \"source\" in processed[\"train\"][0].keys()\n            assert \"target\" in processed[\"train\"][0].keys()\n            assert processed[\"train\"][0][\"source\"].startswith(\"nli premise:\")\n\n            assert (run_dir / \"trained_model\").is_dir()\n"
  },
  {
    "path": "examples/finetune_resnet/.gitignore",
    "content": "data/\nresults/\nextra_testing.py\n"
  },
  {
    "path": "examples/finetune_resnet/config.jsonnet",
    "content": "local input_size = 224;\nlocal batch_size = 32;\nlocal num_classes = 2;\nlocal val_size = 0.05;\nlocal model = \"resnet\";\nlocal feature_extract = true;\nlocal distributed = false;\nlocal devices = if distributed then 2 else 1;\nlocal pretrained_model = \"resnet_ft\";\nlocal training_steps = 500;\nlocal validate_every = 50;\nlocal image_url = \"https://tinyurl.com/2p9xjvn9\";\n\nlocal distributed_dataloader = {\n    batch_size: batch_size,\n    sampler: {\n        type: \"torch::DistributedSampler\",\n        shuffle: true,\n        drop_last: true,\n    },\n    collate_fn: {\"type\": \"image_collator\"},\n};\n\nlocal single_device_dataloader = {\n    shuffle: true,\n    batch_size: batch_size,\n    collate_fn: {\"type\": \"image_collator\"},\n};\n\n{\n    steps: {\n        raw_data: {\n            type: \"datasets::load\",\n            path: \"nateraw/auto-cats-and-dogs\",\n            name: \"cats_and_dogs\",\n        },\n        transform_data: {\n            type: \"transform_data\",\n            dataset: { type: 'ref', ref: 'raw_data' },\n            input_size: input_size,\n            val_size: val_size,\n        },\n        trained_model: {\n            type: \"torch::train\",\n            model: {\n                type: pretrained_model,\n                num_classes: num_classes,\n                feature_extract: true,\n                use_pretrained: true,\n            },\n            training_engine: {\n                optimizer: {\n                    type: \"torch_adam\",\n                    lr: 0.001,\n                },\n            },\n            dataset_dict: {\"type\": \"ref\", \"ref\": \"transform_data\"},\n            train_dataloader: single_device_dataloader,\n            validation_split: \"val\",\n            val_metric_name: \"accuracy\",\n            train_steps: training_steps,\n            validate_every: validate_every,\n            checkpoint_every: validate_every,\n            log_every: 1,\n            device_count: devices,\n            minimize_val_metric: false,\n        },\n        prediction: {\n            type: \"prediction\",\n            image_url: image_url,\n            input_size: input_size,\n            model: {\"type\": \"ref\", \"ref\": \"trained_model\"},\n        },\n    },\n}\n \n"
  },
  {
    "path": "examples/finetune_resnet/resnet_steps.py",
    "content": "from typing import Any, Dict, List, Optional\n\nimport datasets\nimport torch\nfrom cached_path import cached_path\nfrom PIL import Image\nfrom torch import nn\nfrom torch.optim import Adam\nfrom torchvision import models, transforms\n\nfrom tango import Format, JsonFormat, Step\nfrom tango.integrations.torch import DataCollator, Model, Optimizer\n\n# Register the Adam optimizer as an `Optimizer` so we can use it in the train step.\nOptimizer.register(\"torch_adam\")(Adam)\n\n\n# Wrapper class around the pre-trained ResNet-18 model that modifies the final layer.\n@Model.register(\"resnet_ft\")\nclass ResNetWrapper(Model):\n    def __init__(self, num_classes: int, feature_extract: bool, use_pretrained: bool):\n        super().__init__()\n        self.model_ft = models.resnet18(pretrained=use_pretrained)\n        self.set_parameter_requires_grad(self.model_ft, feature_extract)\n        num_features = self.model_ft.fc.in_features\n        self.model_ft.fc = nn.Linear(num_features, num_classes)\n        self.loss_fn = nn.CrossEntropyLoss()\n\n    def set_parameter_requires_grad(self, model: models, feature_extracting: bool):\n        if feature_extracting:\n            for param in model.parameters():\n                param.requires_grad = False\n\n    def forward(  # type: ignore\n        self, image: torch.Tensor, label: Optional[torch.Tensor] = None\n    ) -> Dict[str, torch.Tensor]:\n        output = self.model_ft(image)\n        preds = torch.argmax(output, dim=1)\n        if label is None:\n            return {\"preds\": preds}\n        loss = self.loss_fn(output, label)\n        accuracy = (preds == label).float().mean()\n        return {\"loss\": loss, \"accuracy\": accuracy}\n\n\n# Custom data collator for images, that takes in a batch of images and labels and\n# reformats the data so that it is suitable for the model.\n@DataCollator.register(\"image_collator\")\nclass ImageCollator(DataCollator[Dict[str, Any]]):\n    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:\n        return {\n            \"image\": torch.cat([item[\"image\"].unsqueeze(0) for item in batch], dim=0),\n            \"label\": torch.tensor([item[\"labels\"] for item in batch]),\n        }\n\n\n# Function that returns an image transformations dict with the appropriate image size.\ndef get_data_transforms(input_size: int):\n    data_transforms = {\n        \"train\": transforms.Compose(\n            [\n                transforms.RandomResizedCrop(input_size),\n                transforms.RandomHorizontalFlip(),\n                transforms.ToTensor(),\n                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n            ]\n        ),\n        \"val\": transforms.Compose(\n            [\n                transforms.Resize(input_size),\n                transforms.CenterCrop(input_size),\n                transforms.ToTensor(),\n                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n            ]\n        ),\n    }\n    return data_transforms\n\n\n# loads and image and applies the appropriate transformation\ndef pil_loader(path: str, input_size: int, transform_type: str):\n    with open(path, \"rb\") as f:\n        image = Image.open(f)\n        image = image.convert(\"RGB\")\n        transform = get_data_transforms(input_size=input_size)[transform_type]\n        transformed_image = transform(image)\n        return transformed_image\n\n\n# calls the image loader on every image in a given batch\ndef image_loader(example_batch, input_size: int, transform_type: str):\n    example_batch[\"image\"] = [\n        pil_loader(f, input_size, transform_type) for f in example_batch[\"file\"]\n    ]\n    return example_batch\n\n\n# This step takes in raw image data and transforms and tokenizes it.\n@Step.register(\"transform_data\")\nclass TransformData(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(  # type: ignore\n        self, dataset: datasets.DatasetDict, val_size: float, input_size: int\n    ) -> datasets.DatasetDict:\n        def image_loader_wrapper(example_batch):\n            return image_loader(example_batch, input_size=input_size, transform_type=\"train\")\n\n        dataset = dataset.with_transform(image_loader_wrapper)\n        train_val = dataset[\"train\"].train_test_split(test_size=val_size)\n        train_val[\"val\"] = train_val.pop(\"test\")\n        return train_val\n\n\n# function to map integer labels to string labels\ndef convert_to_label(int_label: int) -> str:\n    if int_label == 0:\n        return \"cat\"\n    else:\n        return \"dog\"\n\n\n@Step.register(\"prediction\")\nclass Prediction(Step):\n    FORMAT: Format = JsonFormat()\n\n    def run(  # type: ignore\n        self, image_url: str, input_size: int, model: models, device: Optional[str] = \"cpu\"\n    ) -> Dict[str, Any]:\n        # download and store image\n        image_path = cached_path(image_url)\n        transformed_image = pil_loader(image_path, input_size, transform_type=\"val\")\n\n        # pass image through transform\n        transformed_image = transformed_image.unsqueeze(0).to(device)\n\n        # pass image through model and get the prediction\n        prediction = model(image=transformed_image, label=None)[\"preds\"][0].float()\n        label = convert_to_label(prediction)\n        return {\"image_url\": image_url, \"local_path\": image_path, \"label\": label}\n"
  },
  {
    "path": "examples/flax/config.jsonnet",
    "content": "{\n    \"steps\": {\n        \"data\": {\n            \"type\": \"datasets::load\",\n            \"path\": \"xsum\",\n        },\n        \"tokenize\": {\n            \"type\": \"tokenize_data\",\n            \"dataset\": {\n                \"type\": \"ref\",\n                \"ref\": \"data\"\n            }\n        },\n        \"train\": {\n            \"type\": \"flax::train\",\n            \"model\": {\n                \"type\": \"transformers::FlaxAutoModelForSeq2SeqLM::from_pretrained\",\n                \"pretrained_model_name_or_path\": \"facebook/bart-base\"\n            },\n            \"dataset\": {\n                \"type\": \"ref\",\n                \"ref\": \"tokenize\"\n            },\n            \"optimizer\": {\n                \"type\" : \"optax::adamw\",\n                \"learning_rate\" : 2e-5\n            },\n            \"train_dataloader\": {\n                \"batch_size\": 16,\n                \"drop_last\": true\n            },\n            \"wrapper\": {\n                \"type\": \"xsum_wrapper\"\n            },\n            \"train_split\": \"train\",\n            \"validation_split\" : \"validation\",\n            \"validate_every\" : 1000,\n            \"validation_dataloader\": {\n                \"batch_size\": 16,\n                \"drop_last\": true\n            },\n            \"train_epoch\": 5,\n            \"checkpoint_every\": 1000,\n            \"log_every\": 1000,\n\n            \"callbacks\" : [\n                //{\"type\" : \"wandb::log_flax\"},\n                {\"type\": \"flax::generate_step\"}\n            ]\n        },\n        \"eval\": {\n            \"type\": \"flax::eval\",\n            \"state\": {\n                \"type\": \"ref\",\n                \"ref\": \"train\"\n            },\n            \"dataset\": {\n                \"type\": \"ref\",\n                \"ref\": \"tokenize\"\n            },\n            \"dataloader\": {\n                \"batch_size\": 16,\n                \"drop_last\": true\n            },\n            \"wrapper\": {\n                \"type\" : \"xsum_wrapper\"\n            }\n        }\n    }\n}"
  },
  {
    "path": "examples/flax/run.sh",
    "content": "#!/bin/bash\n\ntango run config.jsonnet -d workspace --include-package xsum\n"
  },
  {
    "path": "examples/flax/xsum.py",
    "content": "import logging\nfrom typing import List, Optional\n\nimport jax\nimport jax.numpy as jnp\nimport nltk\nimport numpy as np\nimport optax\nfrom datasets import load_metric\nfrom flax.training.common_utils import onehot\nfrom transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSeq2SeqLM\n\nfrom tango.integrations.flax import FlaxWrapper\nfrom tango.integrations.flax.train_callback import TrainCallback\nfrom tango.step import Step\n\n\"\"\"\nXSum Summarization with facebook/bart-base\n\"\"\"\n\n\n@Step.register(\"tokenize_data\")\nclass PreProcessing(Step):\n    DETERMINISTIC = False\n\n    def run(self, dataset):\n        tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-base\")\n        model = FlaxAutoModelForSeq2SeqLM.from_pretrained(\"facebook/bart-base\")\n        model_module = __import__(model.__module__, fromlist=[\"shift_tokens_tight\"])\n        shift_tokens_right_fn = getattr(model_module, \"shift_tokens_right\")\n        config = AutoConfig.from_pretrained(\"facebook/bart-base\")\n\n        MAX_SOURCE_LENGTH = 512\n        MAX_TGT_LENGTH = 64\n\n        def preprocess_function(examples):\n            inputs = examples[\"document\"]\n            targets = examples[\"summary\"]\n            inputs = [inp for inp in inputs]\n            model_inputs = tokenizer(\n                inputs,\n                max_length=MAX_SOURCE_LENGTH,\n                padding=\"max_length\",\n                truncation=True,\n                return_tensors=\"np\",\n            )\n\n            # Setup the tokenizer for targets\n            with tokenizer.as_target_tokenizer():\n                labels = tokenizer(\n                    targets,\n                    max_length=MAX_TGT_LENGTH,\n                    padding=\"max_length\",\n                    truncation=True,\n                    return_tensors=\"np\",\n                )\n\n            model_inputs[\"labels\"] = labels[\"input_ids\"]\n            decoder_input_ids = shift_tokens_right_fn(\n                labels[\"input_ids\"], config.pad_token_id, config.decoder_start_token_id\n            )\n            model_inputs[\"decoder_input_ids\"] = np.asarray(decoder_input_ids)\n\n            # We need decoder_attention_mask so we can ignore pad tokens from loss\n            model_inputs[\"decoder_attention_mask\"] = labels[\"attention_mask\"]\n\n            return model_inputs\n\n        column_names = dataset[\"train\"].column_names\n\n        dataset = dataset.map(\n            preprocess_function,\n            batched=True,\n            remove_columns=column_names,\n            desc=\"Running tokenizer on dataset\",\n        )\n\n        return dataset\n\n\n@FlaxWrapper.register(\"xsum_wrapper\")  # type: ignore\nclass TransformerWrapper(FlaxWrapper):\n    def loss_helper(self, logits, labels, batch):\n        label_smoothing_factor = 0\n        padding_mask = batch[\"decoder_attention_mask\"]\n        vocab_size = logits.shape[-1]\n        confidence = 1.0 - label_smoothing_factor\n        low_confidence = (1.0 - confidence) / (vocab_size - 1)\n        normalizing_constant = -(\n            confidence * jnp.log(confidence)\n            + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)\n        )\n        soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)\n\n        loss = optax.softmax_cross_entropy(logits, soft_labels)\n        loss = loss - normalizing_constant\n\n        # ignore padded tokens from loss\n        loss = loss * padding_mask\n        loss = loss.sum() / padding_mask.sum()\n        return loss\n\n    def train_loss(self, params, state, batch, dropout_rng, labels):\n        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n        loss = self.loss_helper(logits, labels, batch)\n        return loss\n\n    def val_metrics(self, batch, logits, labels):\n        loss = self.loss_helper(logits, labels, batch)\n        metrics = {\"loss\": loss}\n        return metrics\n\n    def eval_metrics(self, batch, logits, labels):\n        loss = self.loss_helper(logits, labels, batch)\n        metrics = {\"loss\": loss}\n        return metrics\n\n\n@TrainCallback.register(\"flax::generate_step\")\nclass GenerateCallback(TrainCallback):\n    def __init__(self, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n        self.logger = logging.getLogger(GenerateCallback.__name__)\n\n    def generate_step(self, params, batch):\n        self.model.params = params\n        gen_kwargs = {\"max_length\": 64, \"num_beams\": self.model.config.num_beams}\n        output_ids = self.model.generate(\n            batch[\"input_ids\"], attention_mask=batch[\"attention_mask\"], **gen_kwargs\n        )\n        return output_ids.sequences\n\n    def pre_train_loop(self) -> None:\n        if len(jax.devices()) > 1:\n            self.p_generate_step = jax.pmap(self.generate_step, axis_name=\"batch\")\n\n    def pre_val_loop(self, step: int, val_step: int, state) -> None:\n        self.state = state\n        self.eval_preds: List = []\n        self.eval_labels: List = []\n\n    def pre_val_batch(self, step: int, val_step: int, epoch: int, val_batch) -> None:\n        labels = val_batch[\"labels\"]\n        if len(jax.devices()) > 1:\n            generated_ids = self.p_generate_step(self.state.params, val_batch)\n        else:\n            generated_ids = self.generate_step(self.state.params, val_batch)\n        self.eval_preds.extend(jax.device_get(generated_ids.reshape(-1, 64)))\n        self.eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1])))\n\n    def postprocess_text(self, preds, labels):\n        preds = [pred.strip() for pred in preds]\n        labels = [label.strip() for label in labels]\n\n        # rougeLSum expects newline after each sentence\n        preds = [\"\\n\".join(nltk.sent_tokenize(pred)) for pred in preds]\n        labels = [\"\\n\".join(nltk.sent_tokenize(label)) for label in labels]\n\n        return preds, labels\n\n    def compute_metrics(self, preds, labels):\n        tokenizer = AutoTokenizer.from_pretrained(\"facebook/bart-base\")\n        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)\n        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n\n        # Some simple post-processing\n        decoded_preds, decoded_labels = self.postprocess_text(decoded_preds, decoded_labels)\n        metric = load_metric(\"rouge\")\n        result = metric.compute(\n            predictions=decoded_preds, references=decoded_labels, use_stemmer=True\n        )\n        # Extract a few results from ROUGE\n        result = {key: value.mid.fmeasure * 100 for key, value in result.items()}\n\n        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]\n        result[\"gen_len\"] = np.mean(prediction_lens)\n        result = {k: round(v, 4) for k, v in result.items()}\n        return result\n\n    def post_val_loop(\n        self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float]\n    ) -> None:\n        rouge_metrics = self.compute_metrics(self.eval_preds, self.eval_labels)\n        rouge_desc = \" \".join([f\"Eval {key}: {value} |\" for key, value in rouge_metrics.items()])\n        self.logger.info(rouge_desc)\n"
  },
  {
    "path": "examples/train_lm/.gitignore",
    "content": "runs\nrun\n"
  },
  {
    "path": "examples/train_lm/README.md",
    "content": "# Fine-tuning a language model\n\n<!-- start overview -->\n\nThis Tango example showcases how you could train or fine-tune a causal language model like [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2)\nor [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) from [transformers](https://github.com/huggingface/transformers) on WikiText2 or a similar dataset.\nIt's best that you run this experiment on a machine with a GPU and PyTorch [properly installed](https://pytorch.org/get-started/locally/#start-locally), otherwise Tango will fall back to CPU-only and it will be extremely slow.\n\nThis example also depends on [FairScale](https://fairscale.readthedocs.io/en/latest/), which allows you to leverage [`FullyShardedDataParallel`](https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html) (FSDP) and [activation checkpointing](https://fairscale.readthedocs.io/en/latest/api/nn/checkpoint/checkpoint_activations.html) to fine-tune [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B) or a similar-sized model. Just set the constants `fsdp` and `activation_checkpointing` in the config to `true`.\nWithout using CPU offloading you'll need at least 4 x 40GiB A100 GPUs, or a different configuration with a comparable amount of total GPU memory.\n\n<!-- end overview -->\n\nTo getting started, just run\n\n```\ntango run config.jsonnet -i tokenize_step.py\n```\n"
  },
  {
    "path": "examples/train_lm/config.jsonnet",
    "content": "##################\n# Model settings #\n##################\n\nlocal pretrained_model = \"gpt2\";\n# With 'fsdp' and 'activation_checkpointing' (see constants below), you should be able to train\n# a 6B model on 4x ~40GB GPUs:\n# local pretrained_model = \"EleutherAI/gpt-j-6B\";\n\n# This doesn't seem to work with gpt2, but works fine with gpt-j.\nlocal load_with_low_cpu_mem_usage = std.startsWith(pretrained_model, \"EleutherAI/gpt-j\");\n\n####################\n# Trainer settings #\n####################\n\n# Trainer settings, adjust to your use-case.\nlocal training_steps = 200;  # total number of optimization steps to train for\nlocal validate_every = 20;  # how often to validate and save checkpoints\n\nlocal devices = 1;  # number of devices to train on (will use GPUs if enough are available, otherwise CPU)\nlocal grad_accum = 1;  # number of gradient accumulation steps (changes the effective batch size)\n# This is the batch size per GPU, ignoring gradient accumulation:\nlocal batch_size = 8;\n# So the effective batch size is `batch_size * grad_accum * devices`\n\nlocal activation_checkpointing = false;  # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2)\nlocal amp = false;  # use PyTorch's native automatic mixed precision\nlocal fsdp = false;  # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2)\nlocal cpu_offloading = false;  # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.\n\n######################\n# Optimizer settings #\n######################\n\nlocal warmup_steps = 20;\nlocal learning_rate = 0.00005;  # you can probably use a higher LR for a small model like \"gpt2\"\n\n\n# <----- you probably don't need to edit below this line ----> #\n\n\nassert fsdp == true || cpu_offloading == false : \"cpu_offloading only available with fsdp\";\n\n# FullyShardedDataParallel config:\nlocal fsdp_config = if fsdp then {\n    reshard_after_forward: true,\n    move_params_to_cpu: cpu_offloading,\n    move_grads_to_cpu: cpu_offloading,\n    mixed_precision: amp,\n} else null;\n\nlocal training_engine = {\n    type: if fsdp then \"fairscale\" else \"torch\",\n    optimizer: {\n        type: \"torch::AdamW\",\n        lr: learning_rate,\n        betas: [0.9, 0.95],\n        eps: 1e-6,\n    },\n    lr_scheduler: {\n        type: \"transformers::linear\",\n        num_warmup_steps: warmup_steps,\n        num_training_steps: training_steps,\n    },\n    amp: amp,\n    [if fsdp then \"fsdp_config\" else null]: fsdp_config,\n};\n\nlocal distributed_dataloader = {\n    batch_size: batch_size,\n    collate_fn: { type: \"transformers::DefaultDataCollator\" },\n    sampler: {\n        type: \"torch::DistributedSampler\",\n        shuffle: true,\n        drop_last: true,\n    },\n};\n\nlocal single_device_dataloader = {\n    shuffle: true,\n    batch_size: batch_size,\n    collate_fn: { type: \"transformers::DefaultDataCollator\" },\n};\n\nlocal dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader;\n\n{\n    steps: {\n        raw_data: {\n            type: \"datasets::load\",\n            path: \"wikitext\",\n            name: \"wikitext-2-raw-v1\",\n        },\n        tokenized_data: {\n            type: \"tokenize_data\",\n            dataset: { type: \"ref\", ref: \"raw_data\" },\n            tokenizer: { pretrained_model_name_or_path: pretrained_model }\n        },\n        trained_model: {\n            type: \"torch::train\",\n            model: {\n                type: \"fairscale::with_wrapped_modules\",\n                model: {\n                    type: \"transformers::AutoModelForCausalLM::from_pretrained\",\n                    pretrained_model_name_or_path: pretrained_model,\n                    low_cpu_mem_usage: load_with_low_cpu_mem_usage,\n                },\n                modules_to_wrap: [\"transformer\\\\.h\\\\.[0-9]+\"],  # tell FairScale to wrap the transformer's blocks individually\n                fsdp_config: fsdp_config,\n                activation_checkpointing: activation_checkpointing,\n            },\n            dataset_dict: { type: \"ref\", ref: \"tokenized_data\" },\n            train_dataloader: dataloader,\n            validation_split: \"validation\",\n            grad_accum: grad_accum,\n            train_steps: training_steps,\n            validate_every: validate_every,\n            checkpoint_every: validate_every,\n            log_every: 1,\n            device_count: devices,\n            training_engine: training_engine,\n        },\n        final_metrics: {\n            type: \"torch::eval\",\n            model: { type: \"ref\", ref: \"trained_model\" },\n            dataset_dict: { type: \"ref\", ref: \"tokenized_data\" },\n            dataloader: single_device_dataloader,\n            test_split: \"test\",\n        },\n    }\n}\n"
  },
  {
    "path": "examples/train_lm/test.py",
    "content": "from tango.common import Params\nfrom tango.common.testing import run_experiment\n\n\ndef test_small_experiment():\n    model = \"sshleifer/tiny-gpt2\"\n    dataloader = {\n        \"batch_size\": 2,\n        \"collate_fn\": {\"type\": \"transformers::DefaultDataCollator\"},\n    }\n    steps = 4\n    overrides = {\n        \"steps.tokenized_data.block_size\": 64,\n        # Override the model in the config with the tiny alternative so training is fast.\n        \"steps.tokenized_data.tokenizer.pretrained_model_name_or_path\": model,\n        \"steps.trained_model.model.model.pretrained_model_name_or_path\": model,\n        # Use a small number of training/validation/eval steps.\n        \"steps.trained_model.training_engine.lr_scheduler.num_warmup_steps\": 1,\n        \"steps.trained_model.training_engine.lr_scheduler.num_training_steps\": steps,\n        \"steps.trained_model.train_steps\": steps,\n        \"steps.trained_model.validation_steps\": 2,\n        \"steps.trained_model.validate_every\": steps,\n        \"steps.final_metrics.eval_steps\": 2,\n        \"steps.trained_model.checkpoint_every\": steps,\n        \"steps.trained_model.device_count\": 1,\n        # Override data loaders.\n        \"steps.trained_model.train_dataloader\": dataloader,\n        \"steps.trained_model.validation_dataloader\": dataloader,\n        \"steps.final_metrics.dataloader\": dataloader,\n    }\n\n    # Load the config.\n    config = Params.from_file(\"config.jsonnet\", params_overrides=overrides)\n\n    # Make sure we've overrode the model entirely.\n    flattened = config.as_flat_dict()\n    for key, value in flattened.items():\n        if \"model_name\" in key or (isinstance(value, str) and \"gpt\" in value):\n            assert value == model\n\n    with run_experiment(config, include_package=[\"tokenize_step.py\"]) as run_dir:\n        assert (run_dir / \"trained_model\").is_dir()\n"
  },
  {
    "path": "examples/train_lm/tokenize_step.py",
    "content": "import datasets\n\nfrom tango import Step\nfrom tango.integrations.datasets import DatasetsFormat\nfrom tango.integrations.transformers import Tokenizer\n\n\n# We need a step to tokenize the raw data. The result of this step will be passed\n# directly into the \"torch::train\" step.\n@Step.register(\"tokenize_data\")\nclass TokenizeData(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT = DatasetsFormat()\n\n    def run(  # type: ignore[override]\n        self,\n        dataset: datasets.DatasetDict,\n        tokenizer: Tokenizer,\n        block_size: int = 1024,\n        num_workers: int = 1,\n        field_to_tokenize: str = \"text\",\n    ) -> datasets.DatasetDict:\n        def tokenize_function(example):\n            return tokenizer(example[field_to_tokenize])\n\n        dataset = dataset.map(\n            tokenize_function,\n            batched=True,\n            num_proc=num_workers,\n            remove_columns=list(dataset.column_names.values())[0],  # remove all old columns\n            desc=\"Tokenizing dataset\",\n        )\n\n        def group_texts(examples):\n            # Concatenate all texts.\n            concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}  # type: ignore\n            total_length = len(concatenated_examples[list(examples.keys())[0]])\n            # We drop the small remainder, we could add padding if the model supported\n            # it instead of this drop, you can customize this part to your needs.\n            if total_length >= block_size:\n                total_length = (total_length // block_size) * block_size\n            # Split by chunks of max_len.\n            result = {\n                k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n                for k, t in concatenated_examples.items()\n            }\n            result[\"labels\"] = result[\"input_ids\"].copy()\n            return result\n\n        dataset = dataset.map(\n            group_texts,\n            batched=True,\n            num_proc=num_workers,\n            desc=f\"Grouping texts into chunks of {block_size}\",\n        )\n\n        return dataset\n"
  },
  {
    "path": "integration_tests/README.md",
    "content": "# Integration tests\n\nThese are a collection of longer running end-to-end tests of various parts of the Tango library.\n\nThe easiest way to run any of these integration tests is by triggering the [**Integration tests**](https://github.com/allenai/tango/actions/workflows/integration_tests.yml)\nworkflow on GitHub Actions. Just select the \"Run workflow\" dropdown, then pick the test to run and the Beaker cluster to run it on,\nand finally hit the \"Run workflow\" button.\n\nEach test should have a `run.sh` file in its folder that will run the relevant tango command.\nThis is what the **Integration tests** workflow will call, and you can also use it to run the test manually.\n"
  },
  {
    "path": "integration_tests/fairscale_benchmarks/README.md",
    "content": "# FairScale Benchmarks\n\nThis integration test is for checking the performance of the `FairScaleTrainingEngine` with various configurations.\n\n**When to run it:** It should be ran every time there is a major PyTorch or FairScale upgrade.\n\n**Where to run it:** A server with 4 A100 GPUs. Make sure you set your `WANDB_API_KEY` environment variable.\n\n**How to run it:** From the root directory of this repository, run:\n```\nintegration_tests/fairscale_benchmarks/run.sh\n```\n\nBy default, not all configurations are run. If you want to run change which configurations are run, open `config.jsonnet`\nare search for \"enabled\". Then toggle this `enabled` field to `true` or `false` for each configuration.\n\n**What to look for:** The training jobs shouldn't fail, for one. After `tango run` completes, check the corresponding Weights & Biases\ndashboard and inspect the results. Compare the various \"fsdp\" training runs with the baseline to ensure you see memory savings.\n"
  },
  {
    "path": "integration_tests/fairscale_benchmarks/config.jsonnet",
    "content": "##################\n# Model settings #\n##################\n\nlocal pretrained_model = \"gpt2\";\n# local pretrained_model = \"EleutherAI/gpt-j-6B\";\n# This doesn't seem to work with gpt2, but works fine with gpt-j-6B.\nlocal load_with_low_cpu_mem_usage = pretrained_model == \"EleutherAI/gpt-j-6B\";\n\n####################\n# Trainer settings #\n####################\n\n# Trainer settings, adjust to your use-case.\nlocal training_steps = 100;  # total number of optimization steps to train for\nlocal validate_every = 20;  # how often to validate and save checkpoints\n\nlocal devices = 4;\nlocal grad_accum = 1;  # number of gradient accumulation steps (changes the effective batch size)\n# This is the batch size per GPU, ignoring gradient accumulation:\nlocal batch_size = 8;\n# So the effective batch size is `batch_size * grad_accum * devices`\n\n######################\n# Optimizer settings #\n######################\n\nlocal warmup_steps = 20;\nlocal learning_rate = if pretrained_model == \"EleutherAI/gpt-j-6B\" then 0.00001 else 0.0001;\n\n\n# <----- you probably don't need to edit below this line ----> #\n\n\nlocal distributed_dataloader = {\n  batch_size: batch_size,\n  collate_fn: { type: \"transformers::DefaultDataCollator\" },\n  sampler: {\n    type: \"torch::DistributedSampler\",\n    shuffle: true,\n    drop_last: true,\n  },\n};\n\nlocal single_device_dataloader = {\n  shuffle: true,\n  batch_size: batch_size,\n  collate_fn: { type: \"transformers::DefaultDataCollator\" },\n};\n\nlocal TrainStep(options) =\n    local training_engine = {\n        type: if options.fsdp_config != null then \"fairscale\" else \"torch\",\n        optimizer: {\n            type: \"torch::AdamW\",\n            lr: learning_rate,\n            betas: [0.9, 0.95],\n            eps: 1e-6,\n        },\n        lr_scheduler: {\n            type: \"transformers::linear\",\n            num_warmup_steps: warmup_steps,\n            num_training_steps: training_steps,\n        },\n        amp: options.amp,\n        [if options.fsdp_config != null then \"fsdp_config\" else null]: options.fsdp_config,\n    };\n\n    {\n        type: \"torch::train\",\n        model: {\n            type: \"fairscale::with_wrapped_modules\",\n            model: {\n                type: \"transformers::AutoModelForCausalLM::from_pretrained\",\n                pretrained_model_name_or_path: pretrained_model,\n                low_cpu_mem_usage: load_with_low_cpu_mem_usage,\n            },\n            modules_to_wrap: [\"transformer\\\\.h\\\\.[0-9]+\"],  # tell FairScale to wrap the transformer's blocks individually\n            fsdp_config: options.fsdp_config,\n            activation_checkpointing: options.activation_checkpointing,\n        },\n        dataset_dict: { type: \"ref\", ref: \"tokenized_data\" },\n        train_dataloader: distributed_dataloader,\n        validation_split: \"validation\",\n        grad_accum: grad_accum,\n        train_steps: training_steps,\n        validate_every: validate_every,\n        checkpoint_every: validate_every,\n        log_every: 1,\n        device_count: devices,\n        training_engine: training_engine,\n        callbacks: [\n            {\n                type: \"wandb::log\",\n                entity: \"allennlp\",\n                project: \"tango-fairscale-benchmarks\",\n                wandb_config: options + {\n                    effective_batch_size: batch_size * devices * grad_accum,\n                    model: pretrained_model,\n                },\n            },\n        ],\n    };\n\n{\n    steps: {\n        raw_data: {\n            type: \"datasets::load\",\n            path: \"wikitext\",\n            name: \"wikitext-2-raw-v1\",\n        },\n        tokenized_data: {\n            type: \"tokenize_data\",\n            dataset: { type: \"ref\", ref: \"raw_data\" },\n            tokenizer: { pretrained_model_name_or_path: pretrained_model }\n        },\n    } + {\n        [\"trained_model_\" + options.name]: TrainStep(options)\n        for options in [\n            # NOTE: With 6B model, baseline and many others will fail with CUDA OOM.\n            # FSDP and activation checkpointing will be required for a 6B model.\n            {\n                name: \"baseline\",\n                enabled: false,\n                amp: false,\n                fsdp_config: null,\n                activation_checkpointing: false,\n            },\n            {\n                name: \"amp\",\n                enabled: false,\n                amp: true,\n                fsdp_config: null,\n                activation_checkpointing: false,\n            },\n            {\n                name: \"checkpointing\",\n                enabled: false,\n                amp: false,\n                fsdp_config: null,\n                activation_checkpointing: true,\n            },\n            {\n                name: \"amp_and_checkpointing\",\n                enabled: false,\n                amp: true,\n                fsdp_config: null,\n                activation_checkpointing: true,\n            },\n            {\n                name: \"fsdp\",\n                enabled: false,\n                amp: false,\n                activation_checkpointing: false,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: false,\n                },\n            },\n            {\n                name: \"fsdp_no_reshard\",\n                enabled: false,\n                amp: false,\n                activation_checkpointing: false,\n                fsdp_config: {\n                    reshard_after_forward: false,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: false,\n                },\n            },\n            {\n                name: \"amp_and_fsdp\",\n                enabled: false,\n                amp: true,\n                activation_checkpointing: false,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: false,\n                },\n            },\n            {\n                name: \"amp_and_fsdp_no_reshard\",\n                enabled: false,\n                amp: true,\n                activation_checkpointing: false,\n                fsdp_config: {\n                    reshard_after_forward: false,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: false,\n                },\n            },\n            {\n                name: \"amp_and_fsdp_mp\",\n                enabled: false,\n                amp: true,\n                activation_checkpointing: false,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: true,\n                },\n            },\n            {\n                name: \"amp_and_fsdp_mp_no_reshard\",\n                enabled: false,\n                amp: true,\n                activation_checkpointing: false,\n                fsdp_config: {\n                    reshard_after_forward: false,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: true,\n                },\n            },\n            {\n                name: \"checkpointing_and_fsdp\",\n                enabled: false,\n                amp: false,\n                activation_checkpointing: true,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: false,\n                },\n            },\n            {\n                name: \"amp_and_checkpointing_and_fsdp\",\n                enabled: false,\n                amp: true,\n                activation_checkpointing: true,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: false,\n                },\n            },\n            {\n                name: \"amp_and_checkpointing_and_fsdp_mp\",\n                enabled: true,\n                amp: true,\n                activation_checkpointing: true,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: true,\n                },\n            },\n            {\n                name: \"checkpointing_and_fsdp_mp\",\n                enabled: false,\n                amp: false,\n                activation_checkpointing: true,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: false,\n                    move_grads_to_cpu: false,\n                    mixed_precision: true,\n                },\n            },\n            {  # This configuration currently does not work. Tracking https://github.com/facebookresearch/fairscale/issues/918\n                name: \"amp_and_checkpointing_and_fsdp_mp_with_partial_offloading\",\n                enabled: false,\n                amp: true,\n                activation_checkpointing: true,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: true,\n                    move_grads_to_cpu: false,\n                    mixed_precision: true,\n                },\n            },\n            {\n                name: \"amp_and_checkpointing_and_fsdp_mp_with_full_offloading\",\n                enabled: false,\n                amp: true,\n                activation_checkpointing: true,\n                fsdp_config: {\n                    reshard_after_forward: true,\n                    move_params_to_cpu: true,\n                    move_grads_to_cpu: true,\n                    mixed_precision: true,\n                },\n            },\n        ] if options.enabled\n    }\n}\n"
  },
  {
    "path": "integration_tests/fairscale_benchmarks/run.sh",
    "content": "#!/bin/sh\n\ntango run integration_tests/fairscale_benchmarks/config.jsonnet -i examples/train_lm/tokenize_step.py\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools\", \"wheel\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"ai2-tango\"\ndynamic = [\"version\"]\nreadme = \"README.md\"\ndescription = \"A library for choreographing your machine learning research.\"\nclassifiers=[\n  \"Intended Audience :: Science/Research\",\n  \"Development Status :: 3 - Alpha\",\n  \"License :: OSI Approved :: Apache Software License\",\n  \"Programming Language :: Python :: 3\",\n  \"Topic :: Scientific/Engineering :: Artificial Intelligence\",\n]\nauthors = [\n    {name = \"Allen Institute for Artificial Intelligence\", email = \"contact@allenai.org\"}\n]\nlicense = {file = \"LICENSE\"}\nrequires-python = \">=3.8.1\"\ndependencies = [\n  \"cached-path>=1.0,<2.0\",\n  \"rjsonnet>=0.5.0\",\n  \"GitPython>=3.0,<4.0\",\n  \"PyYAML>=5.4.1,<7.0\",\n  \"dill\",\n  \"base58\",\n  \"xxhash\",\n  \"filelock>=3.4,<4.0\",\n  \"click>=8.0,<8.1.4\",\n  \"click-help-colors>=0.9.1,<0.10\",\n  \"rich>=12.3,<14.0\",\n  \"tqdm>=4.62,<5.0\",\n  \"more-itertools>=8.0,<11.0\",\n  \"sqlitedict\",\n  \"glob2>=0.7\",\n  \"petname>=2.6,<3.0\",\n  \"pytz\"\n]\n\n[project.optional-dependencies]\ndev = [\n  \"ruff\",\n  \"mypy==1.2.0\",\n  \"types-PyYAML\",\n  \"types-setuptools\",\n  \"types-pytz\",\n  \"types-retry\",\n  \"black==23.3.0\",\n  \"isort==5.12.0\",\n  \"pytest\",\n  \"pytest-sphinx\",\n  \"flaky\",\n  \"twine>=1.11.0\",\n  \"setuptools\",\n  \"wheel\",\n  \"build\",\n  \"Sphinx==5.3.0\",\n  \"furo==2023.3.27\",\n  \"myst-parser==1.0.0\",\n  \"sphinx-copybutton==0.5.2\",\n  \"sphinx-autobuild==2021.3.14\",\n  \"sphinx-autodoc-typehints<=1.23.0\",\n  \"packaging\"\n]\nexamples = [\n  \"torchmetrics>=0.7.0\"\n]\ntorch = [\n  \"torch>=1.9,<2.1\",\n  \"numpy\",\n]\ntransformers = [\n  \"torch>=1.9,<2.1\",\n  \"numpy\",\n  \"datasets>=1.12,<3.0\",\n  \"transformers>=4.12.3\",\n  \"sentencepiece==0.1.98\",\n  \"sacremoses\"\n]\ndatasets = [\n  \"datasets>=1.12,<3.0\"\n]\nfairscale = [\n  \"torch>=1.9,<2.1\",\n  \"numpy\",\n  \"fairscale>=0.4.6,<0.5\"\n]\nflax = [\n  \"datasets>=1.12,<3.0\",\n  \"jax\",\n  \"jaxlib\",\n  \"flax\",\n  \"optax\",\n  \"tensorflow-cpu>=2.9.1\"\n]\nwandb = [\n  \"wandb>=0.16\",\n  \"retry\"\n]\nbeaker = [\n  \"beaker-py>=1.14.0,<2.0\"\n]\ngs = [\n  \"google-cloud-storage>=2.6.0\",\n  \"google-cloud-datastore>=2.12.0\"\n]\nall = [\n  \"ai2-tango[examples,torch,transformers,datasets,fairscale,flax,wandb,beaker,gs]\"\n]\n\n[project.scripts]\ntango = \"tango.__main__:main\"\n\n[project.urls]\nhomepage = \"https://github.com/allenai/tango\"\nrepository = \"https://github.com/allenai/tango\"\n\n[tool.setuptools.packages.find]\nexclude = [\n    \"*.tests\",\n    \"*.tests.*\",\n    \"tests.*\",\n    \"tests\",\n    \"test_fixtures\",\n    \"test_fixtures.*\",\n    \"docs*\",\n    \"scripts*\",\n    \"examples*\"\n]\n\n[tool.setuptools.package-data]\ntango = [\"py.typed\"]\n\"tango.integrations.beaker\" = [\"*.sh\"]\n\n[tool.setuptools.dynamic]\nversion = {attr = \"tango.version.VERSION\"}\n\n[tool.black]\nline-length = 100\ninclude = '\\.pyi?$'\nexclude = '''\n(\n      __pycache__\n    | \\.git\n    | \\.mypy_cache\n    | \\.pytest_cache\n    | \\.vscode\n    | \\.venv\n    | \\bdist\\b\n    | \\bdoc\\b\n)\n'''\n\n[tool.isort]\nprofile = \"black\"\nmulti_line_output = 3\n\n[tool.ruff]\nline-length = 115\nselect = [\"E\"]\nexclude = [\n  \".venv\",\n  \".git\",\n  \"__pycache__\",\n  \".mypy_cache\",\n  \"docs/build\",\n  \"dist\"\n]\n\n[tool.ruff.per-file-ignores]\n\"__init__.py\" = [\"F401\"]\n\"*/**/**/__init__.py\" = [\"F401\",\"E501\"]\n\n[tool.mypy]\nignore_missing_imports = true\nno_site_packages = false\nallow_redefinition = true\ncheck_untyped_defs = true\n\n[[tool.mypy.overrides]]\nmodule = \"tests.*\"\nstrict_optional = false\ndisable_error_code = [\n  \"var-annotated\",\n  \"no-redef\",\n  \"dict-item\"\n]\nallow_redefinition = true\n\n[tool.pytest.ini_options]\ntestpaths = \"tests/\"\npython_classes = [\n  \"Test*\",\n  \"*Test\"\n]\nlog_format = \"%(asctime)s - %(levelname)s - %(name)s - %(message)s\"\nlog_level = \"DEBUG\"\nmarkers = [\n  \"gpu: marks tests that need GPUs\"\n]\nfilterwarnings = [\n  'ignore:.*Consider increasing the value of the `num_workers` argument.*:UserWarning:pytorch_lightning\\.trainer\\.data_loading',\n  'ignore:.*you defined a validation_step but have no val_dataloader.*:UserWarning:pytorch_lightning\\.trainer\\.configuration_validator',\n  'ignore::UserWarning:tango\\.*',\n  'ignore::DeprecationWarning:pkg_resources',\n  'ignore::DeprecationWarning:google\\.rpc'\n]\ndoctest_optionflags = \"NORMALIZE_WHITESPACE\"\n"
  },
  {
    "path": "scripts/entrypoint.sh",
    "content": "#!/bin/bash\n\n# Exit script if any commands fail.\nset -e\nset -o pipefail\n\n# Check that the environment variable has been set correctly\nif [ -z \"$COMMIT_SHA\" ]; then\n  echo >&2 'error: missing COMMIT_SHA environment variable'\n  exit 1\nfi\n\n# Upgrade pip\n/opt/conda/bin/pip install --upgrade pip\n\n# Clone and install tango.\ngit clone https://github.com/allenai/tango.git\ncd tango\ngit checkout --quiet \"$COMMIT_SHA\"\n/opt/conda/bin/pip install --no-cache-dir '.[dev,all]'\n\n# Create directory for results.\nmkdir -p /results\n\n# Execute the arguments to this script as commands themselves, piping output into a log file.\nexec \"$@\" 2>&1 | tee /results/out.log\n"
  },
  {
    "path": "scripts/hash_extras.py",
    "content": "\"\"\"\nUsed in CI to create a unique ID for any set of install extras.\n\"\"\"\n\nimport sys\n\n\ndef main():\n    extras = sys.argv[1]\n    print(\"-\".join(sorted(extras.split(\",\"))))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/prepare_changelog.py",
    "content": "from datetime import datetime\nfrom pathlib import Path\n\nfrom tango.version import VERSION\n\n\ndef main():\n    changelog = Path(\"CHANGELOG.md\")\n\n    with changelog.open() as f:\n        lines = f.readlines()\n\n    insert_index: int\n    for i in range(len(lines)):\n        line = lines[i]\n        if line.startswith(\"## Unreleased\"):\n            insert_index = i + 1\n        elif line.startswith(f\"## [v{VERSION}]\"):\n            print(\"CHANGELOG already up-to-date\")\n            return\n        elif line.startswith(\"## [v\"):\n            break\n    else:\n        raise RuntimeError(\"Couldn't find 'Unreleased' section\")\n\n    lines.insert(insert_index, \"\\n\")\n    lines.insert(\n        insert_index + 1,\n        f\"## [v{VERSION}](https://github.com/allenai/tango/releases/tag/v{VERSION}) - \"\n        f\"{datetime.now().strftime('%Y-%m-%d')}\\n\",\n    )\n\n    with changelog.open(\"w\") as f:\n        f.writelines(lines)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/prepare_citation_cff.py",
    "content": "from datetime import datetime\nfrom pathlib import Path\n\nfrom tango.version import VERSION\n\n\ndef main():\n    citation = Path(\"CITATION.cff\")\n\n    with citation.open() as f:\n        lines = f.readlines()\n\n    for i in range(len(lines)):\n        line = lines[i]\n        if line.startswith(\"version:\"):\n            lines[i] = f'version: \"{VERSION}\"\\n'\n        elif line.startswith(\"date-released:\"):\n            lines[i] = f'date-released: \"{datetime.now().strftime(\"%Y-%m-%d\")}\"\\n'\n\n    with citation.open(\"w\") as f:\n        f.writelines(lines)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/release.sh",
    "content": "#!/bin/bash\n\nset -e\n\nTAG=$(python -c 'from tango.version import VERSION; print(\"v\" + VERSION)')\n\nread -p \"Creating new release for $TAG. Do you want to continue? [Y/n] \" prompt\n\nif [[ $prompt == \"y\" || $prompt == \"Y\" || $prompt == \"yes\" || $prompt == \"Yes\" ]]; then\n    python scripts/prepare_changelog.py\n    python scripts/prepare_citation_cff.py\n    git add -A\n    git commit -m \"Prepare for release $TAG\" || true && git push\n    echo \"Creating new git tag $TAG\"\n    git tag \"$TAG\" -m \"$TAG\"\n    git push --tags\nelse\n    echo \"Cancelled\"\n    exit 1\nfi\n"
  },
  {
    "path": "scripts/release_notes.py",
    "content": "# encoding: utf-8\n\n\"\"\"\nPrepares markdown release notes for GitHub releases.\n\"\"\"\n\nimport os\nfrom typing import List, Optional\n\nimport packaging.version\n\nTAG = os.environ[\"TAG\"]\n\nADDED_HEADER = \"### Added 🎉\"\nCHANGED_HEADER = \"### Changed ⚠️\"\nFIXED_HEADER = \"### Fixed ✅\"\nREMOVED_HEADER = \"### Removed 👋\"\n\n\ndef get_change_log_notes() -> str:\n    in_current_section = False\n    current_section_notes: List[str] = []\n    with open(\"CHANGELOG.md\") as changelog:\n        for line in changelog:\n            if line.startswith(\"## \"):\n                if line.startswith(\"## Unreleased\"):\n                    continue\n                if line.startswith(f\"## [{TAG}]\"):\n                    in_current_section = True\n                    continue\n                break\n            if in_current_section:\n                if line.startswith(\"### Added\"):\n                    line = ADDED_HEADER + \"\\n\"\n                elif line.startswith(\"### Changed\"):\n                    line = CHANGED_HEADER + \"\\n\"\n                elif line.startswith(\"### Fixed\"):\n                    line = FIXED_HEADER + \"\\n\"\n                elif line.startswith(\"### Removed\"):\n                    line = REMOVED_HEADER + \"\\n\"\n                current_section_notes.append(line)\n    assert current_section_notes\n    return \"## What's new\\n\\n\" + \"\".join(current_section_notes).strip() + \"\\n\"\n\n\ndef get_commit_history() -> str:\n    new_version = packaging.version.parse(TAG)\n\n    os.popen(\"git fetch --tags\")\n\n    # Get all tags sorted by version, latest first.\n    all_tags = os.popen(\"git tag -l --sort=-version:refname 'v*'\").read().split(\"\\n\")\n\n    # Out of `all_tags`, find the latest previous version so that we can collect all\n    # commits between that version and the new version we're about to publish.\n    # Note that we ignore pre-releases unless the new version is also a pre-release.\n    last_tag: Optional[str] = None\n    for tag in all_tags:\n        if not tag.strip():  # could be blank line\n            continue\n        version = packaging.version.parse(tag)\n        if new_version.pre is None and version.pre is not None:\n            continue\n        if version < new_version:\n            last_tag = tag\n            break\n    if last_tag is not None:\n        commits = os.popen(f\"git log {last_tag}..{TAG}^ --oneline --first-parent\").read()\n    else:\n        commits = os.popen(\"git log --oneline --first-parent\").read()\n    return \"## Commits\\n\\n\" + commits\n\n\ndef main():\n    print(get_change_log_notes())\n    print(get_commit_history())\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tango/__init__.py",
    "content": "\"\"\"\nA Python library for choreographing your machine learning research.\n\"\"\"\n\n__all__ = [\n    \"cleanup_cli\",\n    \"DillFormat\",\n    \"DillFormatIterator\",\n    \"execute_step_graph\",\n    \"Executor\",\n    \"Format\",\n    \"initialize_cli\",\n    \"JsonFormat\",\n    \"JsonFormatIterator\",\n    \"load_settings\",\n    \"prepare_executor\",\n    \"prepare_workspace\",\n    \"Run\",\n    \"RunInfo\",\n    \"RunSort\",\n    \"SqliteDictFormat\",\n    \"Step\",\n    \"step\",\n    \"StepCache\",\n    \"StepGraph\",\n    \"StepInfo\",\n    \"StepInfoSort\",\n    \"StepResources\",\n    \"StepState\",\n    \"tango_cli\",\n    \"Workspace\",\n]\n\nfrom .cli import (\n    cleanup_cli,\n    execute_step_graph,\n    initialize_cli,\n    load_settings,\n    prepare_executor,\n    prepare_workspace,\n    tango_cli,\n)\nfrom .executor import Executor\nfrom .format import (\n    DillFormat,\n    DillFormatIterator,\n    Format,\n    JsonFormat,\n    JsonFormatIterator,\n    SqliteDictFormat,\n)\nfrom .step import Step, StepResources, step\nfrom .step_cache import StepCache\nfrom .step_graph import StepGraph\nfrom .step_info import StepInfo, StepState\nfrom .workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace\n"
  },
  {
    "path": "tango/__main__.py",
    "content": "\"\"\"\nThe Tango CLI is the recommended tool to run experiments with.\nIt also comes with several other useful commands.\n\nYou can see the the list of all available commands by running:\n\n.. code-block::\n\n    $ tango --help\n\n.. testcode::\n    :hide:\n\n    import subprocess\n    output = subprocess.run(\"tango --help\".split(\" \"), capture_output=True)\n    output.check_returncode()\n    print(output.stdout.decode().replace(\"\\\\n\\\\n\", \"\\\\n\").strip())\n\n.. testoutput::\n\n    Usage: tango [OPTIONS] COMMAND [ARGS]...\n\n    Options:\n      --version                       Show the version and exit.\n      --settings FILE                 Path to a global tango.yml settings file.\n      --log-level [debug|info|warning|error]\n                                      Set the global log level.\n      --file-friendly-logging         Outputs progress bar status on separate lines and slows refresh rate.\n      --start-method [fork|spawn|forkserver]\n                                      Set the multiprocessing start method.\n      --help                          Show this message and exit.\n\n    Commands:\n      info      Get info about the current tango installation.\n      run       Run a tango experiment.\n      settings  Commands for initializing and updating global settings.\n\nTo see all of the available arguments and options for a particular command, run\n\n.. code-block::\n\n    $ tango [COMMAND] --help\n\nFor example,\n\n.. code-block::\n\n    $ tango run --help\n\n``tango run``\n-------------\n\nThe ``run`` command is used to execute a tango experiment from an experiment configuration file.\nSee the `Configuration files </overview.html#configuration-files>`_ section in the overview\nfor a quick introduction to the format.\n\n``tango info``\n--------------\n\nThe ``info`` command just prints out some useful information about the current tango installation,\nsuch as which integrations are available.\n\n``tango settings``\n------------------\n\nThe ``settings`` group of commands can be used to initialize a :class:`~tango.settings.TangoGlobalSettings`\nfile or update fields in it.\n\n\"\"\"\nimport os\nfrom pathlib import Path\nfrom typing import Dict, List, NamedTuple, Optional, Sequence, Union\n\nimport click\nfrom click_help_colors import HelpColorsCommand, HelpColorsGroup\n\nfrom tango.cli import (\n    cleanup_cli,\n    execute_step_graph,\n    initialize_cli,\n    load_settings,\n    prepare_executor,\n    prepare_workspace,\n)\nfrom tango.common.exceptions import CliRunError, IntegrationMissingError\nfrom tango.common.logging import cli_logger, initialize_logging\nfrom tango.common.params import Params\nfrom tango.common.util import (\n    find_integrations,\n    import_extra_module,\n    import_module_and_submodules,\n)\nfrom tango.settings import TangoGlobalSettings\nfrom tango.step_graph import StepGraph\nfrom tango.version import VERSION\nfrom tango.workspace import Workspace\n\n_CLICK_GROUP_DEFAULTS = {\n    \"cls\": HelpColorsGroup,\n    \"help_options_color\": \"green\",\n    \"help_headers_color\": \"yellow\",\n    \"context_settings\": {\"max_content_width\": 115},\n}\n\n_CLICK_COMMAND_DEFAULTS = {\n    \"cls\": HelpColorsCommand,\n    \"help_options_color\": \"green\",\n    \"help_headers_color\": \"yellow\",\n    \"context_settings\": {\"max_content_width\": 115},\n}\n\n\nclass SettingsObject(NamedTuple):\n    settings: TangoGlobalSettings\n    called_by_executor: bool\n\n\n@click.group(name=None, **_CLICK_GROUP_DEFAULTS)\n@click.version_option(version=VERSION)\n@click.option(\n    \"--settings\",\n    type=click.Path(exists=True, dir_okay=False, resolve_path=True),\n    help=\"Path to a global tango.yml settings file.\",\n)\n@click.option(\n    \"--log-level\",\n    help=\"Set the global log level.\",\n    type=click.Choice([\"debug\", \"info\", \"warning\", \"error\"], case_sensitive=False),\n    show_choices=True,\n)\n@click.option(\n    \"--file-friendly-logging\",\n    is_flag=True,\n    help=\"Outputs progress bar status on separate lines and slows refresh rate.\",\n)\n@click.option(\n    \"--start-method\",\n    help=\"Set the multiprocessing start method.\",\n    type=click.Choice([\"fork\", \"spawn\", \"forkserver\"], case_sensitive=True),\n    show_choices=True,\n)\n@click.option(\n    \"--called-by-executor\",\n    is_flag=True,\n    hidden=True,\n)\n@click.pass_context\ndef main(\n    ctx,\n    settings: Optional[str] = None,\n    log_level: Optional[str] = None,\n    file_friendly_logging: bool = False,\n    start_method: Optional[str] = None,\n    called_by_executor: bool = False,\n):\n    settings: TangoGlobalSettings = load_settings(settings)\n\n    if start_method is not None:\n        settings.multiprocessing_start_method = start_method\n\n    if log_level is not None:\n        settings.log_level = log_level\n\n    if file_friendly_logging:\n        settings.file_friendly_logging = file_friendly_logging\n\n    ctx.obj = SettingsObject(settings, called_by_executor)\n\n    initialize_cli(settings=settings, called_by_executor=called_by_executor)\n\n\n@main.result_callback()\ndef cleanup(*args, **kwargs):\n    cleanup_cli()\n\n\n@main.command(**_CLICK_COMMAND_DEFAULTS)\n@click.argument(\n    \"experiment\",\n    type=click.Path(exists=True, dir_okay=False, resolve_path=True),\n)\n@click.option(\n    \"-w\",\n    \"--workspace\",\n    type=click.Path(file_okay=False),\n    help=\"\"\"A workspace path or URL. If not specified, the workspace from any global tango\n    settings file will be used, if found, otherwise an ephemeral MemoryWorkspace.\"\"\",\n    default=None,\n)\n@click.option(\n    \"-d\",\n    \"--workspace-dir\",\n    type=click.Path(file_okay=False),\n    default=None,\n    hidden=True,\n)\n@click.option(\n    \"-o\",\n    \"--overrides\",\n    type=str,\n    help=\"\"\"A JSON(NET) string used to override fields in the experiment config.\n    Use dot syntax to specify nested fields.\"\"\",\n)\n@click.option(\n    \"-i\",\n    \"--include-package\",\n    type=str,\n    help=\"Python packages or modules to import for tango components.\",\n    multiple=True,\n)\n@click.option(\n    \"-j\",\n    \"--parallelism\",\n    type=int,\n    help=\"\"\"The maximum number of steps to run in parallel (for executors that support this).\n    The exact behavior depends on the executor. If you're using the default executors,\n    a value of 0 (or left unspecified) means each step is run in the main process using the default executor,\n    otherwise the multicore executor is used.\"\"\",\n)\n@click.option(\n    \"-s\",\n    \"--step-name\",\n    help=\"Execute a particular step (and its dependencies) in the experiment.\",\n    multiple=True,\n)\n@click.option(\n    \"-n\",\n    \"--name\",\n    type=str,\n    help=\"\"\"Specify the name for this run.\"\"\",\n)\n@click.option(\n    \"-D\",\n    \"--ext-var\",\n    type=str,\n    help=\"\"\"JSONNET external variables to use when loading the experiment config.\n    For example, --ext-var 'pretrained_model=gpt2'.\"\"\",\n    multiple=True,\n)\n@click.pass_obj\ndef run(\n    obj: SettingsObject,\n    experiment: str,\n    workspace: Optional[str] = None,\n    workspace_dir: Optional[Union[str, os.PathLike]] = None,\n    overrides: Optional[str] = None,\n    include_package: Optional[Sequence[str]] = None,\n    parallelism: Optional[int] = None,\n    step_name: Optional[Sequence[str]] = None,\n    name: Optional[str] = None,\n    ext_var: Optional[Sequence[str]] = None,\n):\n    \"\"\"\n    Run a tango experiment.\n\n    EXPERIMENT is the path to experiment's JSON/Jsonnet/YAML configuration file.\n    \"\"\"\n    if workspace_dir is not None:\n        import warnings\n\n        warnings.warn(\n            \"-d/--workspace-dir option is deprecated. Please use -w/--workspace instead.\",\n            DeprecationWarning,\n        )\n\n        if workspace is not None:\n            raise click.ClickException(\n                \"-w/--workspace is mutually exclusive with -d/--workspace-dir\"\n            )\n\n        workspace = \"local://\" + str(workspace_dir)\n\n    _run(\n        obj.settings,\n        experiment,\n        workspace_url=workspace,\n        overrides=overrides,\n        include_package=include_package,\n        parallelism=parallelism,\n        step_names=step_name,\n        name=name,\n        called_by_executor=obj.called_by_executor,\n        ext_var=ext_var,\n    )\n\n\n@main.command(hidden=True)\n@click.argument(\n    \"experiment\",\n    type=click.Path(exists=True, dir_okay=False, resolve_path=True),\n)\n@click.argument(\n    \"step_name\",\n    type=str,\n)\n@click.argument(\n    \"workspace_url\",\n    type=str,\n)\n@click.option(\n    \"-i\",\n    \"--include-package\",\n    type=str,\n    help=\"Python packages or modules to import for tango components.\",\n    multiple=True,\n)\n@click.option(\n    \"--log-level\",\n    help=\"Set the global log level.\",\n    type=click.Choice([\"debug\", \"info\", \"warning\", \"error\"], case_sensitive=False),\n    show_choices=True,\n)\ndef beaker_executor_run(\n    experiment: str,\n    step_name: str,\n    workspace_url: str,\n    include_package: Optional[Sequence[str]] = None,\n    log_level: str = \"debug\",\n):\n    \"\"\"\n    This command is only used internally by the BeakerExecutor.\n    \"\"\"\n    from tango.executor import Executor\n\n    if include_package:\n        for package_name in include_package:\n            import_extra_module(package_name)\n\n    # Load step graph and step.\n    step_graph = StepGraph.from_file(experiment)\n    step = step_graph[step_name]\n\n    # Initialize workspace and executor.\n    # NOTE: We use the default executor here because we're just running the step\n    # locally in the main process.\n    workspace = Workspace.from_url(workspace_url)\n    executor = Executor(workspace=workspace, include_package=include_package)\n\n    # Initialize logging.\n    initialize_logging(log_level=log_level, enable_cli_logs=True, file_friendly_logging=True)\n\n    # Run step.\n    executor.execute_step(step)\n\n\n@main.command(**_CLICK_COMMAND_DEFAULTS)\n@click.pass_obj\ndef info(obj: SettingsObject):\n    \"\"\"\n    Get info about the current tango installation.\n    \"\"\"\n    import platform\n\n    cli_logger.info(\"Tango version %s (python %s)\", VERSION, platform.python_version())\n    cli_logger.info(\"\")\n\n    # Show info about settings.\n    if obj.settings.path is not None:\n        cli_logger.info(\"[underline]Settings:[/]\")\n        cli_logger.info(\"[green] \\N{check mark} Loaded from %s[/]\", obj.settings.path)\n        if obj.settings.include_package:\n            cli_logger.info(\"   Included packages:\")\n            for package in obj.settings.include_package:\n                is_found = True\n                try:\n                    import_module_and_submodules(package)\n                except (ModuleNotFoundError, ImportError):\n                    is_found = False\n                if is_found:\n                    cli_logger.info(\"   [green]\\N{check mark} %s[/]\", package)\n                else:\n                    cli_logger.info(\"   [red]\\N{ballot x} %s (not found)[/]\", package)\n        cli_logger.info(\"\")\n\n    # Show info about integrations.\n    cli_logger.info(\"[underline]Integrations:[/]\")\n    for integration in find_integrations():\n        name = integration.split(\".\")[-1]\n        is_installed = True\n        try:\n            import_module_and_submodules(integration, recursive=False)\n        except (IntegrationMissingError, ModuleNotFoundError, ImportError):\n            is_installed = False\n        if is_installed:\n            cli_logger.info(\" [green]\\N{check mark} %s[/]\", name)\n        else:\n            cli_logger.info(\" [yellow]\\N{ballot x} %s (not installed)[/]\", name)\n\n\n@main.group(**_CLICK_GROUP_DEFAULTS)\n@click.pass_obj\ndef settings(ctx):\n    \"\"\"\n    Commands for initializing and updating global settings.\n    \"\"\"\n\n\n@settings.command(**_CLICK_COMMAND_DEFAULTS)\n@click.option(\n    \"-p\",\n    \"--path\",\n    type=click.Path(exists=False, dir_okay=False, resolve_path=True),\n    default=None,\n    help=\"\"\"The path to write the settings to.\"\"\",\n)\n@click.option(\n    \"-f\",\n    \"--force\",\n    is_flag=True,\n    help=\"\"\"Force overwrite the file if it exists.\"\"\",\n)\n@click.pass_obj\ndef init(obj: SettingsObject, path: Optional[str] = None, force: bool = False):\n    \"\"\"\n    Initialize the settings file.\n    \"\"\"\n    path_to_write = Path(path or TangoGlobalSettings._DEFAULT_LOCATION)\n    if path_to_write.is_file() and not force:\n        raise click.ClickException(\"Settings file already exists! Use -f/--force to overwrite it.\")\n    obj.settings.to_file(path_to_write)\n    cli_logger.info(\n        \"[green]\\N{check mark} Settings file written to [bold]%s[/bold][/green]\", path_to_write\n    )\n\n\n@settings.group(name=\"set\", **_CLICK_GROUP_DEFAULTS)\n@click.pass_obj\ndef set_setting(obj: SettingsObject):\n    \"\"\"\n    Set a value in the settings file.\n    \"\"\"\n    if obj.settings.path is None:\n        raise click.ClickException(\n            \"Settings file not found! Did you forget to call 'tango settings init'?\"\n        )\n\n\n@set_setting.result_callback()\ndef save_settings(settings: TangoGlobalSettings):\n    settings.save()\n\n\n@set_setting.command(**_CLICK_COMMAND_DEFAULTS)\n@click.argument(\n    \"workspace\",\n    type=str,\n)\n@click.option(\n    \"--validate/--no-validate\",\n    type=bool,\n    help=\"Validate that the workspace can be initialized.\",\n    default=True,\n)\n@click.pass_obj\ndef workspace(obj: SettingsObject, workspace: str, validate: bool = True) -> TangoGlobalSettings:\n    \"\"\"\n    Set the default workspace path or URL.\n    \"\"\"\n    from urllib.parse import urlparse\n\n    if not urlparse(workspace).scheme:\n        obj.settings.workspace = {\"type\": \"local\", \"dir\": str(Path(workspace).resolve())}\n    else:\n        obj.settings.workspace = {\"type\": \"from_url\", \"url\": workspace}\n\n    if validate:\n        for package_name in obj.settings.include_package or []:\n            import_extra_module(package_name)\n\n        Workspace.from_params(obj.settings.workspace.copy())\n\n    return obj.settings\n\n\n@set_setting.command(**_CLICK_COMMAND_DEFAULTS)\n@click.argument(\n    \"packages\",\n    type=str,\n    nargs=-1,\n)\n@click.option(\n    \"-a\",\n    \"--append\",\n    is_flag=True,\n    help=\"Appends packages instead of overwriting.\",\n)\n@click.option(\n    \"--validate/--no-validate\",\n    type=bool,\n    help=\"Validate that the workspace can be initialized.\",\n    default=True,\n)\n@click.pass_obj\ndef include_package(\n    obj: SettingsObject,\n    packages: List[str],\n    append: bool = False,\n    validate: bool = True,\n) -> TangoGlobalSettings:\n    \"\"\"\n    Set or add modules to automatically import on 'tango run'.\n    \"\"\"\n    new_include: List[str]\n    if append:\n        new_include = obj.settings.include_package or []\n    else:\n        new_include = []\n    for package in packages:\n        if package not in new_include:\n            new_include.append(package)\n    obj.settings.include_package = new_include\n    if validate:\n        for package in obj.settings.include_package:\n            try:\n                import_module_and_submodules(package)\n            except (ModuleNotFoundError, ImportError):\n                raise click.ClickException(f\"Failed to import '{package}'\")\n    return obj.settings\n\n\n@set_setting.command(**_CLICK_COMMAND_DEFAULTS)\n@click.argument(\n    \"level\",\n    type=click.Choice([\"debug\", \"info\", \"warning\", \"error\"], case_sensitive=False),\n)\n@click.pass_obj\ndef log_level(obj: SettingsObject, level: str) -> TangoGlobalSettings:\n    \"\"\"\n    Set the log level.\n    \"\"\"\n    obj.settings.log_level = level.lower()\n    return obj.settings\n\n\n@set_setting.command(**_CLICK_COMMAND_DEFAULTS)\n@click.argument(\n    \"value\",\n    type=bool,\n)\n@click.pass_obj\ndef file_friendly_logging(obj: SettingsObject, value: bool) -> TangoGlobalSettings:\n    \"\"\"\n    Toggle file friendly logging mode.\n    \"\"\"\n    obj.settings.file_friendly_logging = value\n    return obj.settings\n\n\n@set_setting.command(**_CLICK_COMMAND_DEFAULTS)\n@click.argument(\n    \"start_method\",\n    type=click.Choice([\"fork\", \"spawn\", \"forkserver\"], case_sensitive=True),\n)\n@click.pass_obj\ndef multiprocessing_start_method(obj: SettingsObject, start_method: str) -> TangoGlobalSettings:\n    \"\"\"\n    Set the Python multiprocessing start method.\n    \"\"\"\n    obj.settings.multiprocessing_start_method = start_method\n    return obj.settings\n\n\n@set_setting.command(**_CLICK_COMMAND_DEFAULTS)\n@click.argument(\n    \"key\",\n    type=str,\n)\n@click.argument(\n    \"value\",\n    type=str,\n)\n@click.pass_obj\ndef env(obj: SettingsObject, key: str, value: str) -> TangoGlobalSettings:\n    \"\"\"\n    Add or update an environment variable.\n    \"\"\"\n    from tango.common.aliases import EnvVarNames\n\n    # These environment variables should not be set this way since they'll be ignored.\n    blocked_env_variable_names = EnvVarNames.values()\n\n    if key in blocked_env_variable_names:\n        raise click.ClickException(\n            f\"Cannot add environment variable '{key}' to settings. \"\n            f\"Please set the corresponding settings field instead.\"\n        )\n\n    if obj.settings.environment is None:\n        obj.settings.environment = {}\n    obj.settings.environment[key] = value\n    return obj.settings\n\n\ndef _run(\n    settings: TangoGlobalSettings,\n    experiment: str,\n    workspace_url: Optional[str] = None,\n    overrides: Optional[str] = None,\n    include_package: Optional[Sequence[str]] = None,\n    step_names: Optional[Sequence[str]] = None,\n    parallelism: Optional[int] = None,\n    multicore: Optional[bool] = None,\n    name: Optional[str] = None,\n    called_by_executor: bool = False,\n    ext_var: Optional[Sequence[str]] = None,\n) -> str:\n    # Read params.\n    ext_vars: Dict[str, str] = {}\n    for var in ext_var or []:\n        try:\n            key, value = var.split(\"=\")\n        except ValueError:\n            raise CliRunError(f\"Invalid --ext-var '{var}'\")\n        ext_vars[key] = value\n    params = Params.from_file(experiment, params_overrides=overrides or \"\", ext_vars=ext_vars)\n\n    # Import included packages to find registered components.\n    # NOTE: The Executor imports these as well because it's meant to be used\n    # directly, but we also need to import here in case the user is using a\n    # custom Executor, StepCache, or Workspace.\n    include_package: List[str] = list(include_package or [])\n    include_package += params.pop(\"include_package\", [])\n    include_package += settings.include_package or []\n    for package_name in include_package:\n        import_extra_module(package_name)\n\n    # Initialize step graph.\n    step_graph: StepGraph = StepGraph.from_params(params.pop(\"steps\"))\n    params.assert_empty(\"'tango run'\")\n\n    if step_names:\n        for step_name in step_names:\n            assert step_name in step_graph, (\n                f\"You want to run a step called '{step_name}', but it cannot be found in the experiment config. \"\n                f\"The config contains: {list(step_graph.keys())}.\"\n            )\n        step_graph = step_graph.sub_graph(*step_names)\n\n    # Execute step graph in workspace\n\n    workspace = prepare_workspace(settings=settings, workspace_url=workspace_url)\n\n    executor = prepare_executor(\n        settings=settings,\n        workspace=workspace,\n        include_package=include_package,\n        parallelism=parallelism,\n        multicore=multicore,\n        called_by_executor=called_by_executor,\n    )\n\n    run_name = execute_step_graph(\n        step_graph=step_graph,\n        workspace=workspace,\n        executor=executor,\n        name=name,\n        called_by_executor=called_by_executor,\n        step_names=step_names,\n    )\n\n    return run_name\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "tango/cli.py",
    "content": "import logging\nimport multiprocessing as mp\nimport os\nimport sys\nimport warnings\nfrom contextlib import contextmanager, nullcontext\nfrom typing import TYPE_CHECKING, Optional, Sequence, Union\n\nfrom tango.common.exceptions import CliRunError\nfrom tango.common.logging import (\n    cli_logger,\n    initialize_logging,\n    initialize_prefix_logging,\n    teardown_logging,\n)\nfrom tango.common.params import Params\nfrom tango.executor import Executor\nfrom tango.settings import TangoGlobalSettings\nfrom tango.step_graph import StepGraph\nfrom tango.workspace import Workspace\n\nif TYPE_CHECKING:\n    from tango.executor import ExecutorOutput\n    from tango.workspace import Run\n\n\nlogger = logging.getLogger(__name__)\n\n\ndef load_settings(settings: Union[str, Params, dict, None] = None) -> TangoGlobalSettings:\n    return (\n        TangoGlobalSettings.from_file(settings)\n        if isinstance(settings, str)\n        else TangoGlobalSettings.from_params(settings)\n        if isinstance(settings, (Params, dict))\n        else TangoGlobalSettings.default()\n    )\n\n\n@contextmanager\ndef tango_cli(settings: Union[TangoGlobalSettings, str, Params, dict, None] = None):\n    if not isinstance(settings, TangoGlobalSettings):\n        settings = load_settings(settings)\n\n    try:\n        initialize_cli(settings=settings, called_by_executor=False)\n        yield\n    finally:\n        cleanup_cli()\n\n\ndef initialize_cli(\n    settings: Optional[TangoGlobalSettings] = None,\n    called_by_executor: bool = False,\n):\n    if settings is None:\n        settings = TangoGlobalSettings.default()\n\n    if not sys.warnoptions:\n        warnings.simplefilter(\"default\", category=DeprecationWarning)\n\n    if settings.environment:\n        from tango.common.aliases import EnvVarNames\n\n        # These environment variables should not be set this way since they'll be ignored.\n        blocked_env_variable_names = EnvVarNames.values()\n\n        for key, value in settings.environment.items():\n            if key not in blocked_env_variable_names:\n                os.environ[key] = value\n            else:\n                warnings.warn(\n                    f\"Ignoring environment variable '{key}' from settings file. \"\n                    f\"Please use the corresponding settings field instead.\",\n                    UserWarning,\n                )\n\n    mp.set_start_method(settings.multiprocessing_start_method)\n\n    if not called_by_executor:\n        initialize_logging(\n            log_level=settings.log_level,\n            file_friendly_logging=settings.file_friendly_logging,\n            enable_cli_logs=True,\n        )\n\n\ndef cleanup_cli():\n    teardown_logging()\n\n\ndef prepare_workspace(\n    settings: Optional[TangoGlobalSettings] = None,\n    workspace_url: Optional[str] = None,\n) -> Workspace:\n    from tango.workspaces import default_workspace\n\n    if settings is None:\n        settings = TangoGlobalSettings.default()\n\n    workspace: Workspace\n    if workspace_url is not None:\n        workspace = Workspace.from_url(workspace_url)\n    elif settings.workspace is not None:\n        workspace = Workspace.from_params(settings.workspace)\n    else:\n        workspace = default_workspace\n\n    return workspace\n\n\ndef prepare_executor(\n    workspace: Workspace,\n    settings: Optional[TangoGlobalSettings] = None,\n    include_package: Optional[Sequence[str]] = None,\n    parallelism: Optional[int] = None,\n    multicore: Optional[bool] = None,\n    called_by_executor: bool = False,\n) -> Executor:\n    from tango.executors import MulticoreExecutor\n    from tango.workspaces import MemoryWorkspace\n\n    if settings is None:\n        settings = TangoGlobalSettings.default()\n\n    executor: Executor\n    if not called_by_executor and settings.executor is not None:\n        if multicore is not None:\n            logger.warning(\n                \"Ignoring argument 'multicore' since executor is defined in %s\",\n                settings.path or \"setting\",\n            )\n        executor = Executor.from_params(\n            settings.executor,\n            workspace=workspace,\n            include_package=include_package,\n            **(dict(parallelism=parallelism) if parallelism is not None else {}),  # type: ignore\n        )\n    else:\n        # Determine if we can use the multicore executor.\n        if multicore is None:\n            if isinstance(workspace, MemoryWorkspace):\n                # Memory workspace does not work with multiple cores.\n                multicore = False\n            elif \"pydevd\" in sys.modules:\n                # Pydevd doesn't reliably follow child processes, so we disable multicore under the debugger.\n                logger.warning(\"Debugger detected, disabling multicore.\")\n                multicore = False\n            elif parallelism is None or parallelism == 0:\n                multicore = False\n            else:\n                multicore = True\n\n        if multicore:\n            executor = MulticoreExecutor(\n                workspace=workspace, include_package=include_package, parallelism=parallelism\n            )\n        else:\n            executor = Executor(workspace=workspace, include_package=include_package)\n\n    return executor\n\n\ndef execute_step_graph(\n    step_graph: StepGraph,\n    workspace: Optional[Workspace] = None,\n    executor: Optional[Executor] = None,\n    name: Optional[str] = None,\n    called_by_executor: bool = False,\n    step_names: Optional[Sequence[str]] = None,\n) -> str:\n    if workspace is None:\n        workspace = prepare_workspace()\n        executor = prepare_executor(workspace=workspace)\n    elif executor is None:\n        executor = prepare_executor(workspace=workspace)\n\n    # Register run.\n    run: \"Run\"\n    if called_by_executor and name is not None:\n        try:\n            run = workspace.registered_run(name)\n        except KeyError:\n            raise RuntimeError(\n                \"The CLI was called by `MulticoreExecutor.execute_step_graph`, but \"\n                f\"'{name}' is not already registered as a run. This should never happen!\"\n            )\n    else:\n        run = workspace.register_run((step for step in step_graph.values()), name)\n\n    if called_by_executor:\n        assert step_names is not None and len(step_names) == 1\n        from tango.common.aliases import EnvVarNames\n\n        # We set this environment variable so that any steps that contain multiprocessing\n        # and call `initialize_worker_logging` also log the messages with the `step_name` prefix.\n        os.environ[EnvVarNames.LOGGING_PREFIX.value] = f\"step {step_names[0]}\"\n        initialize_prefix_logging(prefix=f\"step {step_names[0]}\", main_process=False)\n\n    # Capture logs to file.\n    with workspace.capture_logs_for_run(run.name) if not called_by_executor else nullcontext():\n        if not called_by_executor:\n            cli_logger.info(\"[green]Starting new run [bold]%s[/][/]\", run.name)\n\n        executor_output: ExecutorOutput = executor.execute_step_graph(step_graph, run_name=run.name)\n\n        if executor_output.failed:\n            cli_logger.error(\"[red]\\N{ballot x} Run [bold]%s[/] finished with errors[/]\", run.name)\n        elif not called_by_executor:\n            cli_logger.info(\"[green]\\N{check mark} Finished run [bold]%s[/][/]\", run.name)\n\n        if executor_output is not None:\n            if not called_by_executor:\n                executor_output.display()\n            if executor_output.failed:\n                raise CliRunError\n\n    return run.name\n"
  },
  {
    "path": "tango/common/__init__.py",
    "content": "from .aliases import PathOrStr\nfrom .dataset_dict import DatasetDict, DatasetDictBase, IterableDatasetDict\nfrom .det_hash import det_hash\nfrom .from_params import FromParams\nfrom .lazy import Lazy\nfrom .params import Params\nfrom .registrable import Registrable, RegistrableFunction, make_registrable\nfrom .tqdm import Tqdm\nfrom .util import filename_is_safe, threaded_generator\n\n__all__ = [\n    \"PathOrStr\",\n    \"DatasetDictBase\",\n    \"DatasetDict\",\n    \"IterableDatasetDict\",\n    \"det_hash\",\n    \"Params\",\n    \"FromParams\",\n    \"Registrable\",\n    \"RegistrableFunction\",\n    \"make_registrable\",\n    \"Lazy\",\n    \"Tqdm\",\n    \"filename_is_safe\",\n    \"threaded_generator\",\n]\n"
  },
  {
    "path": "tango/common/aliases.py",
    "content": "from enum import Enum, unique\nfrom os import PathLike\nfrom typing import Set, Union\n\nPathOrStr = Union[str, PathLike]\n\n\n@unique\nclass EnvVarNames(Enum):\n    FILE_FRIENDLY_LOGGING = \"FILE_FRIENDLY_LOGGING\"\n    LOG_LEVEL = \"TANGO_LOG_LEVEL\"\n    CLI_LOGGER_ENABLED = \"TANGO_CLI_LOGGER_ENABLED\"\n    LOGGING_HOST = \"TANGO_LOGGING_HOST\"\n    LOGGING_PORT = \"TANGO_LOGGING_PORT\"\n    LOGGING_PREFIX = \"TANGO_LOGGING_PREFIX\"\n    CONSOLE_WIDTH = \"TANGO_CONSOLE_WIDTH\"\n\n    @classmethod\n    def values(cls) -> Set[str]:\n        return set(e.value for e in cls)\n"
  },
  {
    "path": "tango/common/dataset_dict.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import Any, Generic, Iterable, Iterator, Mapping, Sequence, TypeVar\n\nT = TypeVar(\"T\")\nS = TypeVar(\"S\")\n\n\n@dataclass\nclass DatasetDictBase(Generic[S], Mapping[str, S]):\n    \"\"\"\n    The base class for :class:`DatasetDict` and :class:`IterableDatasetDict`.\n    \"\"\"\n\n    splits: Mapping[str, S]\n    \"\"\"\n    A mapping of dataset split names to splits.\n    \"\"\"\n\n    metadata: Mapping[str, Any] = field(default_factory=dict)\n    \"\"\"\n    Metadata can contain anything you need.\n    \"\"\"\n\n    def __getitem__(self, split: str) -> S:\n        \"\"\"\n        Get a split in :attr:`splits`.\n        \"\"\"\n        return self.splits[split]\n\n    def __contains__(self, split: str) -> bool:  # type: ignore[override]\n        \"\"\"\n        Checks if :attr:`splits` contains the given split.\n        \"\"\"\n        return split in self.splits\n\n    def __iter__(self) -> Iterator[str]:\n        \"\"\"\n        Returns an iterator over the keys in :attr:`splits`.\n        \"\"\"\n        return iter(self.splits.keys())\n\n    def __len__(self) -> int:\n        \"\"\"\n        Returns the number of splits in :attr:`splits`.\n        \"\"\"\n        return len(self.splits)\n\n    def keys(self):\n        \"\"\"\n        Returns the split names in :attr:`splits`.\n        \"\"\"\n        return self.splits.keys()\n\n\n@dataclass\nclass DatasetDict(DatasetDictBase[Sequence[T]], Generic[T]):\n    \"\"\"\n    A generic :class:`~collections.abc.Mapping` class of split names (:class:`str`) to datasets\n    (``Sequence[T]``).\n    \"\"\"\n\n\n@dataclass\nclass IterableDatasetDict(DatasetDictBase[Iterable[T]], Generic[T]):\n    \"\"\"\n    An \"iterable\" version of :class:`DatasetDict`, where the dataset splits have\n    type ``Iterable[T]`` instead of ``Sequence[T]``. This is useful for streaming datasets.\n    \"\"\"\n"
  },
  {
    "path": "tango/common/det_hash.py",
    "content": "import collections\nimport hashlib\nimport io\nfrom abc import abstractmethod\nfrom typing import Any, MutableMapping, Optional, Type\n\nimport base58\nimport dill\n\nndarray: Optional[Type]\ntry:\n    from numpy import ndarray\nexcept ModuleNotFoundError:\n    ndarray = None\n\nTorchTensor: Optional[Type]\ntry:\n    from torch import Tensor as TorchTensor\nexcept ModuleNotFoundError:\n    TorchTensor = None\n\n\nclass CustomDetHash:\n    \"\"\"\n    By default, :func:`det_hash()` pickles an object, and returns the hash of the pickled\n    representation. Sometimes you want to take control over what goes into\n    that hash. In that case, derive from this class and implement :meth:`det_hash_object()`.\n    :func:`det_hash()` will pickle the result of this method instead of the object itself.\n\n    If you return ``None``, :func:`det_hash()` falls back to the original behavior and pickles\n    the object.\n    \"\"\"\n\n    @abstractmethod\n    def det_hash_object(self) -> Any:\n        \"\"\"\n        Return an object to use for deterministic hashing instead of ``self``.\n        \"\"\"\n        raise NotImplementedError()\n\n\nclass DetHashFromInitParams(CustomDetHash):\n    \"\"\"\n    Add this class as a mixin base class to make sure your class's det_hash is derived\n    exclusively from the parameters passed to ``__init__()``.\n    \"\"\"\n\n    _det_hash_object: Any\n\n    def __new__(cls, *args, **kwargs):\n        super_new = super(DetHashFromInitParams, cls).__new__\n        if super().__new__ is object.__new__ and cls.__init__ is not object.__init__:\n            instance = super_new(cls)\n        else:\n            instance = super_new(cls, *args, **kwargs)\n        instance._det_hash_object = (args, kwargs)\n        return instance\n\n    def det_hash_object(self) -> Any:\n        \"\"\"Returns a copy of the parameters that were passed to the class instance's ``__init__()`` method.\"\"\"\n        return self._det_hash_object\n\n\nclass DetHashWithVersion(CustomDetHash):\n    \"\"\"\n    Add this class as a mixin base class to make sure your class's det_hash can be modified\n    by altering a static ``VERSION`` member of your class.\n\n    Let's say you are working on training a model. Whenever you change code that's part of your experiment,\n    you have to change the :attr:`~tango.step.Step.VERSION` of the step that's running that code to tell\n    Tango that the step has changed and should be re-run. But if\n    you are training your model using Tango's built-in :class:`~tango.integrations.torch.TorchTrainStep`,\n    how do you change the version of the step? The answer is, leave the version of the step alone, and\n    instead add a :attr:`VERSION` to your model by deriving from this class:\n\n    .. code-block:: Python\n\n        class MyModel(DetHashWithVersion):\n            VERSION = \"001\"\n\n            def __init__(self, ...):\n                ...\n    \"\"\"\n\n    VERSION: Optional[str] = None\n\n    def det_hash_object(self) -> Any:\n        \"\"\"\n        Returns a tuple of :attr:`~tango.common.det_hash.DetHashWithVersion.VERSION` and this instance itself.\n        \"\"\"\n        if self.VERSION is not None:\n            return self.VERSION, self\n        else:\n            return None  # When you return `None` from here, it falls back to just hashing the object itself.\n\n\n_PICKLE_PROTOCOL = 4\n\n\nclass _DetHashPickler(dill.Pickler):\n    def __init__(self, buffer: io.BytesIO):\n        super().__init__(buffer, protocol=_PICKLE_PROTOCOL)\n\n        # We keep track of how deeply we are nesting the pickling of an object.\n        # If a class returns `self` as part of `det_hash_object()`, it causes an\n        # infinite recursion, because we try to pickle the `det_hash_object()`, which\n        # contains `self`, which returns a `det_hash_object()`, etc.\n        # So we keep track of how many times recursively we are trying to pickle the\n        # same object. We only call `det_hash_object()` the first time. We assume that\n        # if `det_hash_object()` returns `self` in any way, we want the second time\n        # to just pickle the object as normal. `DetHashWithVersion` takes advantage\n        # of this ability.\n        self.recursively_pickled_ids: MutableMapping[int, int] = collections.Counter()\n\n    def save(self, obj, save_persistent_id=True):\n        self.recursively_pickled_ids[id(obj)] += 1\n        super().save(obj, save_persistent_id)\n        self.recursively_pickled_ids[id(obj)] -= 1\n\n    def persistent_id(self, obj: Any) -> Any:\n        if isinstance(obj, CustomDetHash) and self.recursively_pickled_ids[id(obj)] <= 1:\n            det_hash_object = obj.det_hash_object()\n            if det_hash_object is not None:\n                return obj.__class__.__module__, obj.__class__.__qualname__, det_hash_object\n            else:\n                return None\n        elif isinstance(obj, type):\n            return obj.__module__, obj.__qualname__\n        elif callable(obj):\n            if hasattr(obj, \"__module__\") and hasattr(obj, \"__qualname__\"):\n                return obj.__module__, obj.__qualname__\n            else:\n                return None\n        elif ndarray is not None and isinstance(obj, ndarray):\n            # It's unclear why numpy arrays don't pickle in a consistent way.\n            return obj.dumps()\n        elif TorchTensor is not None and isinstance(obj, TorchTensor):\n            # It's unclear why torch tensors don't pickle in a consistent way.\n            import torch\n\n            with io.BytesIO() as buffer:\n                torch.save(obj, buffer, pickle_protocol=_PICKLE_PROTOCOL)\n                return buffer.getvalue()\n        else:\n            return None\n\n\ndef det_hash(o: Any) -> str:\n    \"\"\"\n    Returns a deterministic hash code of arbitrary Python objects.\n\n    If you want to override how we calculate the deterministic hash, derive from the\n    :class:`CustomDetHash` class and implement :meth:`CustomDetHash.det_hash_object()`.\n    \"\"\"\n    m = hashlib.blake2b()\n    with io.BytesIO() as buffer:\n        pickler = _DetHashPickler(buffer)\n        pickler.dump(o)\n        m.update(buffer.getbuffer())\n        return base58.b58encode(m.digest()).decode()\n"
  },
  {
    "path": "tango/common/exceptions.py",
    "content": "from typing import TYPE_CHECKING, Any, Optional, Set, Tuple, Union\n\nif TYPE_CHECKING:\n    from tango.step import Step\n    from tango.step_info import StepInfo, StepState\n\n\nclass TangoError(Exception):\n    \"\"\"\n    Base class for Tango exceptions.\n    \"\"\"\n\n\nclass ConfigurationError(TangoError):\n    \"\"\"\n    The exception raised when a Tango object fails to initialize from a config\n    that's misconfigured (e.g. missing properties, invalid properties, unknown properties).\n    \"\"\"\n\n    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:\n        return type(self), (self.message,)\n\n    def __init__(self, message: str):\n        super().__init__()\n        self.message = message\n\n    def __str__(self):\n        return self.message\n\n\nclass RegistryKeyError(ConfigurationError):\n    \"\"\"\n    A configuration error that is raised when attempting to get a class by a registered name\n    that doesn't exist in the registry.\n    \"\"\"\n\n\nclass CancellationError(TangoError):\n    \"\"\"\n    Base class for errors raised due to manual cancellation of a run or step.\n    \"\"\"\n\n\nclass SigTermReceived(CancellationError):\n    \"\"\"\n    Raised when a SIGTERM is caught.\n    \"\"\"\n\n\nclass StepCancelled(CancellationError):\n    pass\n\n\nclass RunCancelled(CancellationError):\n    pass\n\n\nclass CliRunError(TangoError):\n    \"\"\"\n    Raised when `tango run` command fails.\n    \"\"\"\n\n\nclass IntegrationMissingError(TangoError):\n    \"\"\"\n    Raised when an integration can't be used due to missing dependencies.\n    \"\"\"\n\n    def __init__(self, integration: str, dependencies: Optional[Set[str]] = None):\n        self.integration = integration\n        self.dependencies = dependencies or {integration}\n        msg = (\n            f\"'{self.integration}' integration can't be used due to \"\n            f\"missing dependencies ({', '.join(self.dependencies)})\"\n        )\n        super().__init__(msg)\n\n\nclass StepStateError(TangoError):\n    \"\"\"\n    Raised when a step is in an unexpected state.\n    \"\"\"\n\n    def __init__(\n        self,\n        step: Union[\"Step\", \"StepInfo\"],\n        step_state: \"StepState\",\n        context: Optional[str] = None,\n    ):\n        self.step_state = step_state\n        self.step_id = step.unique_id\n        msg = f\"Step '{self.step_id}' is in unexpected state '{self.step_state.value}'\"\n        if context is not None:\n            msg = msg + \" \" + context\n        super().__init__(msg)\n\n\nclass DirtyRepoError(TangoError):\n    \"\"\"\n    Raised when a repository is in a dirty state.\n    \"\"\"\n\n\nclass ExecutorError(TangoError):\n    \"\"\"\n    A base class for executor-specific errors.\n    \"\"\"\n"
  },
  {
    "path": "tango/common/file_lock.py",
    "content": "import os\nimport warnings\nfrom typing import Optional\n\nfrom filelock import AcquireReturnProxy\nfrom filelock import FileLock as _FileLock\n\nfrom .aliases import PathOrStr\n\n\nclass FileLock(_FileLock):  # type: ignore[valid-type,misc]\n    \"\"\"\n    This is just a subclass of the `FileLock` class from the `filelock` library, except that\n    it adds an additional argument to the `__init__` method: `read_only_ok`.\n\n    By default this flag is `False`, which an exception will be thrown when a lock\n    can't be acquired due to lack of write permissions.\n    But if this flag is set to `True`, a warning will be emitted instead of an error when\n    the lock already exists but the lock can't be acquired because write access is blocked.\n    \"\"\"\n\n    def __init__(self, lock_file: PathOrStr, timeout=-1, read_only_ok: bool = False) -> None:\n        super().__init__(str(lock_file), timeout=timeout)\n        self._read_only_ok = read_only_ok\n\n    def acquire(  # type: ignore[override]\n        self,\n        timeout: Optional[float] = None,\n        poll_interval: float = 0.05,\n    ) -> AcquireReturnProxy:\n        try:\n            return super().acquire(timeout=timeout, poll_interval=poll_interval)\n        except OSError as err:\n            # OSError could be a lot of different things, but what we're looking\n            # for in particular are permission errors, such as:\n            #  - errno 1  - EPERM  - \"Operation not permitted\"\n            #  - errno 13 - EACCES - \"Permission denied\"\n            #  - errno 30 - EROFS  - \"Read-only file system\"\n            if err.errno not in (1, 13, 30):\n                raise\n\n            if os.path.isfile(self._lock_file) and self._read_only_ok:  # type: ignore\n                warnings.warn(\n                    f\"Lacking permissions required to obtain lock '{self._lock_file}'. \"  # type: ignore\n                    \"Race conditions are possible if other processes are writing to the same resource.\",\n                    UserWarning,\n                )\n                return AcquireReturnProxy(self)\n            else:\n                raise\n\n    def acquire_with_updates(self, desc: Optional[str] = None) -> AcquireReturnProxy:\n        \"\"\"\n        Same as :meth:`acquire()`, except that when the lock cannot be immediately acquired,\n        it will keep trying and print status updates as it goes.\n        \"\"\"\n        try:\n            return self.acquire(timeout=0.1)\n        except TimeoutError:\n            pass\n\n        from .tqdm import Tqdm\n\n        if desc is None:\n            desc = f\"acquiring lock at {self._lock_file}\"  # type: ignore\n\n        progress = Tqdm.tqdm(desc=desc, bar_format=\"{desc} [{elapsed}]\")\n        while True:\n            progress.update()\n            try:\n                return self.acquire(timeout=1)\n            except TimeoutError:\n                continue\n"
  },
  {
    "path": "tango/common/from_params.py",
    "content": "import collections.abc\nimport inspect\nimport logging\nfrom copy import deepcopy\nfrom pathlib import Path\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    Mapping,\n    Optional,\n    Set,\n    Tuple,\n    Type,\n    TypeVar,\n    Union,\n    cast,\n    get_type_hints,\n)\n\nfrom tango.common.det_hash import DetHashWithVersion\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.lazy import Lazy\nfrom tango.common.params import Params\n\ntry:\n    # For PEP 604 support (python >= 3.10)\n    from types import UnionType  # type: ignore[attr-defined]\nexcept ImportError:\n\n    class UnionType:  # type: ignore\n        pass\n\n\nlogger = logging.getLogger(__name__)\n\nT = TypeVar(\"T\", bound=\"FromParams\")\n\n# If a function parameter has no default value specified,\n# this is what the inspect module returns.\n_NO_DEFAULT = inspect.Parameter.empty\n\n\ndef takes_arg(obj, arg: str) -> bool:\n    \"\"\"\n    Checks whether the provided obj takes a certain arg.\n    If it's a class, we're really checking whether its constructor does.\n    If it's a function or method, we're checking the object itself.\n    Otherwise, we raise an error.\n    \"\"\"\n    if inspect.isclass(obj):\n        signature = inspect.signature(obj.__init__)\n    elif inspect.ismethod(obj) or inspect.isfunction(obj):\n        signature = inspect.signature(obj)\n    else:\n        raise ConfigurationError(f\"object {obj} is not callable\")\n    return arg in signature.parameters\n\n\ndef takes_kwargs(obj) -> bool:\n    \"\"\"\n    Checks whether a provided object takes in any positional arguments.\n    Similar to takes_arg, we do this for both the __init__ function of\n    the class or a function / method\n    Otherwise, we raise an error\n    \"\"\"\n    if inspect.isclass(obj):\n        signature = inspect.signature(obj.__init__)\n    elif inspect.ismethod(obj) or inspect.isfunction(obj):\n        signature = inspect.signature(obj)\n    else:\n        raise ConfigurationError(f\"object {obj} is not callable\")\n    return any(\n        p.kind == inspect.Parameter.VAR_KEYWORD  # type: ignore\n        for p in signature.parameters.values()\n    )\n\n\ndef is_base_registrable(cls) -> bool:\n    \"\"\"\n    Checks whether this is a class that directly inherits from Registrable, or is a subclass of such\n    a class.\n    \"\"\"\n    from tango.common.registrable import (\n        Registrable,  # import here to avoid circular imports\n    )\n\n    if not issubclass(cls, Registrable):\n        return False\n    method_resolution_order = inspect.getmro(cls)[1:]\n    for base_class in method_resolution_order:\n        if issubclass(base_class, Registrable) and base_class is not Registrable:\n            return False\n    return True\n\n\ndef remove_optional(annotation: type):\n    \"\"\"\n    Optional[X] annotations are actually represented as Union[X, NoneType].\n    For our purposes, the \"Optional\" part is not interesting, so here we\n    throw it away.\n    \"\"\"\n    origin = getattr(annotation, \"__origin__\", None)\n    args = getattr(annotation, \"__args__\", ())\n\n    if origin == Union:\n        return Union[tuple([arg for arg in args if arg != type(None)])]  # noqa: E721\n    else:\n        return annotation\n\n\ndef infer_constructor_params(\n    cls: Type[T], constructor: Optional[Union[Callable[..., T], Callable[[T], None]]] = None\n) -> Dict[str, inspect.Parameter]:\n    if constructor is None:\n        constructor = cls.__init__\n    return infer_method_params(cls, constructor)\n\n\ninfer_params = infer_constructor_params  # Legacy name\n\n\ndef infer_method_params(\n    cls: Type[T], method: Callable, infer_kwargs: bool = True\n) -> Dict[str, inspect.Parameter]:\n    signature = inspect.signature(method)\n    parameters = dict(signature.parameters)\n\n    has_kwargs = False\n    var_positional_key = None\n    for param_name in list(parameters.keys()):\n        # Ignore special private parameters.\n        # This is necessary to make `FromParams` work with Pydantic, for example.\n        if param_name.startswith(\"__\"):\n            del parameters[param_name]\n            continue\n\n        param = parameters[param_name]\n        if param.kind == param.VAR_KEYWORD:\n            has_kwargs = True\n        elif param.kind == param.VAR_POSITIONAL:\n            var_positional_key = param.name\n        if isinstance(param.annotation, str):\n            # For Python < 3.10, if the module where this class was defined used\n            # `from __future__ import annotation`, the annotation will be a str,\n            # so we need to resolve it using `get_type_hints` from the typing module.\n            # See https://www.python.org/dev/peps/pep-0563/ for more info.\n            try:\n                parameters[param_name] = param.replace(\n                    annotation=get_type_hints(method)[param_name]\n                )\n            except TypeError as e:\n                if \"'type' object is not subscriptable\" in str(e):\n                    # This can happen when someone uses a type hint like `dict[str, str]`\n                    # instead of `Dict[str, str]`.\n                    err_msg = (\n                        f\"Failed to parse the type annotation `{param.annotation}` \"\n                        f\"from `{cls.__qualname__}.{method.__name__}()`.\"\n                    )\n\n                    if \"[\" in param.annotation:\n                        # Check if there is an equivalent generic in the `typing` module.\n                        import typing\n\n                        type_, *_ = param.annotation.split(\"[\", 1)\n                        for possible_typing_equivalent in {type_, type_.title()}:\n                            if hasattr(typing, possible_typing_equivalent):\n                                err_msg += (\n                                    f\" Try using `{possible_typing_equivalent}` \"\n                                    \"from the `typing` module instead.\"\n                                )\n                                break\n\n                    new_e = TypeError(err_msg)\n                    new_e.__cause__ = e\n                    new_e.__cause__ = e\n                    raise new_e\n                else:\n                    raise\n\n    if var_positional_key:\n        del parameters[var_positional_key]\n\n    if not has_kwargs or not infer_kwargs:\n        return parameters\n\n    # \"mro\" is \"method resolution order\". The first one is the current class, the next is the\n    # first superclass, and so on. We take the first superclass we find that inherits from\n    # FromParams.\n    super_class = None\n    # We have to be a little careful here because in some cases we might not have an\n    # actual class. Instead we might just have a function that returns a class instance.\n    if hasattr(cls, \"mro\"):\n        for super_class_candidate in cls.mro()[1:]:\n            if issubclass(super_class_candidate, FromParams):\n                super_class = super_class_candidate\n                break\n    if super_class:\n        super_parameters = infer_params(super_class)\n    else:\n        super_parameters = {}\n\n    return {**super_parameters, **parameters}  # Subclass parameters overwrite superclass ones\n\n\ndef create_kwargs(\n    constructor: Callable[..., T],\n    cls: Type[T],\n    params: Params,\n    extras: Optional[Dict[str, Any]] = None,\n) -> Dict[str, Any]:\n    \"\"\"\n    Given some class, a ``Params`` object, and potentially other keyword arguments,\n    create a dict of keyword args suitable for passing to the class's constructor.\n\n    The function does this by finding the class's constructor, matching the constructor\n    arguments to entries in the ``params`` object, and instantiating values for the parameters\n    using the type annotation and possibly a from_params method.\n\n    Any values that are provided in the ``extras`` will just be used as is.\n    For instance, you might provide an existing ``Vocabulary`` this way.\n    \"\"\"\n    # Get the signature of the constructor.\n\n    kwargs: Dict[str, Any] = {}\n\n    parameters = infer_params(cls, constructor)\n    accepts_kwargs = False\n\n    # Iterate over all the constructor parameters and their annotations.\n    for param_name, param in parameters.items():\n        # Skip \"self\". You're not *required* to call the first parameter \"self\",\n        # so in theory this logic is fragile, but if you don't call the self parameter\n        # \"self\" you kind of deserve what happens.\n        if param_name == \"self\":\n            continue\n\n        if param.kind == param.VAR_KEYWORD:\n            # When a class takes **kwargs, we do two things: first, we assume that the **kwargs are\n            # getting passed to the super class, so we inspect super class constructors to get\n            # allowed arguments (that happens in `infer_params` above).  Second, we store the fact\n            # that the method allows extra keys; if we get extra parameters, instead of crashing,\n            # we'll just pass them as-is to the constructor, and hope that you know what you're\n            # doing.\n            accepts_kwargs = True\n            continue\n\n        # If the annotation is a compound type like typing.Dict[str, int],\n        # it will have an __origin__ field indicating `typing.Dict`\n        # and an __args__ field indicating `(str, int)`. We capture both.\n        annotation = remove_optional(param.annotation)\n\n        explicitly_set = param_name in params\n        constructed_arg = pop_and_construct_arg(\n            cls.__name__, param_name, annotation, param.default, params, extras or {}\n        )\n\n        # If the param wasn't explicitly set in `params` and we just ended up constructing\n        # the default value for the parameter, we can just omit it.\n        # Leaving it in can cause issues with **kwargs in some corner cases, where you might end up\n        # with multiple values for a single parameter (e.g., the default value gives you lazy=False\n        # for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes\n        # lazy=True - the superclass sees both lazy=True and lazy=False in its constructor).\n        if explicitly_set or constructed_arg is not param.default:\n            kwargs[param_name] = constructed_arg\n\n    if accepts_kwargs:\n        for key in list(params):\n            kwargs[key] = params.pop(key, keep_as_dict=True)\n        if extras:\n            for key, value in extras.items():\n                kwargs[key] = value\n    params.assert_empty(cls.__name__)\n    return kwargs\n\n\ndef create_extras(cls: Type[T], extras: Dict[str, Any]) -> Dict[str, Any]:\n    \"\"\"\n    Given a dictionary of extra arguments, returns a dictionary of\n    kwargs that actually are a part of the signature of the cls.from_params\n    (or cls) method.\n    \"\"\"\n    subextras: Dict[str, Any] = {}\n    if hasattr(cls, \"from_params\"):\n        from_params_method = cls.from_params  # type: ignore\n    else:\n        # In some rare cases, we get a registered subclass that does _not_ have a\n        # from_params method (this happens with Activations, for instance, where we\n        # register pytorch modules directly).  This is a bit of a hack to make those work,\n        # instead of adding a `from_params` method for them somehow. Then the extras\n        # in the class constructor are what we are looking for, to pass on.\n        from_params_method = cls\n    if takes_kwargs(from_params_method):\n        # If annotation.params accepts **kwargs, we need to pass them all along.\n        # For example, `BasicTextFieldEmbedder.from_params` requires a Vocabulary\n        # object, but `TextFieldEmbedder.from_params` does not.\n        subextras = extras\n    else:\n        # Otherwise, only supply the ones that are actual args; any additional ones\n        # will cause a TypeError.\n        subextras = {k: v for k, v in extras.items() if takes_arg(from_params_method, k)}\n    return subextras\n\n\ndef pop_and_construct_arg(\n    class_name: str,\n    argument_name: str,\n    annotation: Type,\n    default: Any,\n    params: Params,\n    extras: Dict[str, Any],\n) -> Any:\n    \"\"\"\n    Does the work of actually constructing an individual argument for\n    [``create_kwargs``](./#create_kwargs).\n\n    Here we're in the inner loop of iterating over the parameters to a particular constructor,\n    trying to construct just one of them.  The information we get for that parameter is its name,\n    its type annotation, and its default value; we also get the full set of ``Params`` for\n    constructing the object (which we may mutate), and any ``extras`` that the constructor might\n    need.\n\n    We take the type annotation and default value here separately, instead of using an\n    ``inspect.Parameter`` object directly, so that we can handle ``Union`` types using recursion on\n    this method, trying the different annotation types in the union in turn.\n    \"\"\"\n    # We used `argument_name` as the method argument to avoid conflicts with 'name' being a key in\n    # `extras`, which isn't _that_ unlikely.  Now that we are inside the method, we can switch back\n    # to using `name`.\n    name = argument_name\n\n    # Some constructors expect extra non-parameter items, e.g. vocab: Vocabulary.\n    # We check the provided `extras` for these and just use them if they exist.\n    if name in extras:\n        if name not in params:\n            return extras[name]\n        else:\n            logger.warning(\n                f\"Parameter {name} for class {class_name} was found in both \"\n                \"**extras and in params. Using the specification found in params, \"\n                \"but you probably put a key in a config file that you didn't need, \"\n                \"and if it is different from what we get from **extras, you might \"\n                \"get unexpected behavior.\"\n            )\n\n    try:\n        popped_params = params.pop(name, default) if default != _NO_DEFAULT else params.pop(name)\n    except ConfigurationError:\n        raise ConfigurationError(f'Missing key \"{name}\" for {class_name}')\n\n    if popped_params is None:\n        return None\n\n    return construct_arg(class_name, name, popped_params, annotation, default)\n\n\ndef _params_contain_step(o: Any) -> bool:\n    from tango.step import Step\n\n    if isinstance(o, Step):\n        return True\n    elif isinstance(o, str):\n        return False  # Confusingly, str is an Iterable of itself, resulting in infinite recursion.\n    elif isinstance(o, Params):\n        return _params_contain_step(o.as_dict(quiet=True))\n    elif isinstance(o, dict):\n        if set(o.keys()) == {\"type\", \"ref\"} and o[\"type\"] == \"ref\":\n            return True\n        else:\n            return _params_contain_step(o.values())\n    elif isinstance(o, Iterable):\n        return any(_params_contain_step(p) for p in o)\n    else:\n        return False\n\n\ndef construct_arg(\n    class_name: str,\n    argument_name: str,\n    popped_params: Params,\n    annotation: Type,\n    default: Any,\n    try_from_step: bool = True,\n) -> Any:\n    \"\"\"\n    The first two parameters here are only used for logging if we encounter an error.\n    \"\"\"\n    # If we have the default, we're already done :)\n    if popped_params is default:\n        return popped_params\n\n    from tango.step import FunctionalStep, Step, StepIndexer, WithUnresolvedSteps\n\n    origin = getattr(annotation, \"__origin__\", None)\n    args = getattr(annotation, \"__args__\", [])\n\n    # Try to guess if `popped_params` might be a step, come from a step, or contain a step.\n    could_be_step = (\n        try_from_step\n        and (\n            origin == Step\n            or isinstance(popped_params, Step)\n            or _params_contain_step(popped_params)\n            or (isinstance(popped_params, (dict, Params)) and popped_params.get(\"type\") == \"ref\")\n        )\n        and not (class_name == \"StepInfo\" and argument_name == \"config\")\n    )\n    if could_be_step:\n        # If we think it might be a step, we try parsing as a step _first_.\n        # Parsing as a non-step always succeeds, because it will fall back to returning a dict.\n        # So we can't try parsing as a non-step first.\n        backup_params = deepcopy(popped_params)\n        try:\n            return construct_arg(\n                class_name,\n                argument_name,\n                popped_params,\n                Step[annotation],  # type: ignore\n                default,\n                try_from_step=False,\n            )\n        except (ValueError, TypeError, ConfigurationError, AttributeError, IndexError):\n            popped_params = backup_params\n\n    # The parameter is optional if its default value is not the \"no default\" sentinel.\n    optional = default != _NO_DEFAULT\n\n    if (inspect.isclass(annotation) and issubclass(annotation, FromParams)) or (\n        inspect.isclass(origin) and issubclass(origin, FromParams)\n    ):\n        if origin is None and isinstance(popped_params, annotation):\n            return popped_params\n        elif popped_params is not None:\n            # In some cases we allow a string instead of a param dict, so\n            # we need to handle that case separately.\n            if isinstance(popped_params, str):\n                if origin != Step:\n                    # We don't allow single strings to be upgraded to steps.\n                    # Since we try everything as a step first, upgrading strings to\n                    # steps automatically would cause confusion every time a step\n                    # name conflicts with any string anywhere in a config.\n                    popped_params = Params({\"type\": popped_params})\n            elif isinstance(popped_params, dict):\n                popped_params = Params(popped_params)\n            elif not isinstance(popped_params, (Params, Step)):\n                raise TypeError(\n                    f\"Expected a `Params` object, found `{popped_params}` instead while constructing \"\n                    f\"parameter '{argument_name}' for `{class_name}`\"\n                )\n\n            result: Union[FromParams, WithUnresolvedSteps]\n            if isinstance(popped_params, Step):\n                result = popped_params\n            else:\n                if origin != Step and _params_contain_step(popped_params):\n                    result = WithUnresolvedSteps(annotation.from_params, popped_params)\n                else:\n                    result = annotation.from_params(popped_params)\n\n            if isinstance(result, Step):\n                expected_return_type = args[0] if args else None\n                if isinstance(result, FunctionalStep):\n                    return_type = inspect.signature(result.WRAPPED_FUNC).return_annotation\n                else:\n                    return_type = inspect.signature(result.run).return_annotation\n                if return_type == inspect.Signature.empty:\n                    logger.warning(\n                        \"Step %s has no return type annotation. Those are really helpful when \"\n                        \"debugging, so we recommend them highly.\",\n                        result.__class__.__name__,\n                    )\n                else:\n                    try:\n                        if expected_return_type is not None and not issubclass(\n                            return_type, expected_return_type\n                        ):\n                            raise ConfigurationError(\n                                f\"Step {result.name} returns {return_type}, but \"\n                                f\"we expected {expected_return_type}.\"\n                            )\n                    except TypeError:\n                        pass\n\n            return result\n        elif not optional:\n            # Not optional and not supplied, that's an error!\n            raise ConfigurationError(f\"expected key {argument_name} for {class_name}\")\n        else:\n            return default\n\n    # For StepIndexer, we just return as-is and hope the for the best.\n    # At worst, user will get an error at runtime if they are trying to index a step\n    # result that can't be indexed.\n    # TODO (epwalsh): we could check the return type of the wrapped step here\n    # and make sure that:\n    #  1. It's an index-able object,\n    #  2. The item in the index-able object matches `annotation`.\n    #\n    # But that's complex and might have false negatives.\n    elif type(popped_params) == StepIndexer:\n        return popped_params\n\n    # If the parameter type is a Python primitive, just pop it off\n    # using the correct casting pop_xyz operation.\n    elif annotation in {int, bool}:\n        if type(popped_params) in {int, bool}:\n            return annotation(popped_params)\n        else:\n            raise TypeError(\n                f\"Expected {argument_name} to be {annotation.__name__}, \"\n                f\"found {popped_params} ({type(popped_params)}).\"\n            )\n    elif annotation == str:\n        # Strings are special because we allow casting from Path to str.\n        if isinstance(popped_params, str) or isinstance(popped_params, Path):\n            return str(popped_params)  # type: ignore\n        else:\n            raise TypeError(\n                f\"Expected {argument_name} to be a string, found {popped_params} ({type(popped_params)})\"\n            )\n    elif annotation == float:\n        # Floats are special because in Python, you can put an int wherever you can put a float.\n        # https://mypy.readthedocs.io/en/stable/duck_type_compatibility.html\n        if type(popped_params) in {int, float}:\n            return popped_params\n        else:\n            raise TypeError(f\"Expected {argument_name} to be numeric.\")\n    elif annotation == Path:\n        if isinstance(popped_params, (str, Path)):\n            return Path(popped_params)\n        else:\n            raise TypeError(\n                f\"Expected {argument_name} to be a str or Path, found {popped_params} ({type(popped_params)})\"\n            )\n\n    # This is special logic for handling types like Dict[str, TokenIndexer],\n    # List[TokenIndexer], Tuple[TokenIndexer, Tokenizer], and Set[TokenIndexer],\n    # which it creates by instantiating each value from_params and returning the resulting structure.\n    elif origin in {collections.abc.Mapping, Mapping, Dict, dict} and len(args) == 2:\n        value_cls = annotation.__args__[-1]\n        value_dict = {}\n        if not isinstance(popped_params, Mapping):\n            raise TypeError(\n                f\"Expected {argument_name} to be a Mapping (probably a dict or a Params object) \"\n                f\"found {popped_params} ({type(popped_params)}).\"\n            )\n\n        for key, value_params in popped_params.items():\n            value_dict[key] = construct_arg(\n                str(value_cls),\n                argument_name + \".\" + key,\n                value_params,\n                value_cls,\n                _NO_DEFAULT,\n            )\n\n        return value_dict\n\n    elif origin in (Tuple, tuple):\n        value_list = []\n\n        value_types = list(annotation.__args__)\n        if value_types[-1] == Ellipsis:\n            # Variable length tuples, e.g. 'Tuple[int, ...]', we set value_types to '[int] * len(popped_params)'.\n            value_types = value_types[:-1] + [value_types[-2]] * (\n                len(popped_params) - len(annotation.__args__) + 1\n            )\n\n        for i, (value_cls, value_params) in enumerate(zip(value_types, popped_params)):\n            value = construct_arg(\n                str(value_cls),\n                argument_name + f\".{i}\",\n                value_params,\n                value_cls,\n                _NO_DEFAULT,\n            )\n            value_list.append(value)\n\n        return tuple(value_list)\n\n    elif origin in (Set, set) and len(args) == 1:\n        value_cls = annotation.__args__[0]\n\n        value_set = set()\n\n        for i, value_params in enumerate(popped_params):\n            value = construct_arg(\n                str(value_cls),\n                argument_name + f\".{i}\",\n                value_params,\n                value_cls,\n                _NO_DEFAULT,\n            )\n            value_set.add(value)\n\n        return value_set\n\n    elif origin == Union or isinstance(annotation, UnionType):\n        # Storing this so we can recover it later if we need to.\n        backup_params = deepcopy(popped_params)\n\n        # We'll try each of the given types in the union sequentially, returning the first one that\n        # succeeds.\n        error_chain: Optional[Exception] = None\n        for arg_annotation in args:\n            try:\n                return construct_arg(\n                    str(arg_annotation),\n                    argument_name,\n                    popped_params,\n                    arg_annotation,\n                    default,\n                )\n            except (ValueError, TypeError, ConfigurationError, AttributeError) as e:\n                # Our attempt to construct the argument may have modified popped_params, so we\n                # restore it here.\n                popped_params = deepcopy(backup_params)\n                e.args = (f\"While constructing an argument of type {arg_annotation}\",) + e.args\n                e.__cause__ = error_chain\n                error_chain = e\n\n        # If none of them succeeded, we crash.\n        config_error = ConfigurationError(\n            f\"Failed to construct argument {argument_name} with type {annotation}.\"\n        )\n        config_error.__cause__ = error_chain\n        raise config_error\n    elif origin == Lazy:\n        value_cls = args[0]\n        return Lazy(value_cls, params=deepcopy(popped_params))  # type: ignore\n\n    # For any other kind of iterable, we will just assume that a list is good enough, and treat\n    # it the same as List. This condition needs to be at the end, so we don't catch other kinds\n    # of Iterables with this branch.\n    elif origin in {collections.abc.Iterable, Iterable, List, list} and len(args) == 1:\n        value_cls = annotation.__args__[0]\n\n        value_list = []\n\n        for i, value_params in enumerate(popped_params):\n            value = construct_arg(\n                str(value_cls),\n                argument_name + f\".{i}\",\n                value_params,\n                value_cls,\n                _NO_DEFAULT,\n            )\n            value_list.append(value)\n\n        return value_list\n\n    elif (inspect.isclass(annotation) or inspect.isclass(origin)) and isinstance(\n        popped_params, Params\n    ):\n        # Constructing arbitrary classes from params\n        arbitrary_class = origin or annotation\n        constructor_to_inspect = arbitrary_class.__init__\n        constructor_to_call = arbitrary_class\n        params_contain_step = _params_contain_step(popped_params)\n        kwargs = create_kwargs(constructor_to_inspect, arbitrary_class, popped_params)\n        from tango.step import WithUnresolvedSteps\n\n        if origin != Step and params_contain_step:\n            return WithUnresolvedSteps(constructor_to_call, *[], **kwargs)\n        else:\n            return constructor_to_call(**kwargs)  # type: ignore\n\n    else:\n        # Pass it on as is and hope for the best.   ¯\\_(ツ)_/¯\n        if isinstance(popped_params, Params):\n            return popped_params.as_dict()\n        return popped_params\n\n\nclass FromParams(DetHashWithVersion):\n    \"\"\"\n    Mixin to give a :meth:`from_params` method to classes. We create a distinct base class for this\n    because sometimes we want non :class:`~tango.common.Registrable`\n    classes to be instantiatable ``from_params``.\n    \"\"\"\n\n    @classmethod\n    def from_params(\n        cls: Type[T],\n        params_: Union[Params, dict, str],\n        constructor_to_call: Optional[Callable[..., T]] = None,\n        constructor_to_inspect: Optional[Union[Callable[..., T], Callable[[T], None]]] = None,\n        **extras,\n    ) -> T:\n        \"\"\"\n        This is the automatic implementation of ``from_params``. Any class that subclasses\n        from ``FromParams`` (or :class:`~tango.common.Registrable`,\n        which itself subclasses ``FromParams``) gets this\n        implementation for free.  If you want your class to be instantiated from params in the\n        \"obvious\" way -- pop off parameters and hand them to your constructor with the same names --\n        this provides that functionality.\n\n        If you need more complex logic in your from ``from_params`` method, you'll have to implement\n        your own method that overrides this one.\n\n        The ``constructor_to_call`` and ``constructor_to_inspect`` arguments deal with a bit of\n        redirection that we do.  We allow you to register particular ``@classmethods`` on a class as\n        the constructor to use for a registered name.  This lets you, e.g., have a single\n        ``Vocabulary`` class that can be constructed in two different ways, with different names\n        registered to each constructor.  In order to handle this, we need to know not just the class\n        we're trying to construct (``cls``), but also what method we should inspect to find its\n        arguments (``constructor_to_inspect``), and what method to call when we're done constructing\n        arguments (``constructor_to_call``).  These two methods are the same when you've used a\n        ``@classmethod`` as your constructor, but they are ``different`` when you use the default\n        constructor (because you inspect ``__init__``, but call ``cls()``).\n        \"\"\"\n\n        from tango.common.registrable import (\n            Registrable,  # import here to avoid circular imports\n        )\n\n        params = params_\n\n        logger.debug(\n            f\"instantiating class {cls} from params {getattr(params, 'params', params)} \"\n            f\"and extras {set(extras.keys())}\"\n        )\n\n        if params is None:\n            return None\n\n        if isinstance(params, str):\n            params = Params({\"type\": params})\n\n        if not isinstance(params, Params):\n            if isinstance(params, dict):\n                params = Params(params)\n            else:\n                raise ConfigurationError(\n                    \"from_params was passed a `params` object that was not a `Params`. This probably \"\n                    \"indicates malformed parameters in a configuration file, where something that \"\n                    \"should have been a dictionary was actually a list, or something else. \"\n                    f\"This happened when constructing an object of type {cls}.\"\n                )\n\n        if issubclass(cls, Registrable) and not constructor_to_call:\n            # We know `cls` inherits from Registrable, so we'll use a cast to make mypy happy.\n            as_registrable = cast(Type[Registrable], cls)\n\n            if \"type\" in params and params[\"type\"] not in as_registrable.list_available():\n                as_registrable.search_modules(params[\"type\"])\n\n            # Resolve the subclass and constructor.\n            if is_base_registrable(cls) or \"type\" in params:\n                default_to_first_choice = as_registrable.default_implementation is not None\n                choice = params.pop_choice(\n                    \"type\",\n                    choices=as_registrable.list_available(),\n                    default_to_first_choice=default_to_first_choice,\n                )\n                # We allow users to register methods and functions, not just classes.\n                # So we have to handle both here.\n                subclass_or_factory_func, constructor_name = as_registrable.resolve_class_name(\n                    choice\n                )\n                if inspect.isclass(subclass_or_factory_func):\n                    # We have an actual class.\n                    subclass = subclass_or_factory_func\n                    if constructor_name is not None:\n                        constructor_to_inspect = cast(\n                            Callable[..., T], getattr(subclass, constructor_name)\n                        )\n                        constructor_to_call = constructor_to_inspect\n                    else:\n                        constructor_to_inspect = subclass.__init__\n                        constructor_to_call = subclass\n                else:\n                    # We have a function that returns an instance of the class.\n                    factory_func = cast(Callable[..., T], subclass_or_factory_func)\n                    return_type = inspect.signature(factory_func).return_annotation\n                    if return_type == inspect.Signature.empty:\n                        subclass = cls\n                    else:\n                        subclass = return_type\n                    constructor_to_inspect = factory_func\n                    constructor_to_call = factory_func\n            else:\n                # Must be trying to instantiate the given class directly.\n                subclass = cls\n                constructor_to_inspect = cls.__init__\n                constructor_to_call = cast(Callable[..., T], cls)\n\n            if hasattr(subclass, \"from_params\"):\n                # We want to call subclass.from_params.\n                extras = create_extras(subclass, extras)\n                # mypy can't follow the typing redirection that we do, so we explicitly cast here.\n                retyped_subclass = cast(Type[T], subclass)\n                return retyped_subclass.from_params(\n                    params,\n                    constructor_to_call=constructor_to_call,\n                    constructor_to_inspect=constructor_to_inspect,\n                    **extras,\n                )\n            else:\n                # In some rare cases, we get a registered subclass that does _not_ have a\n                # from_params method (this happens with Activations, for instance, where we\n                # register pytorch modules directly).  This is a bit of a hack to make those work,\n                # instead of adding a `from_params` method for them somehow.  We just trust that\n                # you've done the right thing in passing your parameters, and nothing else needs to\n                # be recursively constructed.\n                kwargs = create_kwargs(constructor_to_inspect, subclass, params, extras)  # type: ignore\n                return constructor_to_call(**kwargs)  # type: ignore\n        else:\n            # This is not a base class, so convert our params and extras into a dict of kwargs.\n\n            # See the docstring for an explanation of what's going on here.\n            if not constructor_to_inspect:\n                constructor_to_inspect = cls.__init__\n            if not constructor_to_call:\n                constructor_to_call = cls\n\n            if constructor_to_inspect == object.__init__:\n                # This class does not have an explicit constructor, so don't give it any kwargs.\n                # Without this logic, create_kwargs will look at object.__init__ and see that\n                # it takes *args and **kwargs and look for those.\n                kwargs: Dict[str, Any] = {}  # type: ignore[no-redef]\n                params.assert_empty(cls.__name__)\n            else:\n                # This class has a constructor, so create kwargs for it.\n                constructor_to_inspect = cast(Callable[..., T], constructor_to_inspect)\n                kwargs = create_kwargs(constructor_to_inspect, cls, params, extras)\n\n            return constructor_to_call(**kwargs)  # type: ignore\n\n    def to_params(self) -> Params:\n        \"\"\"\n        Returns a ``Params`` object that can be used with ``.from_params()`` to recreate an\n        object just like it.\n\n        This relies on ``_to_params()``. If you need this in your custom ``FromParams`` class,\n        override ``_to_params()``, not this method.\n        \"\"\"\n\n        def replace_object_with_params(o: Any) -> Any:\n            if isinstance(o, FromParams):\n                return o.to_params().as_dict(quiet=True)\n            elif isinstance(o, (list, tuple, set)):\n                return [replace_object_with_params(i) for i in o]\n            elif isinstance(o, dict):\n                return {key: replace_object_with_params(value) for key, value in o.items()}\n            elif isinstance(o, Path):\n                return str(o)\n            elif o is None or isinstance(o, (str, float, int, bool)):\n                return o\n            else:\n                raise NotImplementedError(\n                    f\"Unexpected type encountered in to_params(): {o} ({type(o)})\\n\"\n                    \"You may need to implement a custom '_to_params()'.\"\n                )\n\n        return Params(replace_object_with_params(self._to_params()))\n\n    def _to_params(self) -> Dict[str, Any]:\n        \"\"\"\n        Returns a dictionary of parameters that, when turned into a ``Params`` object and\n        then fed to ``.from_params()``, will recreate this object.\n\n        You don't need to implement this all the time. Tango will let you know if you\n        need it.\n        \"\"\"\n        try:\n            return self.__dict__\n        except AttributeError:\n            raise NotImplementedError(\n                f\"{self.__class__.__name__}._to_params() needs to be implemented\"\n            )\n"
  },
  {
    "path": "tango/common/lazy.py",
    "content": "import copy\nimport inspect\nfrom typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar, Union, cast\n\nfrom .det_hash import CustomDetHash, DetHashWithVersion\nfrom .params import Params\n\nT = TypeVar(\"T\")\n\n\nclass Lazy(Generic[T], CustomDetHash):\n    \"\"\"\n    This class is for use when constructing objects using :class:`~tango.common.FromParams`,\n    when an argument to a constructor has a `sequential dependency` with another argument to the same\n    constructor.\n\n    For example, in a ``Trainer`` class you might want to take a ``Model`` and an ``Optimizer`` as arguments,\n    but the ``Optimizer`` needs to be constructed using the parameters from the ``Model``. You can give\n    the type annotation ``Lazy[Optimizer]`` to the optimizer argument, then inside the constructor\n    call ``optimizer.construct(parameters=model.parameters)``.\n\n    This is only recommended for use when you have registered a ``@classmethod`` as the constructor\n    for your class, instead of using ``__init__``.  Having a ``Lazy[]`` type annotation on an argument\n    to an ``__init__`` method makes your class completely dependent on being constructed using the\n    ``FromParams`` pipeline, which is not a good idea.\n\n    The actual implementation here is incredibly simple; the logic that handles the lazy\n    construction is actually found in ``FromParams``, where we have a special case for a ``Lazy`` type\n    annotation.\n\n    Examples\n    --------\n\n    ::\n\n        @classmethod\n        def my_constructor(\n            cls,\n            some_object: Lazy[MyObject],\n            optional_object: Lazy[MyObject] = None,\n            # or:\n            #  optional_object: Optional[Lazy[MyObject]] = None,\n            optional_object_with_default: Optional[Lazy[MyObject]] = Lazy(MyObjectDefault),\n            required_object_with_default: Lazy[MyObject] = Lazy(MyObjectDefault),\n        ) -> MyClass:\n            obj1 = some_object.construct()\n            obj2 = None if optional_object is None else optional_object.construct()\n            obj3 = None optional_object_with_default is None else optional_object_with_default.construct()\n            obj4 = required_object_with_default.construct()\n\n    \"\"\"\n\n    def __init__(\n        self,\n        constructor: Union[Type[T], Callable[..., T]],\n        params: Optional[Params] = None,\n        constructor_extras: Optional[Dict[str, Any]] = None,\n        **kwargs,\n    ) -> None:\n        self._constructor = constructor\n        self._params = params or Params({})\n        self._constructor_extras = constructor_extras or {}\n        self._constructor_extras.update(kwargs)\n\n    @property\n    def constructor(self) -> Callable[..., T]:\n        from tango.common.from_params import FromParams\n\n        if inspect.isclass(self._constructor) and issubclass(self._constructor, FromParams):\n\n            def constructor_to_use(**kwargs):\n                return self._constructor.from_params(  # type: ignore[union-attr]\n                    copy.deepcopy(self._params),\n                    **kwargs,\n                )\n\n            return constructor_to_use\n        else:\n            return self._constructor\n\n    def construct(self, **kwargs) -> T:\n        \"\"\"\n        Call the constructor to create an instance of ``T``.\n        \"\"\"\n        # If there are duplicate keys between self._constructor_extras and kwargs,\n        # this will overwrite the ones in self._constructor_extras with what's in kwargs.\n        constructor_kwargs = {**self._constructor_extras, **kwargs}\n        return self.constructor(**constructor_kwargs)\n\n    def det_hash_object(self) -> Any:\n        from tango.common.from_params import FromParams\n\n        class_to_construct: Union[Type[T], Callable[..., T]] = self._constructor\n        if isinstance(class_to_construct, type) and issubclass(class_to_construct, FromParams):\n            params = copy.deepcopy(self._params)\n            if params is None:\n                params = Params({})\n            elif isinstance(params, str):\n                params = Params({\"type\": params})\n            elif isinstance(params, dict):\n                params = Params(params)\n            elif not isinstance(params, Params):\n                return None\n\n            from tango.common import Registrable\n\n            if issubclass(class_to_construct, Registrable):\n                as_registrable = cast(Type[Registrable], class_to_construct)\n\n                if \"type\" in params and params[\"type\"] not in as_registrable.list_available():\n                    as_registrable.search_modules(params[\"type\"])\n\n                # Resolve the subclass and constructor.\n                from .from_params import is_base_registrable\n\n                if is_base_registrable(class_to_construct) or \"type\" in params:\n                    default_to_first_choice = as_registrable.default_implementation is not None\n                    choice = params.pop_choice(\n                        \"type\",\n                        choices=as_registrable.list_available(),\n                        default_to_first_choice=default_to_first_choice,\n                    )\n                    subclass_or_factory_func, _ = as_registrable.resolve_class_name(choice)\n                    if inspect.isclass(subclass_or_factory_func):\n                        class_to_construct = subclass_or_factory_func\n                    else:\n                        # We have a function that returns an instance of the class.\n                        factory_func = cast(Callable[..., T], subclass_or_factory_func)\n                        return_type = inspect.signature(factory_func).return_annotation\n                        if return_type != inspect.Signature.empty:\n                            class_to_construct = return_type\n\n        if isinstance(class_to_construct, type) and issubclass(\n            class_to_construct, DetHashWithVersion\n        ):\n            return class_to_construct.VERSION, self\n        else:\n            return self\n"
  },
  {
    "path": "tango/common/logging.py",
    "content": "\"\"\"\nTango makes heavy use of the :mod:`logging` module from the standard library to convey information to users.\nWhen you're writing your own :class:`~tango.step.Step` implementations we encourage you to also use standard\nPython logging as opposed to :func:`print` or other functions that write directly to ``stdout`` or ``stderr``.\nThis is easy enough since each :class:`~tango.step.Step` class already comes with its own logger:\n:attr:`Step.logger <tango.step.Step.logger>`.\n\nWhen using the `Tango CLI <./commands.html>`_ you can set the log level in several different ways:\n\n1. Through a Tango `global settings <./commands.html#global-settings>`_ file.\n2. With the environment variable ``TANGO_LOG_LEVEL``.\n3. Or with the ``--log-level`` command-line option.\n\nIn some cases (like when running on `Beaker <https://beaker.org>`_) you may also want\nto enable `\"file friendly logging\" <#tango.common.logging.FILE_FRIENDLY_LOGGING>`_.\n\nConfiguring logging in your own CLI\n-----------------------------------\n\nIf you're writing your own CLI that uses tango, you can utilize the :func:`initialize_logging()`\nfunction to easily configure logging properly.\n\nFor example,\n\n.. testcode::\n\n    from tango.common.logging import initialize_logging, teardown_logging\n\n    initialize_logging(log_level=\"info\")\n\n    logger = logging.getLogger()\n    logger.info(\"Running script!\")\n\n    teardown_logging()\n\n.. testoutput::\n    :options: +ELLIPSIS\n\n    [...] INFO     Running script! ...\n\nIf you want to have logs written to a file, you can use the :func:`file_handler` context manager.\n\nLogging from worker processes or threads\n----------------------------------------\n\nIf you have steps or other functions that spawn workers, and you want to enable logging within\nthose workers, you can call the :func:`initialize_worker_logging()` function to configure\nlogging within each worker. This assumes that you've called :func:`initialize_logging()` from the\nmain process (the tango CLI does this for you).\n\nFor example,\n\n.. testcode::\n\n    import logging\n    import multiprocessing as mp\n\n    from tango import Step\n    from tango.common.logging import initialize_worker_logging\n\n    @Step.register(\"multiprocessing_step_example\")\n    class MultiprocessingStep(Step):\n        def run(self, num_proc: int = 2) -> bool:  # type: ignore\n            workers = []\n            for i in range(num_proc):\n                worker = mp.Process(target=_worker_function, args=(i,))\n                workers.append(worker)\n                worker.start()\n            for worker in workers:\n                worker.join()\n            return True\n\n\n    def _worker_function(worker_id: int):\n        initialize_worker_logging(worker_rank=worker_id)\n        logger = logging.getLogger(MultiprocessingStep.__name__)\n        logger.info(\"Hello from worker %d!\", worker_id)\n\n\"\"\"\n\nimport logging\nimport logging.handlers\nimport os\nimport pickle\nimport socketserver\nimport struct\nimport sys\nimport threading\nfrom contextlib import contextmanager\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Callable, ClassVar, ContextManager, Generator, List, Optional, Union\n\nimport rich\nfrom rich.console import Console, ConsoleRenderable, Group\nfrom rich.highlighter import NullHighlighter\nfrom rich.padding import Padding\nfrom rich.syntax import Syntax\nfrom rich.table import Table\nfrom rich.text import Text\n\nfrom .aliases import EnvVarNames, PathOrStr\nfrom .exceptions import CancellationError, CliRunError, SigTermReceived\nfrom .util import _parse_bool, _parse_optional_int\n\nFILE_FRIENDLY_LOGGING: bool = _parse_bool(\n    os.environ.get(EnvVarNames.FILE_FRIENDLY_LOGGING.value, False)\n)\n\"\"\"\nIf this flag is set to ``True``, we remove special styling characters from log messages,\nadd newlines to :class:`~tango.common.tqdm.Tqdm` output even on an interactive terminal, and we slow\ndown :class:`~tango.common.tqdm.Tqdm`'s output to only once every 10 seconds.\n\n.. attention::\n    Unfortunately this won't affect ``tqdm`` output from other libraries that don't use\n    Tango's :class:`~tango.common.tqdm.Tqdm` wrapper.\n\nBy default, it is set to ``False``. It can be changed by setting the corresponding environment\nvariable (``FILE_FRIENDLY_LOGGING``) or field in a :class:`~tango.__main__.TangoGlobalSettings`\nfile (``file_friendly_logging``) to \"true\" or \"false\",\nor from the command line with the ``--file-friendly-logging`` flag.\nFor example,\n\n.. code-block::\n\n    $ tango --file-friendly-logging run ...\n\n\"\"\"\n\nTANGO_LOG_LEVEL: Optional[str] = os.environ.get(EnvVarNames.LOG_LEVEL.value, None)\n\"\"\"\nThe log level to use globally. The value can be set from the corresponding environment variable\n(``TANGO_LOG_LEVEL``) or field in a :class:`~tango.__main__.TangoGlobalSettings` file (``log_level``),\nor from the command line with the ``--log-level`` option.\nPossible values are \"debug\", \"info\", \"warning\", or \"error\" (not case sensitive).\nFor example,\n\n.. code-block::\n\n    $ tango --log-level info run ...\n\n.. note::\n    This does not affect the :data:`~tango.common.logging.cli_logger`\n    or logs from :class:`~tango.common.Tqdm` progress bars.\n\n\"\"\"\n\nTANGO_CONSOLE_WIDTH: Optional[int] = _parse_optional_int(\n    os.environ.get(EnvVarNames.CONSOLE_WIDTH.value, None)\n)\n\n# Click logger disabled by default in case nobody calls initialize_logging().\nTANGO_CLI_LOGGER_ENABLED: bool = _parse_bool(\n    os.environ.get(EnvVarNames.CLI_LOGGER_ENABLED.value, False)\n)\n\n# Keep track of exceptions logged so we don't log duplicates from our custom excepthook.\n_EXCEPTIONS_LOGGED: List[BaseException] = []\n\n\nclass LevelFilter(logging.Filter):\n    \"\"\"\n    Filters out everything that is above `max_level` or higher. This is meant to be used\n    with a stdout handler when a stderr handler is also configured. That way WARNING or ERROR\n    messages aren't duplicated.\n    \"\"\"\n\n    def __init__(self, max_level: int, min_level: Optional[int] = None, name=\"\"):\n        self.max_level = max_level\n        self.min_level = min_level\n        super().__init__(name)\n\n    def filter(self, record):\n        if self.min_level is not None:\n            return self.min_level <= record.levelno <= self.max_level\n        else:\n            return record.levelno <= self.max_level\n\n\nclass CliFilter(logging.Filter):\n    def __init__(self, filter_out: bool):\n        self.filter_out = filter_out\n\n    def filter(self, record):\n        if self.filter_out:\n            return record.name != \"tango.__main__\"\n        else:\n            return record.name == \"tango.__main__\"\n\n\nclass WorkerLogFilter(logging.Filter):\n    def __init__(self, rank=-1):\n        super().__init__()\n        self._rank = rank\n\n    def filter(self, record):\n        if self._rank != -1:\n            record.msg = f\"[rank {self._rank}] {record.msg}\"\n        return True\n\n\nclass PrefixLogFilter(logging.Filter):\n    def __init__(self, prefix):\n        super().__init__()\n        self._prefix = prefix\n\n    def filter(self, record):\n        if not isinstance(record.msg, str):\n            return True\n        if record.name == \"tango.__main__\":\n            from rich.markup import escape\n\n            record.msg = escape(f\"[{self._prefix}] \") + record.msg\n        else:\n            record.msg = f\"[{self._prefix}] {record.msg}\"\n        return True\n\n\nclass LogRecordStreamHandler(socketserver.StreamRequestHandler):\n    \"\"\"Handler for a streaming logging request.\n\n    This basically logs the record using whatever logging policy is\n    configured locally.\n\n    Taken from\n    `the logging cookbook <https://docs.python.org/3.8/howto/logging-cookbook.html>`_.\n    \"\"\"\n\n    def handle(self):\n        \"\"\"\n        Handle multiple requests - each expected to be a 4-byte length,\n        followed by the LogRecord in pickle format. Logs the record\n        according to whatever policy is configured locally.\n        \"\"\"\n        while True:\n            chunk = self.connection.recv(4)\n            if len(chunk) < 4:\n                break\n            slen = struct.unpack(\">L\", chunk)[0]\n            chunk = self.connection.recv(slen)\n            while len(chunk) < slen:\n                chunk = chunk + self.connection.recv(slen - len(chunk))\n            obj = self.unPickle(chunk)\n            record = logging.makeLogRecord(obj)\n            self.handleLogRecord(record)\n\n    def unPickle(self, data):\n        return pickle.loads(data)\n\n    def handleLogRecord(self, record):\n        name = record.name\n        logger = logging.getLogger(name)\n        # N.B. EVERY record gets logged. This is because Logger.handle\n        # is normally called AFTER logger-level filtering. If you want\n        # to do filtering, do it at the client end to save wasting\n        # cycles and network bandwidth!\n        logger.handle(record)\n\n\nclass LogRecordSocketReceiver(socketserver.ThreadingTCPServer):\n    \"\"\"\n    Simple TCP socket-based logging receiver.\n\n    Taken from\n    `the logging cookbook <https://docs.python.org/3.8/howto/logging-cookbook.html>`_.\n    \"\"\"\n\n    allow_reuse_address = True\n\n    def __init__(self, host: str, port: int = 0):\n        super().__init__((host, port), LogRecordStreamHandler)\n        self.abort = False\n        self.timeout = 0.2\n\n    def serve_until_stopped(self):\n        import select\n\n        while not self.abort:\n            rd, _, _ = select.select([self.socket.fileno()], [], [], self.timeout)\n            if rd:\n                self.handle_request()\n\n\n_LOGGING_PREFIX: str = os.environ.get(EnvVarNames.LOGGING_PREFIX.value, \"\")\n_LOGGING_HOST: str = os.environ.get(EnvVarNames.LOGGING_HOST.value, \"localhost\")\n_LOGGING_PORT: Optional[int] = _parse_optional_int(\n    os.environ.get(EnvVarNames.LOGGING_PORT.value, None)\n)\n_LOGGING_SERVER: Optional[LogRecordSocketReceiver] = None\n_LOGGING_SERVER_THREAD: Optional[threading.Thread] = None\n\n\nclass RichHandler(logging.Handler):\n    \"\"\"\n    Adapted from\n    https://github.com/Textualize/rich/blob/master/rich/logging.py\n    \"\"\"\n\n    KEYWORDS: ClassVar[Optional[List[str]]] = [\n        \"GET\",\n        \"POST\",\n        \"HEAD\",\n        \"PUT\",\n        \"DELETE\",\n        \"OPTIONS\",\n        \"TRACE\",\n        \"PATCH\",\n    ]\n\n    def __init__(\n        self,\n        level: Union[int, str] = logging.NOTSET,\n        console: Optional[Console] = None,\n        *,\n        markup: bool = False,\n        log_time_format: Union[str, Callable[[datetime], str]] = \"[%x %X]\",\n        keywords: Optional[List[str]] = None,\n        show_time: bool = True,\n        show_level: bool = True,\n        show_path: bool = True,\n    ) -> None:\n        super().__init__(level=level)\n        self.console = console or rich.get_console()\n        self.highlighter = NullHighlighter()\n        self.time_format = log_time_format\n        self.markup = markup\n        self.keywords = keywords or self.KEYWORDS\n        self.show_time = show_time\n        self.show_level = show_level\n        self.show_path = show_path\n\n    def emit(self, record: logging.LogRecord) -> None:\n        if isinstance(record.msg, (Syntax, Table)):\n            self.console.print(Padding(record.msg, (1, 0, 1, 1)))\n        elif hasattr(record.msg, \"__rich__\") or hasattr(record.msg, \"__rich_console__\"):\n            self.console.print(record.msg)\n        else:\n            message = self.format(record)\n            message_renderable = self.render_message(record, message)\n            log_renderable = self.render(record=record, message_renderable=message_renderable)\n            try:\n                self.console.print(log_renderable)\n            except Exception:\n                self.handleError(record)\n\n    def render_message(self, record: logging.LogRecord, message: str) -> ConsoleRenderable:\n        use_markup = getattr(record, \"markup\", self.markup)\n        message_text = Text.from_markup(message) if use_markup else Text(message)\n        if self.show_path and record.exc_info is None:\n            message_text.end = \" \"\n\n        highlighter = getattr(record, \"highlighter\", self.highlighter)\n        if highlighter:\n            message_text = highlighter(message_text)\n\n        if self.keywords is None:\n            self.keywords = self.KEYWORDS\n\n        if self.keywords:\n            message_text.highlight_words(self.keywords, \"logging.keyword\")\n\n        return message_text\n\n    def get_time_text(self, record: logging.LogRecord) -> Text:\n        log_time = datetime.fromtimestamp(record.created)\n        time_str: str\n        if callable(self.time_format):\n            time_str = self.time_format(log_time)\n        else:\n            time_str = log_time.strftime(self.time_format)\n        return Text(time_str, style=\"log.time\", end=\" \")\n\n    def get_level_text(self, record: logging.LogRecord) -> Text:\n        level_name = record.levelname\n        level_text = Text.styled(level_name.ljust(8), f\"logging.level.{level_name.lower()}\")\n        level_text.style = \"log.level\"\n        level_text.end = \" \"\n        return level_text\n\n    def get_path_text(self, record: logging.LogRecord, length_so_far: int) -> Text:\n        path = Path(record.pathname)\n        for package_root in sys.path:\n            try:\n                path = path.relative_to(Path(package_root))\n                break\n            except ValueError:\n                continue\n        text = f\"{path}:{record.lineno}\"\n        length_after_wrap = length_so_far % self.console.width\n        return Text(\n            text.rjust(self.console.width - length_after_wrap - 3),\n            style=\"log.path\",\n        )\n\n    def render(\n        self,\n        *,\n        record: logging.LogRecord,\n        message_renderable: ConsoleRenderable,\n    ) -> ConsoleRenderable:\n        components: List[ConsoleRenderable] = []\n        if self.show_time:\n            components.append(self.get_time_text(record))\n        if self.show_level:\n            components.append(self.get_level_text(record))\n        components.append(message_renderable)\n        if self.show_path and record.exc_info is None:\n            try:\n                length_so_far = sum(len(x) for x in components)  # type: ignore\n            except TypeError:\n                pass\n            else:\n                components.append(self.get_path_text(record, length_so_far))\n\n        return Group(*components)\n\n\ndef get_handler(\n    level: int,\n    stderr: bool = False,\n    enable_markup: bool = False,\n    show_time: bool = True,\n    show_level: bool = True,\n    show_path: bool = True,\n) -> logging.Handler:\n    console = Console(\n        color_system=\"auto\" if not FILE_FRIENDLY_LOGGING else None,\n        stderr=stderr,\n        width=TANGO_CONSOLE_WIDTH,\n        soft_wrap=True,\n    )\n    if TANGO_CONSOLE_WIDTH is None and not console.is_terminal:\n        console.width = 160\n    handler = RichHandler(\n        level=level,\n        console=console,\n        markup=enable_markup,\n        show_time=show_time,\n        show_level=show_level,\n        show_path=show_path,\n    )\n    return handler\n\n\ncli_logger = logging.getLogger(\"tango.__main__\")\n\"\"\"\nA logger that emits messages directly to stdout/stderr using\n`rich <https://github.com/Textualize/rich>`_'s\n:class:`~rich.console.Console` class.\n\nThis provides a convenient way for command-line apps to log pretty, styled messages\nuses the `markup style <https://rich.readthedocs.io/en/latest/markup.html>`_ provided by `rich`.\n\"\"\"\n\ncli_logger.propagate = False\ncli_logger.disabled = TANGO_CLI_LOGGER_ENABLED\n\n\ndef excepthook(exctype, value, traceback):\n    \"\"\"\n    Used to patch `sys.excepthook` in order to log exceptions.\n    \"\"\"\n    log_exc_info(exctype, value, traceback)\n\n\ndef log_exception(exc: Optional[BaseException] = None, logger: Optional[logging.Logger] = None):\n    if exc is None:\n        et, ev, tb = sys.exc_info()\n        log_exc_info(et, ev, tb, logger=logger)\n    else:\n        log_exc_info(exc.__class__, exc, exc.__traceback__, logger=logger)\n\n\ndef log_exc_info(exctype, value, traceback, logger: Optional[logging.Logger] = None):\n    global _EXCEPTIONS_LOGGED\n    if value not in _EXCEPTIONS_LOGGED:\n        _EXCEPTIONS_LOGGED.append(value)\n        logger = logger or logging.getLogger()\n        if isinstance(value, CliRunError):\n            msg = str(value)\n            if msg:\n                cli_logger.error(msg)\n        elif isinstance(value, (KeyboardInterrupt, CancellationError)):\n            logger.error(\"%s: %s\", exctype.__name__, value)\n        else:\n            logger.error(\n                \"Uncaught exception\",\n                exc_info=(exctype, value, traceback),\n                extra={\"highlighter\": rich.highlighter.ReprHighlighter()},\n            )\n\n\ndef initialize_logging(\n    *,\n    log_level: Optional[str] = None,\n    enable_cli_logs: Optional[bool] = None,\n    file_friendly_logging: Optional[bool] = None,\n):\n    \"\"\"\n    Initialize logging, which includes setting the global log level, format, and configuring\n    handlers.\n\n    .. tip::\n        This should be called as early on in your script as possible.\n\n    .. tip::\n        You should also call :func:`teardown_logging()` as the end of your script.\n\n    .. tip::\n        For worker threads/processes, use :func:`initialize_worker_logging()` instead.\n\n    :param log_level:\n        Can be one of \"debug\", \"info\", \"warning\", \"error\". Defaults to the value\n        of :data:`TANGO_LOG_LEVEL`, if set, or \"error\".\n    :param enable_cli_logs:\n        Set to ``True`` to enable messages from the :data:`cli_logger`.\n    :param file_friendly_logging:\n        Enable or disable file friendly logging. Defaults to the value of :data:`FILE_FRIENDLY_LOGGING`.\n\n    \"\"\"\n    import multiprocessing as mp\n\n    is_main_process: bool\n    if hasattr(mp, \"parent_process\"):  # python 3.8 or greater\n        is_main_process = mp.parent_process() is None  # type: ignore\n    else:\n        is_main_process = mp.current_process().name == \"MainProcess\"\n\n    _initialize_logging(\n        log_level=log_level,\n        enable_cli_logs=enable_cli_logs,\n        file_friendly_logging=file_friendly_logging,\n        main_process=is_main_process,\n    )\n\n\ndef initialize_worker_logging(worker_rank: Optional[int] = None):\n    \"\"\"\n    Initialize logging in a worker thread/process.\n\n    :param worker_rank:\n        The rank/ID of the worker.\n\n    \"\"\"\n    if worker_rank is not None:\n        if worker_rank != -1:\n            prefix = f\"rank {worker_rank}\"\n        else:\n            prefix = None\n    else:\n        prefix = None\n    return initialize_prefix_logging(prefix=prefix, main_process=False)\n\n\ndef initialize_prefix_logging(\n    *, log_level: Optional[str] = None, prefix: Optional[str] = None, main_process: bool = False\n):\n    \"\"\"\n    Initialize logging with a prefix.\n\n    :param log_level:\n        Can be one of \"debug\", \"info\", \"warning\", \"error\". Defaults to the value\n        of :data:`TANGO_LOG_LEVEL`, if set, or \"error\".\n    :param prefix:\n        The string prefix to add to the log message.\n    :param main_process:\n        Whether it is for the main/worker process.\n    \"\"\"\n    return _initialize_logging(log_level=log_level, prefix=prefix, main_process=main_process)\n\n\ndef _initialize_logging(\n    *,\n    log_level: Optional[str] = None,\n    enable_cli_logs: Optional[bool] = None,\n    file_friendly_logging: Optional[bool] = None,\n    prefix: Optional[str] = None,\n    main_process: bool = True,\n):\n    global FILE_FRIENDLY_LOGGING, TANGO_LOG_LEVEL, TANGO_CLI_LOGGER_ENABLED\n    global _LOGGING_HOST, _LOGGING_PORT, _LOGGING_SERVER, _LOGGING_SERVER_THREAD, _LOGGING_PREFIX\n\n    if log_level is None:\n        log_level = TANGO_LOG_LEVEL\n    if log_level is None:\n        log_level = \"warning\"\n    if file_friendly_logging is None:\n        file_friendly_logging = FILE_FRIENDLY_LOGGING\n    if enable_cli_logs is None:\n        enable_cli_logs = TANGO_CLI_LOGGER_ENABLED\n    if prefix:\n        prefix = _LOGGING_PREFIX + \" \" + prefix if _LOGGING_PREFIX else prefix\n    else:\n        prefix = _LOGGING_PREFIX\n\n    level = logging._nameToLevel[log_level.upper()]\n\n    # Update global flags and corresponding environment variables, if necessary,\n    # so that child processes can read the environment variables to determine the right\n    # settings.\n    TANGO_LOG_LEVEL = log_level\n    os.environ[EnvVarNames.LOG_LEVEL.value] = log_level\n    if file_friendly_logging is not None:\n        FILE_FRIENDLY_LOGGING = file_friendly_logging\n        os.environ[EnvVarNames.FILE_FRIENDLY_LOGGING.value] = str(file_friendly_logging).lower()\n    if enable_cli_logs is not None:\n        TANGO_CLI_LOGGER_ENABLED = enable_cli_logs\n        os.environ[EnvVarNames.CLI_LOGGER_ENABLED.value] = str(enable_cli_logs).lower()\n\n    from .tqdm import logger as tqdm_logger\n\n    # Handle special cases for specific loggers:\n    # These loggers emit too many messages, so we tell them to be quiet unless they have something\n    # important to say.\n    for loud_logger in {\"filelock\", \"sqlitedict\"}:\n        logging.getLogger(loud_logger).setLevel(max(level, logging.WARNING))\n    # We always want to see all CLI messages if we're running from the command line, and none otherwise.\n    cli_logger.setLevel(logging.DEBUG)\n    cli_logger.disabled = not enable_cli_logs\n    # We also want to enable the tqdm logger so that the progress bar lines end up in the log file.\n    tqdm_logger.setLevel(logging.DEBUG)\n\n    root_logger = logging.getLogger()\n    root_logger.setLevel(level)\n    root_logger.handlers.clear()\n\n    if main_process:\n        # Create stdout and stderr handlers so that we can route DEBUG and INFO\n        # messages to stdout, and WARNING and ERROR messages to stderr.\n        stdout_handler = get_handler(level)\n        stdout_handler.addFilter(LevelFilter(logging.INFO))\n        stderr_handler = get_handler(max(level, logging.WARNING), stderr=True)\n        stderr_handler.addFilter(LevelFilter(logging.CRITICAL, min_level=logging.WARNING))\n        root_logger.addHandler(stdout_handler)\n        root_logger.addHandler(stderr_handler)\n\n        # Configure cli_logger so that if log level <= INFO, it will behave\n        # like a regular logger, otherwise it prints directly to stdout.\n        cli_logger.handlers.clear()\n        if enable_cli_logs:\n            for handler_level in (logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR):\n                cli_handler = get_handler(\n                    handler_level,\n                    stderr=handler_level >= logging.WARNING,\n                    enable_markup=True,\n                    show_time=level <= handler_level,\n                    show_level=(level <= handler_level) or handler_level >= logging.WARNING,\n                    show_path=level <= handler_level,\n                )\n                cli_handler.addFilter(LevelFilter(handler_level))\n                cli_logger.addHandler(cli_handler)\n\n        # Add prefix.\n        if prefix:\n            for logger in (root_logger, cli_logger, tqdm_logger):\n                for handler in logger.handlers:\n                    handler.addFilter(PrefixLogFilter(prefix))\n\n        # Main process: set formatter and handlers, initialize logging socket and server.\n        # Set up logging socket to emit log records from worker processes/threads.\n        # Inspired by:\n        # https://docs.python.org/3.8/howto/logging-cookbook.html#sending-and-receiving-logging-events-across-a-network\n        _LOGGING_SERVER = LogRecordSocketReceiver(_LOGGING_HOST, 0)\n        _LOGGING_PORT = _LOGGING_SERVER.server_address[1]\n        os.environ[EnvVarNames.LOGGING_PORT.value] = str(_LOGGING_PORT)\n        _LOGGING_SERVER_THREAD = threading.Thread(\n            target=_LOGGING_SERVER.serve_until_stopped, daemon=True\n        )\n        _LOGGING_SERVER_THREAD.start()\n    else:\n        # Child process: set handler and level, no need to set formatting since only raw log records\n        # will be sent to the logging socket.\n        if _LOGGING_PORT is None:\n            raise ValueError(\n                \"missing logging socket configuration, \"\n                \"did you forget to call 'initialize_logging()' from the main process?\"\n            )\n        socket_handler = logging.handlers.SocketHandler(_LOGGING_HOST, _LOGGING_PORT)\n        if prefix:\n            socket_handler.addFilter(PrefixLogFilter(prefix))\n\n        for logger in (root_logger, cli_logger, tqdm_logger):\n            logger.handlers.clear()\n            logger.addHandler(socket_handler)\n\n    # Write uncaught exceptions to the logs.\n    sys.excepthook = excepthook\n\n    # Ensure warnings issued by the 'warnings' module will be redirected to the logging system.\n    logging.captureWarnings(True)\n\n\ndef teardown_logging():\n    \"\"\"\n    Cleanup any logging fixtures created from :func:`initialize_logging()`. Should\n    be called at the end of your script.\n    \"\"\"\n    global _LOGGING_HOST, _LOGGING_PORT, _LOGGING_SERVER, _LOGGING_SERVER_THREAD\n\n    if _LOGGING_SERVER is not None:\n        _LOGGING_SERVER.abort = True\n\n    if _LOGGING_SERVER_THREAD is not None:\n        _LOGGING_SERVER_THREAD.join()\n        _LOGGING_SERVER_THREAD = None\n\n    if _LOGGING_SERVER is not None:\n        _LOGGING_SERVER = None\n\n    sys.excepthook = sys.__excepthook__  # type: ignore[assignment]\n\n\n@contextmanager\ndef insert_handlers(*handlers: logging.Handler) -> Generator[None, None, None]:\n    \"\"\"\n    A context manager that can be used to route logs to a specific handler temporarily.\n    \"\"\"\n    global _EXCEPTIONS_LOGGED\n\n    root_logger = logging.getLogger()\n\n    from .tqdm import logger as tqdm_logger\n\n    for logger in (root_logger, cli_logger, tqdm_logger):\n        for handler in handlers:\n            logger.addHandler(handler)\n\n    try:\n        yield None\n    except BaseException as e:\n        if not isinstance(\n            e, (CliRunError, KeyboardInterrupt, SigTermReceived)\n        ):  # don't need tracebacks for these\n            log_exception(e)\n            _EXCEPTIONS_LOGGED.append(e)\n        raise\n    finally:\n        for logger in (root_logger, cli_logger, tqdm_logger):\n            for handler in handlers:\n                logger.removeHandler(handler)\n\n\ndef file_handler(filepath: PathOrStr) -> ContextManager[None]:\n    \"\"\"\n    A context manager that can be used to route logs to a file by adding a\n    :class:`logging.FileHandler` to the root logger's handlers.\n\n    For example,\n\n    .. code-block::\n\n        from tango.common.logging import initialize_logging, file_handler, teardown_logging\n\n        initialize_logging(log_level=\"info\")\n\n        logger = logging.getLogger()\n        logger.info(\"Hi!\")\n\n        with file_handler(\"log.out\"):\n            logger.info(\"This message should also go into 'log.out'\")\n\n        teardown_logging()\n\n    \"\"\"\n    log_file = open(filepath, \"w\")\n    handlers: List[logging.Handler] = []\n    console = Console(\n        color_system=None,\n        file=log_file,\n        force_terminal=False,\n        width=TANGO_CONSOLE_WIDTH or 160,\n    )\n    for is_cli_handler in (True, False):\n        handler = RichHandler(\n            console=console,\n            markup=is_cli_handler,\n        )\n        handler.addFilter(CliFilter(filter_out=not is_cli_handler))\n        handlers.append(handler)\n    return insert_handlers(*handlers)\n"
  },
  {
    "path": "tango/common/params.py",
    "content": "import copy\nimport json\nimport logging\nimport os\nimport zlib\nfrom collections import OrderedDict\nfrom collections.abc import MutableMapping\nfrom itertools import chain\nfrom pathlib import Path\nfrom typing import Any, Dict, Iterable, List, Optional, Set, TypeVar, Union\n\nimport yaml\nfrom rjsonnet import evaluate_file, evaluate_snippet\n\nfrom .aliases import PathOrStr\nfrom .exceptions import ConfigurationError\nfrom .util import could_be_class_name\n\nlogger = logging.getLogger(__name__)\n\n\ndef infer_and_cast(value: Any):\n    \"\"\"\n    In some cases we'll be feeding params dicts to functions we don't own;\n    for example, PyTorch optimizers. In that case we can't use ``pop_int``\n    or similar to force casts (which means you can't specify ``int`` parameters\n    using environment variables). This function takes something that looks JSON-like\n    and recursively casts things that look like (bool, int, float) to (bool, int, float).\n    \"\"\"\n\n    if isinstance(value, (int, float, bool)):\n        # Already one of our desired types, so leave as is.\n        return value\n    elif isinstance(value, list):\n        # Recursively call on each list element.\n        return [infer_and_cast(item) for item in value]\n    elif isinstance(value, dict):\n        # Recursively call on each dict value.\n        return {key: infer_and_cast(item) for key, item in value.items()}\n    elif isinstance(value, str):\n        # If it looks like a bool, make it a bool.\n        if value.lower() == \"true\":\n            return True\n        elif value.lower() == \"false\":\n            return False\n        else:\n            # See if it could be an int.\n            try:\n                return int(value)\n            except ValueError:\n                pass\n            # See if it could be a float.\n            try:\n                return float(value)\n            except ValueError:\n                # Just return it as a string.\n                return value\n    else:\n        raise ValueError(f\"cannot infer type of {value}\")\n\n\ndef _is_encodable(value: str) -> bool:\n    \"\"\"\n    We need to filter out environment variables that can't\n    be unicode-encoded to avoid a \"surrogates not allowed\"\n    error in jsonnet.\n    \"\"\"\n    # Idiomatically you'd like to not check the != b\"\"\n    # but mypy doesn't like that.\n    return (value == \"\") or (value.encode(\"utf-8\", \"ignore\") != b\"\")\n\n\ndef _environment_variables() -> Dict[str, str]:\n    \"\"\"\n    Wraps ``os.environ`` to filter out non-encodable values.\n    \"\"\"\n    return {key: value for key, value in os.environ.items() if _is_encodable(value)}\n\n\nT = TypeVar(\"T\", dict, list)\n\n\ndef with_overrides(original: T, overrides_dict: Dict[str, Any], prefix: str = \"\") -> T:\n    merged: T\n    keys: Union[Iterable[str], Iterable[int]]\n    if isinstance(original, list):\n        merged = [None] * len(original)\n        keys = range(len(original))\n    elif isinstance(original, dict):\n        merged = {}\n        keys = chain(\n            original.keys(), (k for k in overrides_dict if \".\" not in k and k not in original)\n        )\n    else:\n        if prefix:\n            raise ValueError(\n                f\"overrides for '{prefix[:-1]}.*' expected list or dict in original, \"\n                f\"found {type(original)} instead\"\n            )\n        else:\n            raise ValueError(f\"expected list or dict, found {type(original)} instead\")\n\n    used_override_keys: Set[str] = set()\n    for key in keys:\n        if str(key) in overrides_dict:\n            merged[key] = copy.deepcopy(overrides_dict[str(key)])\n            used_override_keys.add(str(key))\n        else:\n            overrides_subdict = {}\n            for o_key in overrides_dict:\n                if o_key.startswith(f\"{key}.\"):\n                    overrides_subdict[o_key[len(f\"{key}.\") :]] = overrides_dict[o_key]\n                    used_override_keys.add(o_key)\n            if overrides_subdict:\n                merged[key] = with_overrides(\n                    original[key], overrides_subdict, prefix=prefix + f\"{key}.\"\n                )\n            else:\n                merged[key] = copy.deepcopy(original[key])\n\n    unused_override_keys = [prefix + key for key in set(overrides_dict.keys()) - used_override_keys]\n    if unused_override_keys:\n        raise ValueError(f\"overrides dict contains unused keys: {unused_override_keys}\")\n\n    return merged\n\n\ndef parse_overrides(\n    serialized_overrides: str, ext_vars: Optional[Dict[str, Any]] = None\n) -> Dict[str, Any]:\n    if serialized_overrides:\n        ext_vars = {**_environment_variables(), **(ext_vars or {})}\n\n        return json.loads(evaluate_snippet(\"\", serialized_overrides, ext_vars=ext_vars))\n    else:\n        return {}\n\n\ndef _is_dict_free(obj: Any) -> bool:\n    \"\"\"\n    Returns False if obj is a dict, or if it's a list with an element that _has_dict.\n    \"\"\"\n    if isinstance(obj, dict):\n        return False\n    elif isinstance(obj, list):\n        return all(_is_dict_free(item) for item in obj)\n    else:\n        return True\n\n\ndef pop_choice(\n    params: Dict[str, Any],\n    key: str,\n    choices: List[Any],\n    default_to_first_choice: bool = False,\n    history: str = \"?.\",\n    allow_class_names: bool = True,\n) -> Any:\n    \"\"\"\n    Performs the same function as ``Params.pop_choice``, but is required in order to deal with\n    places that the Params object is not welcome, such as inside Keras layers.  See the docstring\n    of that method for more detail on how this function works.\n\n    This method adds a ``history`` parameter, in the off-chance that you know it, so that we can\n    reproduce ``Params.pop_choice`` exactly.  We default to using \"?.\" if you don't know the\n    history, so you'll have to fix that in the log if you want to actually recover the logged\n    parameters.\n    \"\"\"\n    value = Params(params, history).pop_choice(\n        key, choices, default_to_first_choice, allow_class_names=allow_class_names\n    )\n    return value\n\n\ndef _replace_none(params: Any) -> Any:\n    if isinstance(params, str) and params == \"None\":\n        return None\n    elif isinstance(params, (dict, Params)):\n        if isinstance(params, Params):\n            params = params.as_dict(quiet=True)\n        for key, value in params.items():\n            params[key] = _replace_none(value)\n        return params\n    elif isinstance(params, list):\n        return [_replace_none(value) for value in params]\n    return params\n\n\ndef remove_keys_from_params(params: \"Params\", keys: List[str] = [\"pretrained_file\", \"initializer\"]):\n    if isinstance(params, Params):  # The model could possibly be a string, for example.\n        param_keys = params.keys()\n        for key in keys:\n            if key in param_keys:\n                del params[key]\n        for value in params.values():\n            if isinstance(value, Params):\n                remove_keys_from_params(value, keys)\n            elif isinstance(value, list):\n                for item in value:\n                    if isinstance(item, Params):\n                        remove_keys_from_params(item, keys)\n\n\nclass Params(MutableMapping):\n    \"\"\"\n    A :class:`~collections.abc.MutableMapping` that represents a parameter dictionary with a history,\n    and contains other functionality around parameter passing and validation for AI2 Tango.\n\n    There are currently two benefits of a ``Params`` object over a plain dictionary for parameter\n    passing:\n\n    1. We handle a few kinds of parameter validation, including making sure that parameters\n       representing discrete choices actually have acceptable values, and making sure no extra\n       parameters are passed.\n    2. We log all parameter reads, including default values.  This gives a more complete\n       specification of the actual parameters used than is given in a JSON file, because\n       those may not specify what default values were used, whereas this will log them.\n\n    .. important::\n        The convention for using a ``Params`` object in Tango is that you will consume the parameters\n        as you read them, so that there are none left when you've read everything you expect.  This\n        lets us easily validate that you didn't pass in any ``extra`` parameters, just by making sure\n        that the parameter dictionary is empty.  You should do this when you're done handling\n        parameters, by calling :meth:`Params.assert_empty()`.\n    \"\"\"\n\n    # This allows us to check for the presence of \"None\" as a default argument,\n    # which we require because we make a distinction between passing a value of \"None\"\n    # and passing no value to the default parameter of \"pop\".\n    DEFAULT = object()\n\n    def __init__(self, params: \"MutableMapping[str, Any]\", history: str = \"\") -> None:\n        if isinstance(params, Params):\n            self.params: MutableMapping = params.params\n        else:\n            self.params = _replace_none(params)\n        self.history = history\n\n    def pop(self, key: str, default: Any = DEFAULT, keep_as_dict: bool = False) -> Any:\n        \"\"\"\n        Performs the functionality associated with ``dict.pop(key)``, along with checking for\n        returned dictionaries, replacing them with Param objects with an updated history\n        (unless keep_as_dict is True, in which case we leave them as dictionaries).\n\n        If ``key`` is not present in the dictionary, and no default was specified, we raise a\n        :class:`~tango.common.exceptions.ConfigurationError`, instead of the typical ``KeyError``.\n        \"\"\"\n        if default is self.DEFAULT:\n            try:\n                value = self.params.pop(key)\n            except KeyError:\n                msg = f'key \"{key}\" is required'\n                if self.history:\n                    msg += f' at location \"{self.history}\"'\n                raise ConfigurationError(msg)\n        else:\n            value = self.params.pop(key, default)\n\n        logger.debug(f\"{self.history}{key} = {value}\")\n        if keep_as_dict or _is_dict_free(value):\n            return value\n        else:\n            return self._check_is_dict(key, value)\n\n    def pop_int(self, key: str, default: Any = DEFAULT) -> Optional[int]:\n        \"\"\"\n        Performs a pop and coerces to an int.\n        \"\"\"\n        value = self.pop(key, default)\n        if value is None:\n            return None\n        else:\n            return int(value)\n\n    def pop_float(self, key: str, default: Any = DEFAULT) -> Optional[float]:\n        \"\"\"\n        Performs a pop and coerces to a float.\n        \"\"\"\n        value = self.pop(key, default)\n        if value is None:\n            return None\n        else:\n            return float(value)\n\n    def pop_bool(self, key: str, default: Any = DEFAULT) -> Optional[bool]:\n        \"\"\"\n        Performs a pop and coerces to a bool.\n        \"\"\"\n        value = self.pop(key, default)\n        if value is None:\n            return None\n        elif isinstance(value, bool):\n            return value\n        elif value == \"true\":\n            return True\n        elif value == \"false\":\n            return False\n        else:\n            raise ValueError(\"Cannot convert variable to bool: \" + value)\n\n    def get(self, key: str, default: Any = DEFAULT):\n        \"\"\"\n        Performs the functionality associated with ``dict.get(key)`` but also checks for returned\n        dicts and returns a ``Params`` object in their place with an updated history.\n        \"\"\"\n        default = None if default is self.DEFAULT else default\n        value = self.params.get(key, default)\n        return self._check_is_dict(key, value)\n\n    def pop_choice(\n        self,\n        key: str,\n        choices: List[Any],\n        default_to_first_choice: bool = False,\n        allow_class_names: bool = True,\n    ) -> Any:\n        \"\"\"\n        Gets the value of ``key`` in the ``params`` dictionary, ensuring that the value is one of\n        the given choices. Note that this ``pops`` the key from params, modifying the dictionary,\n        consistent with how parameters are processed in this codebase.\n\n        :param key:\n            Key to get the value from in the param dictionary\n\n        :param choices:\n            A list of valid options for values corresponding to ``key``.  For example, if you're\n            specifying the type of encoder to use for some part of your model, the choices might be\n            the list of encoder classes we know about and can instantiate.  If the value we find in\n            the param dictionary is not in ``choices``, we raise a\n            :class:`~tango.common.exceptions.ConfigurationError`, because\n            the user specified an invalid value in their parameter file.\n\n        :param default_to_first_choice:\n            If this is ``True``, we allow the ``key`` to not be present in the parameter\n            dictionary.  If the key is not present, we will use the return as the value the first\n            choice in the ``choices`` list.  If this is ``False``, we raise a\n            :class:`~tango.common.exceptions.ConfigurationError`, because\n            specifying the ``key`` is required (e.g., you ``have`` to\n            specify your model class when running an experiment, but you can feel free to use\n            default settings for encoders if you want).\n\n        :param allow_class_names:\n            If this is ``True``, then we allow unknown choices that look like fully-qualified class names.\n            This is to allow e.g. specifying a model type as ``my_library.my_model.MyModel``\n            and importing it on the fly. Our check for \"looks like\" is extremely lenient\n            and consists of checking that the value contains a '.'.\n        \"\"\"\n        default = choices[0] if default_to_first_choice else self.DEFAULT\n        value = self.pop(key, default)\n        ok_because_class_name = allow_class_names and could_be_class_name(value)\n        if value not in choices and not ok_because_class_name:\n            key_str = self.history + key\n            message = (\n                f\"'{value}' not in acceptable choices for {key_str}: {choices}. \"\n                \"You should either use the --include-package flag to make sure the correct module \"\n                \"is loaded, or use a fully qualified class name in your config file like \"\n                \"\"\"{\"model\": \"my_module.models.MyModel\"} to have it imported automatically.\"\"\"\n            )\n            raise ConfigurationError(message)\n        return value\n\n    def as_dict(self, quiet: bool = False, infer_type_and_cast: bool = False):\n        \"\"\"\n        Sometimes we need to just represent the parameters as a dict, for instance when we pass\n        them to PyTorch code.\n\n        :param quiet:\n            Whether to log the parameters before returning them as a dict.\n\n        :param infer_type_and_cast:\n            If ``True``, we infer types and cast (e.g. things that look like floats to floats).\n        \"\"\"\n        if infer_type_and_cast:\n            params_as_dict = infer_and_cast(self.params)\n        else:\n            params_as_dict = self.params\n\n        if quiet:\n            return params_as_dict\n\n        def log_recursively(parameters, history):\n            for key, value in parameters.items():\n                if isinstance(value, dict):\n                    new_local_history = history + key + \".\"\n                    log_recursively(value, new_local_history)\n                else:\n                    logger.debug(f\"{history}{key} = {value}\")\n\n        log_recursively(self.params, self.history)\n        return params_as_dict\n\n    def as_flat_dict(self) -> Dict[str, Any]:\n        \"\"\"\n        Returns the parameters of a flat dictionary from keys to values.\n        Nested structure is collapsed with periods.\n        \"\"\"\n        flat_params = {}\n\n        def recurse(parameters, path):\n            for key, value in parameters.items():\n                newpath = path + [key]\n                if isinstance(value, dict):\n                    recurse(value, newpath)\n                else:\n                    flat_params[\".\".join(newpath)] = value\n\n        recurse(self.params, [])\n        return flat_params\n\n    def duplicate(self) -> \"Params\":\n        \"\"\"\n        Uses ``copy.deepcopy()`` to create a duplicate (but fully distinct)\n        copy of these Params.\n        \"\"\"\n        return copy.deepcopy(self)\n\n    def assert_empty(self, name: str):\n        \"\"\"\n        Raises a :class:`~tango.common.exceptions.ConfigurationError` if ``self.params`` is not empty.\n        We take ``name`` as an argument so that the error message gives some idea of where an error\n        happened, if there was one. For example, ``name`` could be the name of the ``calling`` class\n        that got extra parameters (if there are any).\n        \"\"\"\n        if self.params:\n            raise ConfigurationError(\"Extra parameters passed to {}: {}\".format(name, self.params))\n\n    def __getitem__(self, key):\n        if key in self.params:\n            return self._check_is_dict(key, self.params[key])\n        else:\n            raise KeyError(str(key))\n\n    def __setitem__(self, key, value):\n        self.params[key] = value\n\n    def __delitem__(self, key):\n        del self.params[key]\n\n    def __iter__(self):\n        return iter(self.params)\n\n    def __len__(self):\n        return len(self.params)\n\n    def _check_is_dict(self, new_history, value):\n        if isinstance(value, dict):\n            new_history = self.history + new_history + \".\"\n            return Params(value, history=new_history)\n        if isinstance(value, list):\n            value = [self._check_is_dict(f\"{new_history}.{i}\", v) for i, v in enumerate(value)]\n        return value\n\n    @classmethod\n    def from_file(\n        cls,\n        params_file: PathOrStr,\n        params_overrides: Union[str, Dict[str, Any]] = \"\",\n        ext_vars: Optional[dict] = None,\n    ) -> \"Params\":\n        \"\"\"\n        Load a ``Params`` object from a configuration file.\n\n        :param params_file:\n            The path to the configuration file to load. Can be JSON, Jsonnet, or YAML.\n\n        :param params_overrides:\n            A dict of overrides that can be applied to final object.\n            e.g. ``{\"model.embedding_dim\": 10}`` will change the value of \"embedding_dim\"\n            within the \"model\" object of the config to 10. If you wanted to override the entire\n            \"model\" object of the config, you could do ``{\"model\": {\"type\": \"other_type\", ...}}``.\n\n        :param ext_vars:\n            Our config files are Jsonnet, which allows specifying external variables\n            for later substitution. Typically we substitute these using environment\n            variables; however, you can also specify them here, in which case they\n            take priority over environment variables.\n            e.g. ``{\"HOME_DIR\": \"/Users/allennlp/home\"}``\n        \"\"\"\n        if ext_vars is None:\n            ext_vars = {}\n\n        # redirect to cache, if necessary\n        from cached_path import cached_path\n\n        params_file: Path = Path(cached_path(params_file))\n        if not params_file.is_file():\n            raise FileNotFoundError(params_file)\n\n        file_dict: Dict[str, Any]\n        if params_file.suffix in {\".yml\", \".yaml\"}:\n            with open(params_file) as f:\n                file_dict = yaml.safe_load(f)\n        else:\n            # Fall back to JSON/Jsonnet.\n            ext_vars = {**_environment_variables(), **ext_vars}\n            json_str = evaluate_file(params_file.name, str(params_file.parent), ext_vars=ext_vars)\n            file_dict = json.loads(json_str)\n\n        if isinstance(params_overrides, dict):\n            params_overrides = json.dumps(params_overrides)\n        overrides_dict = parse_overrides(params_overrides, ext_vars=ext_vars)\n\n        if overrides_dict:\n            param_dict = with_overrides(file_dict, overrides_dict)\n        else:\n            param_dict = file_dict\n\n        return cls(param_dict)\n\n    def to_file(\n        self, params_file: PathOrStr, preference_orders: Optional[List[List[str]]] = None\n    ) -> None:\n        \"\"\"\n        Write the params to file.\n        \"\"\"\n        with open(params_file, \"w\") as handle:\n            json.dump(self.as_ordered_dict(preference_orders), handle, indent=4)\n\n    def as_ordered_dict(self, preference_orders: Optional[List[List[str]]] = None) -> OrderedDict:\n        \"\"\"\n        Returns an ``OrderedDict`` of ``Params`` from list of partial order preferences.\n\n        :param preference_orders:\n            ``preference_orders`` is list of partial preference orders. [\"A\", \"B\", \"C\"] means\n            \"A\" > \"B\" > \"C\". For multiple preference_orders first will be considered first.\n            Keys not found, will have last but alphabetical preference. Default Preferences:\n            ``[[\"dataset_reader\", \"iterator\", \"model\", \"train_data_path\", \"validation_data_path\",\n            \"test_data_path\", \"trainer\", \"vocabulary\"], [\"type\"]]``\n        \"\"\"\n        params_dict = self.as_dict(quiet=True)\n        if not preference_orders:\n            preference_orders = []\n            preference_orders.append([\"type\"])\n\n        def order_func(key):\n            # Makes a tuple to use for ordering.  The tuple is an index into each of the `preference_orders`,\n            # followed by the key itself.  This gives us integer sorting if you have a key in one of the\n            # `preference_orders`, followed by alphabetical ordering if not.\n            order_tuple = [\n                order.index(key) if key in order else len(order) for order in preference_orders  # type: ignore\n            ]\n            return order_tuple + [key]\n\n        def order_dict(dictionary, order_func):\n            # Recursively orders dictionary according to scoring order_func\n            result = OrderedDict()\n            for key, val in sorted(dictionary.items(), key=lambda item: order_func(item[0])):\n                result[key] = order_dict(val, order_func) if isinstance(val, dict) else val\n            return result\n\n        return order_dict(params_dict, order_func)\n\n    def get_hash(self) -> str:\n        \"\"\"\n        Returns a hash code representing the current state of this ``Params`` object.  We don't\n        want to implement ``__hash__`` because that has deeper python implications (and this is a\n        mutable object), but this will give you a representation of the current state.\n        We use ``zlib.adler32`` instead of Python's builtin ``hash`` because the random seed for the\n        latter is reset on each new program invocation, as discussed here:\n        https://stackoverflow.com/questions/27954892/deterministic-hashing-in-python-3.\n        \"\"\"\n        dumped = json.dumps(self.params, sort_keys=True)\n        hashed = zlib.adler32(dumped.encode())\n        return str(hashed)\n\n    def __str__(self) -> str:\n        return f\"{self.history}Params({self.params})\"\n"
  },
  {
    "path": "tango/common/registrable.py",
    "content": "\"\"\"\n:class:`Registrable` is a \"mixin\" for endowing\nany base class with a named registry for its subclasses and a decorator\nfor registering them.\n\"\"\"\n\nimport importlib\nimport logging\nfrom collections import defaultdict\nfrom typing import (\n    Callable,\n    ClassVar,\n    DefaultDict,\n    Dict,\n    List,\n    Optional,\n    Set,\n    Tuple,\n    Type,\n    TypeVar,\n    cast,\n)\n\nfrom .exceptions import ConfigurationError, IntegrationMissingError, RegistryKeyError\nfrom .from_params import FromParams\nfrom .util import (\n    could_be_class_name,\n    find_integrations,\n    find_submodules,\n    import_module_and_submodules,\n)\n\nlogger = logging.getLogger(__name__)\n\n_T = TypeVar(\"_T\")\n_RegistrableT = TypeVar(\"_RegistrableT\", bound=\"Registrable\")\n\n_SubclassRegistry = Dict[str, Tuple[type, Optional[str]]]\n\n\nclass Registrable(FromParams):\n    \"\"\"\n    Any class that inherits from ``Registrable`` gains access to a named registry for its\n    subclasses. To register them, just decorate them with the classmethod\n    ``@BaseClass.register(name)``.\n\n    After which you can call ``BaseClass.list_available()`` to get the keys for the\n    registered subclasses, and ``BaseClass.by_name(name)`` to get the corresponding subclass.\n    Note that the registry stores the subclasses themselves; not class instances.\n    In most cases you would then call :meth:`~tango.common.from_params.FromParams.from_params()`\n    on the returned subclass.\n\n    You can specify a default by setting ``BaseClass.default_implementation``.\n    If it is set, it will be the first element of :meth:`list_available()`.\n\n    Note that if you use this class to implement a new ``Registrable`` abstract class,\n    you must ensure that all subclasses of the abstract class are loaded when the module is\n    loaded, because the subclasses register themselves in their respective files. You can\n    achieve this by having the abstract class and all subclasses in the ``__init__.py`` of the\n    module in which they reside (as this causes any import of either the abstract class or\n    a subclass to load all other subclasses and the abstract class).\n    \"\"\"\n\n    _registry: ClassVar[DefaultDict[type, _SubclassRegistry]] = defaultdict(dict)\n\n    default_implementation: Optional[str] = None\n\n    @classmethod\n    def register(\n        cls, name: str, constructor: Optional[str] = None, exist_ok: bool = False\n    ) -> Callable[[Type[_T]], Type[_T]]:\n        \"\"\"\n        Register a class under a particular name.\n\n        :param name:\n            The name to register the class under.\n        :param constructor:\n            The name of the method to use on the class to construct the object.  If this is given,\n            we will use this method (which must be a ``@classmethod``) instead of the default\n            constructor.\n        :param exist_ok:\n            If True, overwrites any existing models registered under ``name``. Else,\n            throws an error if a model is already registered under ``name``.\n\n        Examples\n        --------\n\n        To use this class, you would typically have a base class that inherits from ``Registrable``::\n\n            class Vocabulary(Registrable):\n                ...\n\n        Then, if you want to register a subclass, you decorate it like this::\n\n            @Vocabulary.register(\"my-vocabulary\")\n            class MyVocabulary(Vocabulary):\n                def __init__(self, param1: int, param2: str):\n                    ...\n\n        Registering a class like this will let you instantiate a class from a config file, where you\n        give ``\"type\": \"my-vocabulary\"``, and keys corresponding to the parameters of the ``__init__``\n        method (note that for this to work, those parameters must have type annotations).\n\n        If you want to have the instantiation from a config file call a method other than the\n        constructor, either because you have several different construction paths that could be\n        taken for the same object (as we do in ``Vocabulary``) or because you have logic you want to\n        happen before you get to the constructor (as we do in ``Embedding``), you can register a\n        specific ``@classmethod`` as the constructor to use, like this::\n\n            @Vocabulary.register(\"my-vocabulary-from-instances\", constructor=\"from_instances\")\n            @Vocabulary.register(\"my-vocabulary-from-files\", constructor=\"from_files\")\n            class MyVocabulary(Vocabulary):\n                def __init__(self, some_params):\n                    ...\n\n                @classmethod\n                def from_instances(cls, some_other_params) -> MyVocabulary:\n                    ...  # construct some_params from instances\n                    return cls(some_params)\n\n                @classmethod\n                def from_files(cls, still_other_params) -> MyVocabulary:\n                    ...  # construct some_params from files\n                    return cls(some_params)\n        \"\"\"\n\n        if _cls_is_step(cls) and name == \"ref\":\n            raise ConfigurationError(\n                \"You cannot use the name 'ref' to register a step. This name is reserved.\"\n            )\n\n        registry = Registrable._registry[cls]\n\n        def add_subclass_to_registry(subclass: Type[_T]) -> Type[_T]:\n            # Add to registry, raise an error if key has already been used.\n            if name in registry:\n                already_in_use_for = registry[name][0]\n                if already_in_use_for.__module__ == \"__main__\":\n                    # Sometimes the same class shows up under module.submodule.Class and __main__.Class, and we\n                    # don't want to make a fuss in that case. We prefer the class without __main__, so we go\n                    # ahead and overwrite the entry.\n                    pass\n                elif subclass.__module__ == \"__main__\":\n                    # We don't want to overwrite the entry because the new one comes from the __main__ module.\n                    return already_in_use_for\n                elif exist_ok:\n                    message = (\n                        f\"Registering {_fullname(subclass)} as a {_fullname(cls)} under the name {name} \"\n                        f\"overwrites existing entry {_fullname(already_in_use_for)}, which is fine because \"\n                        \"you said exist_ok=True.\"\n                    )\n                    logger.info(message)\n                else:\n                    message = (\n                        f\"Attempting to register {_fullname(subclass)} as a {_fullname(cls)} under the name \"\n                        f\"'{name}' failed. {_fullname(already_in_use_for)} is already registered under that name.\"\n                    )\n                    raise ConfigurationError(message)\n            registry[name] = (subclass, constructor)\n            return subclass\n\n        return add_subclass_to_registry\n\n    @classmethod\n    def by_name(cls: Type[_RegistrableT], name: str) -> Callable[..., _RegistrableT]:\n        \"\"\"\n        Returns a callable function that constructs an argument of the registered class.  Because\n        you can register particular functions as constructors for specific names, this isn't\n        necessarily the ``__init__`` method of some class.\n        \"\"\"\n        logger.debug(f\"instantiating registered subclass {name} of {cls}\")\n        subclass, constructor = cls.resolve_class_name(name)\n        if not constructor:\n            return cast(Type[_RegistrableT], subclass)\n        else:\n            return cast(Callable[..., _RegistrableT], getattr(subclass, constructor))\n\n    @classmethod\n    def search_modules(cls: Type[_RegistrableT], name: str):\n        \"\"\"\n        Search for and import modules where ``name`` might be registered.\n        \"\"\"\n        if (\n            could_be_class_name(name)\n            or name in Registrable._registry[cls]\n            or (_cls_is_step(cls) and name == \"ref\")\n        ):\n            return None\n\n        def try_import(module, recursive: bool = True):\n            try:\n                import_module_and_submodules(module, recursive=recursive)\n            except IntegrationMissingError:\n                pass\n            except ImportError as e:\n                if e.name != module:\n                    raise\n\n        integrations = {m.split(\".\")[-1]: m for m in find_integrations()}\n        integrations_imported: Set[str] = set()\n        if name in integrations:\n            try_import(integrations[name], recursive=False)\n            integrations_imported.add(name)\n            if name in Registrable._registry[cls]:\n                return None\n\n        if \"::\" in name:\n            # Try to guess the integration that it comes from.\n            maybe_integration = name.split(\"::\")[0]\n            if maybe_integration in integrations:\n                try_import(integrations[maybe_integration], recursive=False)\n                integrations_imported.add(maybe_integration)\n                if name in Registrable._registry[cls]:\n                    return None\n\n        # Check Python files and modules in the current directory.\n        from glob import glob\n        from pathlib import Path\n\n        for pyfile in glob(\"*.py\"):\n            module = str(Path(pyfile).with_suffix(\"\"))\n            if module == \"setup\":\n                continue\n            try:\n                try_import(module)\n                if name in Registrable._registry[cls]:\n                    return None\n            except:  # noqa: E722\n                continue\n        for pyinit in glob(\"**/__init__.py\"):\n            module = str(Path(pyinit).parent)\n            if module == \"tango\" or module.startswith(\"test\"):\n                continue\n            try:\n                try_import(module)\n                if name in Registrable._registry[cls]:\n                    return None\n            except:  # noqa: E722\n                continue\n\n        # Search all other modules in Tango.\n        for module in find_submodules(exclude={\"tango.integrations*\"}, recursive=False):\n            try_import(module)\n            if name in Registrable._registry[cls]:\n                return None\n\n        # Try importing all other integrations.\n        for integration_name, module in integrations.items():\n            if integration_name not in integrations_imported:\n                try_import(module, recursive=False)\n                integrations_imported.add(integration_name)\n                if name in Registrable._registry[cls]:\n                    return None\n\n    @classmethod\n    def resolve_class_name(\n        cls: Type[_RegistrableT],\n        name: str,\n        search_modules: bool = True,\n    ) -> Tuple[Type[_RegistrableT], Optional[str]]:\n        \"\"\"\n        Returns the subclass that corresponds to the given ``name``, along with the name of the\n        method that was registered as a constructor for that ``name``, if any.\n\n        This method also allows ``name`` to be a fully-specified module name, instead of a name that\n        was already added to the ``Registry``.  In that case, you cannot use a separate function as\n        a constructor (as you need to call ``cls.register()`` in order to tell us what separate\n        function to use).\n\n        If the ``name`` given is not in the registry and ``search_modules`` is ``True``,\n        it will search for and import modules where the class might be defined according to\n        :meth:`search_modules()`.\n        \"\"\"\n        if name in Registrable._registry[cls]:\n            subclass, constructor = Registrable._registry[cls][name]\n            return subclass, constructor\n        elif could_be_class_name(name):\n            # This might be a fully qualified class name, so we'll try importing its \"module\"\n            # and finding it there.\n            parts = name.split(\".\")\n            submodule = \".\".join(parts[:-1])\n            class_name = parts[-1]\n\n            try:\n                module = importlib.import_module(submodule)\n            except ModuleNotFoundError:\n                raise ConfigurationError(\n                    f\"tried to interpret {name} as a path to a class \"\n                    f\"but unable to import module {submodule}\"\n                )\n\n            try:\n                subclass = getattr(module, class_name)\n                constructor = None\n                return subclass, constructor\n            except AttributeError:\n                raise ConfigurationError(\n                    f\"tried to interpret {name} as a path to a class \"\n                    f\"but unable to find class {class_name} in {submodule}\"\n                )\n        else:\n            # is not a qualified class name\n            if search_modules:\n                cls.search_modules(name)\n                return cls.resolve_class_name(name, search_modules=False)\n\n            available = cls.list_available()\n            suggestion = _get_suggestion(name, available)\n            raise RegistryKeyError(\n                (\n                    f\"'{name}' is not a registered name for '{cls.__name__}'\"\n                    + (\". \" if not suggestion else f\", did you mean '{suggestion}'? \")\n                )\n                + \"If your registered class comes from custom code, you'll need to import \"\n                \"the corresponding modules. If you're using Tango or AllenNLP from the command-line, \"\n                \"this is done by using the '--include-package' flag, or by specifying your imports \"\n                \"in a 'tango.yml' settings file. \"\n                \"Alternatively, you can specify your choices \"\n                \"\"\"using fully-qualified paths, e.g. {\"model\": \"my_module.models.MyModel\"} \"\"\"\n                \"in which case they will be automatically imported correctly.\"\n            )\n\n    @classmethod\n    def list_available(cls) -> List[str]:\n        \"\"\"List default first if it exists\"\"\"\n        keys = list(Registrable._registry[cls].keys())\n        default = cls.default_implementation\n\n        if default is None:\n            return keys\n\n        if default not in keys:\n            cls.search_modules(default)\n\n        keys = list(Registrable._registry[cls].keys())\n        if default not in keys:\n            raise ConfigurationError(f\"Default implementation '{default}' is not registered\")\n        else:\n            return [default] + [k for k in keys if k != default]\n\n\nclass RegistrableFunction(Registrable):\n    \"\"\"\n    A registrable class mimicking a `Callable`. This is to allow\n    referring to functions by their name in tango configurations.\n    \"\"\"\n\n    WRAPPED_FUNC: ClassVar[Callable]\n\n    def __call__(self, *args, **kwargs):\n        return self.__class__.WRAPPED_FUNC(*args, **kwargs)\n\n\ndef make_registrable(name: Optional[str] = None, *, exist_ok: bool = False):\n    \"\"\"\n    A decorator to create a :class:`RegistrableFunction` from a function.\n\n    :param name: A name to register the function under. By default the name of the function is used.\n    :param exist_ok:\n        If True, overwrites any existing function registered under the same ``name``. Else,\n        throws an error if a function is already registered under ``name``.\n    \"\"\"\n\n    def function_wrapper(func):\n        @RegistrableFunction.register(name or func.__name__, exist_ok=exist_ok)\n        class WrapperFunc(RegistrableFunction):\n            WRAPPED_FUNC = func\n\n        return WrapperFunc()\n\n    return function_wrapper\n\n\ndef _get_suggestion(name: str, available: List[str]) -> Optional[str]:\n    # Check for simple mistakes like using '-' instead of '_', or vice-versa.\n    for ch, repl_ch in ((\"_\", \"-\"), (\"-\", \"_\")):\n        suggestion = name.replace(ch, repl_ch)\n        if suggestion in available:\n            return suggestion\n    return None\n\n\ndef _fullname(c: type) -> str:\n    return f\"{c.__module__}.{c.__qualname__}\"\n\n\ndef _cls_is_step(c: type) -> bool:\n    # NOTE (epwalsh): importing the actual Step class here would result in a circular\n    # import, even though the import wouldn't be at the top of the module (believe me, I've tried).\n    # So instead we just check the fully qualified name of the class.\n    return _fullname(c) == \"tango.step.Step\"\n"
  },
  {
    "path": "tango/common/remote_utils.py",
    "content": "import logging\nfrom typing import Union\n\nfrom tango.step import Step\nfrom tango.step_info import StepInfo\n\nlogger = logging.getLogger(__name__)\n\n\nclass RemoteConstants:\n    \"\"\"\n    Common constants to be used as prefixes and filenames in remote workspaces.\n    \"\"\"\n\n    RUN_ARTIFACT_PREFIX: str = \"tango-run-\"\n    RUN_DATA_FNAME: str = \"run.json\"\n    STEP_ARTIFACT_PREFIX: str = \"tango-step-\"\n    STEP_INFO_FNAME: str = \"step_info.json\"\n    STEP_RESULT_DIR: str = \"result\"\n    STEP_GRAPH_ARTIFACT_PREFIX: str = \"tango-step-graph-\"\n    STEP_EXPERIMENT_PREFIX: str = \"tango-step-\"\n    STEP_GRAPH_FILENAME: str = \"config.json\"\n    GITHUB_TOKEN_SECRET_NAME: str = \"TANGO_GITHUB_TOKEN\"\n    RESULTS_DIR: str = \"/tango/output\"\n    INPUT_DIR: str = \"/tango/input\"\n    LOCK_ARTIFACT_SUFFIX: str = \"-lock\"\n\n    @classmethod\n    def step_artifact_name(cls, step: Union[str, StepInfo, Step]) -> str:\n        return f\"{cls.STEP_ARTIFACT_PREFIX}{step if isinstance(step, str) else step.unique_id}\"\n\n    @classmethod\n    def step_lock_artifact_name(cls, step: Union[str, StepInfo, Step]) -> str:\n        return f\"{cls.step_artifact_name(step)}{cls.LOCK_ARTIFACT_SUFFIX}\"\n\n    @classmethod\n    def run_artifact_name(cls, name: str) -> str:\n        return f\"{cls.RUN_ARTIFACT_PREFIX}{name}\"\n"
  },
  {
    "path": "tango/common/sequences.py",
    "content": "import bisect\nimport os\nimport random\nimport shutil\nfrom collections import abc\nfrom os import PathLike\nfrom typing import Any, Callable, Iterable, MutableSequence, Optional, Sequence, Union\n\n\nclass ShuffledSequence(abc.Sequence):\n    \"\"\"\n    Produces a shuffled view of a sequence, such as a list.\n\n    This assumes that the inner sequence never changes. If it does, the results\n    are undefined.\n\n    :param inner_sequence: the inner sequence that's being shuffled\n    :param indices: Optionally, you can specify a list of indices here. If you don't, we'll just shuffle the\n                    inner sequence randomly. If you do specify indices, element ``n`` of the output sequence\n                    will be ``inner_sequence[indices[n]]``. This gives you great flexibility. You can repeat\n                    elements, leave them out completely, or slice the list. A Python :class:`slice` object is\n                    an acceptable input for this parameter, and so are other sequences from this module.\n\n    Example:\n\n    .. testcode::\n        :hide:\n\n        import random\n        random.seed(42)\n\n    .. testcode::\n\n        from tango.common.sequences import ShuffledSequence\n        l = [1, 2, 3, 4, 5, 6, 7, 8, 9]\n        shuffled_l = ShuffledSequence(l)\n\n        print(shuffled_l[0])\n        print(shuffled_l[1])\n        print(shuffled_l[2])\n        assert len(shuffled_l) == len(l)\n\n    This will print something like the following:\n\n    .. testoutput::\n\n        4\n        7\n        8\n    \"\"\"\n\n    def __init__(self, inner_sequence: Sequence, indices: Optional[Sequence[int]] = None):\n        self.inner = inner_sequence\n        self.indices: Sequence[int]\n        if indices is None:\n            self.indices = list(range(len(inner_sequence)))\n            random.shuffle(self.indices)\n        else:\n            self.indices = indices\n\n    def __len__(self) -> int:\n        return len(self.indices)\n\n    def __getitem__(self, i: Union[int, slice]):\n        if isinstance(i, int):\n            return self.inner[self.indices[i]]\n        else:\n            return ShuffledSequence(self.inner, self.indices[i])\n\n    def __contains__(self, item) -> bool:\n        for i in self.indices:\n            if self.inner[i] == item:\n                return True\n        return False\n\n\nclass SlicedSequence(ShuffledSequence):\n    \"\"\"\n    Produces a sequence that's a slice into another sequence, without copying the elements.\n\n    This assumes that the inner sequence never changes. If it does, the results\n    are undefined.\n\n    :param inner_sequence: the inner sequence that's being shuffled\n    :param s: the :class:`~slice` to slice the input with.\n\n    Example:\n\n    .. testcode::\n\n        from tango.common.sequences import SlicedSequence\n        l = [1, 2, 3, 4, 5, 6, 7, 8, 9]\n        sliced_l = SlicedSequence(l, slice(1, 4))\n\n        print(sliced_l[0])\n        print(sliced_l[1])\n        print(sliced_l[2])\n        assert len(sliced_l) == 3\n\n    This will print the following:\n\n    .. testoutput::\n\n        2\n        3\n        4\n\n    \"\"\"\n\n    def __init__(self, inner_sequence: Sequence, s: slice):\n        super().__init__(inner_sequence, range(*s.indices(len(inner_sequence))))\n\n\nclass ConcatenatedSequence(abc.Sequence):\n    \"\"\"\n    Produces a sequence that's the lazy concatenation of multiple other sequences. It does not copy\n    any of the elements of the original sequences.\n\n    This assumes that the inner sequences never change. If they do, the results are undefined.\n\n    :param sequences: the inner sequences to concatenate\n\n    Example:\n\n    .. testcode::\n\n        from tango.common.sequences import ConcatenatedSequence\n        l1 = [1, 2, 3]\n        l2 = [4, 5]\n        l3 = [6]\n        cat_l = ConcatenatedSequence(l1, l2, l3)\n\n        assert len(cat_l) == 6\n        for i in cat_l:\n            print(i)\n\n    This will print the following:\n\n    .. testoutput::\n\n        1\n        2\n        3\n        4\n        5\n        6\n    \"\"\"\n\n    def __init__(self, *sequences: Sequence):\n        self.sequences = sequences\n        self.cumulative_sequence_lengths = [0]\n        for sequence in sequences:\n            self.cumulative_sequence_lengths.append(\n                self.cumulative_sequence_lengths[-1] + len(sequence)\n            )\n\n    def __len__(self):\n        return self.cumulative_sequence_lengths[-1]\n\n    def __getitem__(self, i: Union[int, slice]):\n        if isinstance(i, int):\n            if i < 0:\n                i += len(self)\n            if i < 0 or i >= len(self):\n                raise IndexError(\"list index out of range\")\n            sequence_index = bisect.bisect_right(self.cumulative_sequence_lengths, i) - 1\n            i -= self.cumulative_sequence_lengths[sequence_index]\n            return self.sequences[sequence_index][i]\n        else:\n            return SlicedSequence(self, i)\n\n    def __contains__(self, item) -> bool:\n        return any(s.__contains__(item) for s in self.sequences)\n\n\nclass MappedSequence(abc.Sequence):\n    \"\"\"\n    Produces a sequence that applies a function to every element of another sequence.\n\n    This is similar to Python's :func:`map`, but it returns a sequence instead of a :class:`map` object.\n\n    :param fn: the function to apply to every element of the inner sequence. The function should take\n               one argument.\n    :param inner_sequence: the inner sequence to map over\n\n    Example:\n\n    .. testcode::\n\n        from tango.common.sequences import MappedSequence\n\n        def square(x):\n            return x * x\n\n        l = [1, 2, 3, 4]\n        map_l = MappedSequence(square, l)\n\n        assert len(map_l) == len(l)\n        for i in map_l:\n            print(i)\n\n    This will print the following:\n\n    .. testoutput::\n\n        1\n        4\n        9\n        16\n\n    \"\"\"\n\n    def __init__(self, fn: Callable, inner_sequence: Sequence):\n        self.inner = inner_sequence\n        self.fn = fn\n\n    def __getitem__(self, item):\n        if isinstance(item, slice):\n            new_inner = None\n            try:\n                # special case for a special library\n                from datasets import Dataset\n\n                if isinstance(self.inner, Dataset):\n                    new_inner = self.inner.select(range(*item.indices(len(self.inner))))\n            except ImportError:\n                pass\n            if new_inner is None:\n                new_inner = self.inner[item]\n            return MappedSequence(self.fn, new_inner)\n        else:\n            item = self.inner.__getitem__(item)\n            return self.fn(item)\n\n    def __len__(self):\n        return self.inner.__len__()\n\n    def __contains__(self, item):\n        return any(e == item for e in self)\n\n\nclass SqliteSparseSequence(MutableSequence[Any]):\n    \"\"\"\n    This is a sparse sequence that pickles elements to a Sqlite database.\n\n    When you read from the sequence, elements are retrieved and unpickled lazily. That means creating/opening\n    a sequence is very fast and does not depend on the length of the sequence.\n\n    This is a \"sparse sequence\" because you can set element ``n`` before you set element ``n-1``:\n\n    .. testcode::\n        :hide:\n\n        from tango.common.sequences import SqliteSparseSequence\n        import tempfile\n        dir = tempfile.TemporaryDirectory()\n        from pathlib import Path\n        filename = Path(dir.name) / \"test.sqlite\"\n\n    .. testcode::\n\n        s = SqliteSparseSequence(filename)\n        element = \"Big number, small database.\"\n        s[2**32] = element\n        assert len(s) == 2**32 + 1\n        assert s[2**32] == element\n        assert s[1000] is None\n        s.close()\n\n    .. testcode::\n        :hide:\n\n        dir.cleanup()\n\n    You can use a ``SqliteSparseSequence`` from multiple processes at the same time. This is useful, for example,\n    if you're filling out a sequence and you are partitioning ranges to processes.\n\n    :param filename: the filename at which to store the data\n    :param read_only: Set this to ``True`` if you only want to read.\n    \"\"\"\n\n    def __init__(self, filename: Union[str, PathLike], read_only: bool = False):\n        from sqlitedict import SqliteDict\n\n        self.table = SqliteDict(filename, \"sparse_sequence\", flag=\"r\" if read_only else \"c\")\n\n    def __del__(self):\n        if self.table is not None:\n            self.table.close(force=True)\n            self.table = None\n\n    def __getitem__(self, i: Union[int, slice]) -> Any:\n        if isinstance(i, int):\n            try:\n                return self.table[str(i)]\n            except KeyError:\n                current_length = len(self)\n                if i >= current_length or current_length <= 0:\n                    raise IndexError(\"list index out of range\")\n                elif i < 0 < current_length:\n                    return self.__getitem__(i % current_length)\n                else:\n                    return None\n        elif isinstance(i, slice):\n            return SlicedSequence(self, i)\n        else:\n            raise TypeError(f\"list indices must be integers or slices, not {i.__class__.__name__}\")\n\n    def __setitem__(self, i: Union[int, slice], value: Any):\n        if isinstance(i, int):\n            current_length = len(self)\n            if i < 0:\n                i %= current_length\n            self.table[str(i)] = value\n            self.table[\"_len\"] = max(i + 1, current_length)\n            self.table.commit()\n        else:\n            raise TypeError(f\"list indices must be integers, not {i.__class__.__name__}\")\n\n    def __delitem__(self, i: Union[int, slice]):\n        current_length = len(self)\n        if isinstance(i, int):\n            if i < 0:\n                i %= current_length\n            if i >= current_length:\n                raise IndexError(\"list assignment index out of range\")\n            for index in range(i + 1, current_length):\n                self.table[str(index - 1)] = self.table.get(str(index))\n            del self.table[str(current_length - 1)]\n            self.table[\"_len\"] = current_length - 1\n            self.table.commit()\n        elif isinstance(i, slice):\n            # This isn't very efficient for continuous slices.\n            for index in reversed(range(*i.indices(current_length))):\n                del self[index]\n        else:\n            raise TypeError(f\"list indices must be integers or slices, not {i.__class__.__name__}\")\n\n    def extend(self, values: Iterable[Any]) -> None:\n        current_length = len(self)\n        index = -1\n        for index, value in enumerate(values):\n            self.table[str(index + current_length)] = value\n        if index < 0:\n            return\n        self.table[\"_len\"] = current_length + index + 1\n        self.table.commit()\n\n    def insert(self, i: int, value: Any) -> None:\n        current_length = len(self)\n        for index in reversed(range(i, current_length)):\n            self.table[str(index + 1)] = self.table.get(str(index))\n        self.table[str(i)] = value\n        self.table[\"_len\"] = max(i + 1, current_length + 1)\n        self.table.commit()\n\n    def __len__(self) -> int:\n        try:\n            return self.table[\"_len\"]\n        except KeyError:\n            return 0\n\n    def clear(self) -> None:\n        \"\"\"\n        Clears the entire sequence\n        \"\"\"\n        self.table.clear()\n        self.table.commit()\n\n    def close(self) -> None:\n        \"\"\"\n        Closes the underlying Sqlite table. Do not use this sequence afterwards!\n        \"\"\"\n        if self.table is not None:\n            self.table.close()\n            self.table = None\n\n    def copy_to(self, target: Union[str, PathLike]):\n        \"\"\"\n        Make a copy of this sequence at a new location.\n\n        :param target: the location of the copy\n\n        This will attempt to make a hardlink, which is very fast, but only works on Linux and if ``target`` is\n        on the same drive. If making a hardlink fails, it falls back to making a regular copy. As a result,\n        there is no guarantee whether you will get a hardlink or a copy. If you get a hardlink, future edits\n        in the source sequence will also appear in the target sequence. This is why we recommend to not use\n        :meth:`copy_to()` until you are done with the sequence. This is not ideal, but it is a compromise we make\n        for performance.\n        \"\"\"\n        try:\n            os.link(self.table.filename, target)\n        except OSError as e:\n            if e.errno == 18:  # Cross-device link\n                shutil.copy(self.table.filename, target)\n            else:\n                raise\n"
  },
  {
    "path": "tango/common/testing/__init__.py",
    "content": "import logging\nimport os\nimport shutil\nimport tempfile\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Union, cast\n\nfrom tango.common.aliases import EnvVarNames, PathOrStr\nfrom tango.common.logging import initialize_logging, teardown_logging\nfrom tango.common.params import Params\nfrom tango.settings import TangoGlobalSettings\n\n\nclass TangoTestCase:\n    \"\"\"\n    A custom testing class that\n\n    * disables some of the more verbose logging,\n    * creates and destroys a temp directory as a test fixture, and\n    * restores the internal state of the `Registrable` class at the end of each test method.\n\n    \"\"\"\n\n    PROJECT_ROOT = (Path(__file__).parent / \"..\" / \"..\" / \"..\").resolve()\n    \"\"\"\n    Root of the git repository.\n    \"\"\"\n\n    # to run test suite with finished package, which does not contain\n    # tests & fixtures, we must be able to look them up somewhere else\n    PROJECT_ROOT_FALLBACK = (\n        # users wanting to run test suite for installed package\n        Path(os.environ[\"TANGO_SRC_DIR\"])\n        if \"TANGO_SRC_DIR\" in os.environ\n        else (\n            # fallback for conda packaging\n            Path(os.environ[\"SRC_DIR\"])\n            if \"CONDA_BUILD\" in os.environ\n            # stay in-tree\n            else PROJECT_ROOT\n        )\n    )\n\n    MODULE_ROOT = PROJECT_ROOT_FALLBACK / \"tango\"\n    \"\"\"\n    Root of the tango module.\n    \"\"\"\n\n    TESTS_ROOT = PROJECT_ROOT_FALLBACK / \"tests\"\n    \"\"\"\n    Root of the tests directory.\n    \"\"\"\n\n    FIXTURES_ROOT = PROJECT_ROOT_FALLBACK / \"test_fixtures\"\n    \"\"\"\n    Root of the test fixtures directory.\n    \"\"\"\n\n    def setup_method(self):\n        logging.basicConfig(\n            format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\", level=logging.DEBUG\n        )\n\n        # Disabling some of the more verbose logging statements that typically aren't very helpful\n        # in tests.\n        logging.getLogger(\"urllib3.connectionpool\").disabled = True\n\n        # Create a temporary scratch directory.\n        self.TEST_DIR = Path(tempfile.mkdtemp(prefix=\"tango_tests\"))\n        os.makedirs(self.TEST_DIR, exist_ok=True)\n\n        # Set an artificial console width so logs are not mangled.\n        os.environ[EnvVarNames.CONSOLE_WIDTH.value] = str(300)\n\n    def teardown_method(self):\n        shutil.rmtree(self.TEST_DIR)\n        if EnvVarNames.CONSOLE_WIDTH.value in os.environ:\n            del os.environ[EnvVarNames.CONSOLE_WIDTH.value]\n\n    def run(\n        self,\n        config: Union[PathOrStr, Dict[str, Any], Params],\n        overrides: Optional[Union[Dict[str, Any], str]] = None,\n        include_package: Optional[List[str]] = None,\n        workspace_url: Optional[str] = None,\n        step_name: Optional[str] = None,\n        parallelism: Optional[int] = 1,\n        multicore: Optional[bool] = False,\n        name: Optional[str] = None,\n        settings: Optional[TangoGlobalSettings] = None,\n    ) -> Path:\n        from tango.__main__ import _run\n\n        if isinstance(config, (dict, Params)):\n            params = config if isinstance(config, Params) else Params(config)\n            config = self.TEST_DIR / \"config.json\"\n            params.to_file(cast(Path, config))\n\n        if isinstance(overrides, dict):\n            import json\n\n            overrides = json.dumps(overrides)\n\n        run_name = _run(\n            settings or TangoGlobalSettings(),\n            str(config),\n            workspace_url=workspace_url or \"local://\" + str(self.TEST_DIR / \"workspace\"),\n            overrides=overrides,\n            include_package=include_package,\n            step_names=None if not step_name else [step_name],\n            parallelism=parallelism,\n            multicore=multicore,\n            name=name,\n        )\n\n        return self.TEST_DIR / \"workspace\" / \"runs\" / run_name\n\n\n@contextmanager\ndef run_experiment(\n    config: Union[PathOrStr, Dict[str, Any], Params],\n    overrides: Optional[Union[Dict[str, Any], str]] = None,\n    file_friendly_logging: bool = True,\n    include_package: Optional[List[str]] = None,\n    workspace_url: Optional[str] = None,\n    parallelism: Optional[int] = 1,\n    multicore: Optional[bool] = False,\n    name: Optional[str] = None,\n    settings: Optional[TangoGlobalSettings] = None,\n):\n    \"\"\"\n    A context manager to make testing experiments easier. On ``__enter__`` it runs\n    the experiment and returns the path to the run directory, a temporary directory that will be\n    cleaned up on ``__exit__``.\n    \"\"\"\n    initialize_logging(enable_cli_logs=True, file_friendly_logging=file_friendly_logging)\n    test_case = TangoTestCase()\n    try:\n        test_case.setup_method()\n        yield test_case.run(\n            config,\n            overrides=overrides,\n            include_package=include_package,\n            workspace_url=workspace_url,\n            parallelism=parallelism,\n            multicore=multicore,\n            name=name,\n            settings=settings,\n        )\n    finally:\n        test_case.teardown_method()\n        teardown_logging()\n\n\ndef requires_gpus(test_method):\n    \"\"\"\n    Decorator to indicate that a test requires multiple GPU devices.\n    \"\"\"\n    import pytest\n    import torch\n\n    return pytest.mark.gpu(\n        pytest.mark.skipif(torch.cuda.device_count() < 2, reason=\"2 or more GPUs required.\")(\n            test_method\n        )\n    )\n"
  },
  {
    "path": "tango/common/testing/steps.py",
    "content": "import logging\nimport multiprocessing as mp\nimport random\nimport time\nfrom string import ascii_letters\nfrom typing import List\n\nimport tango.common.logging as common_logging\nfrom tango import Step\nfrom tango.common import Tqdm\n\n\n@Step.register(\"float\")\nclass FloatStep(Step):\n    CACHEABLE = True\n    DETERMINISTIC = True\n\n    def run(self, result: float) -> float:  # type: ignore\n        return result\n\n\n@Step.register(\"string\")\nclass StringStep(Step):\n    CACHEABLE = True\n    DETERMINISTIC = True\n\n    def run(self, result: str) -> str:  # type: ignore\n        return result\n\n\n@Step.register(\"concat_strings\")\nclass ConcatStringsStep(Step):\n    CACHEABLE = True\n    DETERMINISTIC = True\n\n    def run(self, string1: str, string2: str, join_with: str = \" \") -> str:  # type: ignore\n        return join_with.join([string1, string2])\n\n\n@Step.register(\"noisy_step\")\nclass NoisyStep(Step):\n    CACHEABLE = True\n    DETERMINISTIC = True\n\n    def run(self, raise_error: bool = False) -> None:  # type: ignore\n        self.logger.debug(\"debug message\")\n        common_logging.cli_logger.debug(\"debug message from cli_logger\")\n\n        self.logger.info(\"info message\")\n        common_logging.cli_logger.info(\"info message from cli_logger\")\n\n        self.logger.warning(\"warning message\")\n        common_logging.cli_logger.warning(\"warning message from cli_logger\")\n\n        self.logger.error(\"error message\")\n        common_logging.cli_logger.error(\"error message from cli_logger\")\n\n        if raise_error:\n            raise ValueError(\"Oh no!\")\n\n\n@Step.register(\"random_string\")\nclass RandomStringStep(Step):\n    def run(self, length: int = 10) -> str:  # type: ignore\n        return \"\".join([random.choice(ascii_letters) for _ in range(length)])\n\n\n@Step.register(\"add_numbers\")\nclass AddNumbersStep(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n\n    def run(self, a_number: int, b_number: int) -> int:  # type: ignore\n        return a_number + b_number\n\n\n@Step.register(\"sleep-print-maybe-fail\")\nclass SleepPrintMaybeFail(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n\n    def run(self, string: str, seconds: int = 5, fail: bool = False) -> str:  # type: ignore\n        time.sleep(seconds)\n        self.logger.info(f\"Step {self.name} is awake.\")\n        print(string)\n        if fail:\n            raise RuntimeError(\"Step had to fail!\")\n        return string\n\n\n@Step.register(\"logging-step\")\nclass LoggingStep(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n\n    def run(self, string: str, num_log_lines: int = 50) -> str:  # type: ignore\n        for i in Tqdm.tqdm(list(range(num_log_lines)), desc=\"log progress\"):\n            time.sleep(0.1)\n            self.logger.info(f\"{i} - {string}\")\n        return string\n\n\n@Step.register(\"make_number\")\nclass MakeNumber(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n\n    def run(self, what_number: int) -> int:  # type: ignore\n        return what_number\n\n\n@Step.register(\"store_number_in_file\")\nclass StoreNumberInFile(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(self, number: int, file_name: str) -> None:  # type: ignore\n        # Note: this is only for testing if the uncacheable step\n        # ran in the multicore setting. Normally, a step like this\n        # would be marked as CACHEABLE.\n        with open(file_name, \"w\") as file_ref:\n            file_ref.write(str(number))\n\n\n@Step.register(\"multiprocessing_step\")\nclass MultiprocessingStep(Step):\n    \"\"\"\n    Mainly used to test that logging works properly in child processes.\n    \"\"\"\n\n    def run(self, num_proc: int = 2) -> bool:  # type: ignore\n        for _ in Tqdm.tqdm(list(range(10)), desc=\"progress from main process\"):\n            time.sleep(0.1)\n\n        workers = []\n        for i in range(num_proc):\n            worker = mp.Process(target=_worker_function, args=(i,))\n            workers.append(worker)\n            worker.start()\n\n        for worker in workers:\n            worker.join()\n\n        return True\n\n\n@Step.register(\"range_step\")\nclass RangeOutput(Step):\n    def run(self, start: int, end: int) -> List[int]:  # type: ignore\n        return list(range(start, end))\n\n\ndef _worker_function(worker_id: int):\n    common_logging.initialize_worker_logging(worker_id)\n    logger = logging.getLogger(MultiprocessingStep.__name__)\n    logger.info(\"Hello from worker %d!\", worker_id)\n    common_logging.cli_logger.info(\"Hello from the cli logger in worker %d!\", worker_id)\n    for _ in Tqdm.tqdm(list(range(10)), desc=\"progress from worker\", disable=worker_id > 0):\n        time.sleep(0.1)\n"
  },
  {
    "path": "tango/common/tqdm.py",
    "content": "\"\"\"\nCopied over from ``allennlp.common.tqdm.Tqdm``.\n\nWraps tqdm so we can add configurable global defaults for certain tqdm parameters.\n\"\"\"\n\nimport logging\nimport sys\nfrom contextlib import contextmanager\nfrom time import time\nfrom typing import Optional\n\ntry:\n    SHELL = str(type(get_ipython()))  # type:ignore # noqa: F821\nexcept:  # noqa: E722\n    SHELL = \"\"\n\nif \"zmqshell.ZMQInteractiveShell\" in SHELL:\n    from tqdm import tqdm_notebook as _tqdm\nelse:\n    from tqdm import tqdm as _tqdm\n\nfrom tango.common import logging as common_logging\n\n# This is necessary to stop tqdm from hanging\n# when exceptions are raised inside iterators.\n# It should have been fixed in 4.2.1, but it still\n# occurs.\n# TODO(Mark): Remove this once tqdm cleans up after itself properly.\n# https://github.com/tqdm/tqdm/issues/469\n_tqdm.monitor_interval = 0\n\n\nlogger = logging.getLogger(\"tqdm\")\nlogger.propagate = False\n\n\ndef replace_cr_with_newline(message: str) -> str:\n    \"\"\"\n    TQDM and requests use carriage returns to get the training line to update for each batch\n    without adding more lines to the terminal output. Displaying those in a file won't work\n    correctly, so we'll just make sure that each batch shows up on its one line.\n    \"\"\"\n    # In addition to carriage returns, nested progress-bars will contain extra new-line\n    # characters and this special control sequence which tells the terminal to move the\n    # cursor one line up.\n    message = message.replace(\"\\r\", \"\").replace(\"\\n\", \"\").replace(\"\u001b[A\", \"\")\n    if message and message[-1] != \"\\n\":\n        message += \"\\n\"\n    return message\n\n\nclass TqdmToLogsWriter:\n    def __init__(self):\n        self.last_message_written_time = 0.0\n\n    def write(self, message):\n        file_friendly_message: Optional[str] = None\n        if common_logging.FILE_FRIENDLY_LOGGING:\n            file_friendly_message = replace_cr_with_newline(message)\n            if file_friendly_message.strip():\n                sys.stderr.write(file_friendly_message)\n        else:\n            sys.stderr.write(message)\n\n        # Every 10 seconds we also log the message.\n        now = time()\n        if now - self.last_message_written_time >= 10 or \"100%\" in message:\n            if file_friendly_message is None:\n                file_friendly_message = replace_cr_with_newline(message)\n            for message in file_friendly_message.split(\"\\n\"):\n                message = message.strip()\n                if len(message) > 0:\n                    logger.info(message)\n                    self.last_message_written_time = now\n\n    def flush(self):\n        sys.stderr.flush()\n\n\nclass Tqdm:\n    \"\"\"\n    A `tqdm <https://tqdm.github.io/>`_ wrapper that respects\n    :data:`~tango.common.logging.FILE_FRIENDLY_LOGGING` and other Tango logging configurations.\n    \"\"\"\n\n    @staticmethod\n    def tqdm(*args, **kwargs):\n        new_kwargs = Tqdm.get_updated_kwargs(**kwargs)\n        return _tqdm(*args, **new_kwargs)\n\n    @staticmethod\n    @contextmanager\n    def wrapattr(*args, **kwargs):\n        new_kwargs = Tqdm.get_updated_kwargs(**kwargs)\n        with _tqdm.wrapattr(*args, **new_kwargs) as t:\n            yield t\n\n    @staticmethod\n    def get_updated_kwargs(**kwargs):\n        # Use a slower interval when FILE_FRIENDLY_LOGGING is set.\n        default_mininterval = 2.0 if common_logging.FILE_FRIENDLY_LOGGING else 0.1\n        return {\n            \"file\": TqdmToLogsWriter(),\n            \"mininterval\": default_mininterval,\n            **kwargs,\n        }\n\n    @staticmethod\n    def set_lock(lock):\n        _tqdm.set_lock(lock)\n\n    @staticmethod\n    def get_lock():\n        return _tqdm.get_lock()\n"
  },
  {
    "path": "tango/common/util.py",
    "content": "import importlib\nimport pkgutil\nimport signal\nimport string\nimport sys\nimport traceback\nfrom collections import OrderedDict\nfrom dataclasses import asdict, is_dataclass\nfrom datetime import datetime, tzinfo\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import Any, Iterable, Optional, Set, Tuple, Union\n\nimport pytz\n\nfrom .exceptions import SigTermReceived\n\n\ndef tango_cache_dir() -> Path:\n    \"\"\"\n    Returns a directory suitable for caching things from Tango, defaulting\n    to ``$HOME/.cache/tango``.\n    \"\"\"\n    cache_dir = Path.home() / \".cache\" / \"tango\"\n    cache_dir.mkdir(parents=True, exist_ok=True)\n    return cache_dir\n\n\ndef _handle_sigterm(sig, frame):\n    raise SigTermReceived\n\n\ndef install_sigterm_handler():\n    signal.signal(signal.SIGTERM, _handle_sigterm)\n\n\n_extra_imported_modules: Set[str] = set()\n\n\ndef get_extra_imported_modules() -> Set[str]:\n    return _extra_imported_modules\n\n\ndef import_extra_module(package_name: str) -> None:\n    global _extra_imported_modules\n    import_module_and_submodules(package_name)\n    _extra_imported_modules.add(package_name)\n\n\ndef resolve_module_name(package_name: str) -> Tuple[str, Path]:\n    base_path = Path(\".\")\n    package_path = Path(package_name)\n    if not package_path.exists():\n        raise ValueError(f\"'{package_path}' looks like a path, but the path does not exist\")\n\n    parent = package_path.parent\n    while parent != parent.parent:\n        if (parent / \"__init__.py\").is_file():\n            parent = parent.parent\n        else:\n            base_path = parent\n            break\n\n    package_name = str(package_path.relative_to(base_path)).replace(\"/\", \".\")\n\n    if package_path.is_file():\n        if package_path.name == \"__init__.py\":\n            # If `__init__.py` file, resolve to the parent module.\n            package_name = package_name[: -len(\".__init__.py\")]\n        elif package_name.endswith(\".py\"):\n            package_name = package_name[:-3]\n\n        if not package_name:\n            raise ValueError(f\"invalid package path '{package_path}'\")\n\n    return package_name, base_path\n\n\ndef import_module_and_submodules(\n    package_name: str, exclude: Optional[Set[str]] = None, recursive: bool = True\n) -> None:\n    \"\"\"\n    Import all submodules under the given package.\n\n    Primarily useful so that people using tango can specify their own custom packages\n    and have their custom classes get loaded and registered.\n    \"\"\"\n    # If `package_name` is in the form of a path, convert to the module format.\n    if \"/\" in package_name or package_name.endswith(\".py\"):\n        package_name, base_path = resolve_module_name(package_name)\n    else:\n        base_path = Path(\".\")\n    base_path = base_path.resolve()\n\n    if exclude and package_name in exclude:\n        return\n\n    importlib.invalidate_caches()\n\n    # Ensure `base_path` is first in `sys.path`.\n    if str(base_path) not in sys.path:\n        sys.path.insert(0, str(base_path))\n    else:\n        sys.path.insert(0, sys.path.pop(sys.path.index(str(base_path))))\n\n    # Certain packages might mess with sys.excepthook which we don't like since\n    # we mess with sys.excepthook ourselves. If it looks like the package is overriding\n    # the hook for a reason, we'll leave it be but also make sure our hook is still called.\n    excepthook = sys.excepthook\n\n    # Import at top level\n    try:\n        module = importlib.import_module(package_name)\n    finally:\n        if sys.excepthook != excepthook:\n            if sys.excepthook.__module__.startswith(\"rich\"):\n                # We definitely don't want rich's traceback hook because that will mess\n                # with our exception handling.\n                sys.excepthook = excepthook\n            else:\n                new_hook = sys.excepthook\n\n                def excepthook_wrapper(exctype, value, traceback):\n                    excepthook(exctype, value, traceback)\n                    new_hook(exctype, value, traceback)\n\n                sys.excepthook = excepthook_wrapper\n\n    path = getattr(module, \"__path__\", [])\n    path_string = \"\" if not path else path[0]\n\n    if recursive:\n        # walk_packages only finds immediate children, so need to recurse.\n        for module_finder, name, _ in pkgutil.walk_packages(path):\n            # Sometimes when you import third-party libraries that are on your path,\n            # `pkgutil.walk_packages` returns those too, so we need to skip them.\n            if path_string and module_finder.path != path_string:  # type: ignore[union-attr]\n                continue\n            subpackage = f\"{package_name}.{name}\"\n            import_module_and_submodules(subpackage, exclude=exclude)\n\n\ndef _parse_bool(value: Union[bool, str]) -> bool:\n    if isinstance(value, bool):\n        return value\n    if value in {\"1\", \"true\", \"True\", \"TRUE\"}:\n        return True\n    return False\n\n\ndef _parse_optional_int(value: Optional[str]) -> Optional[int]:\n    if value is not None:\n        return int(value)\n    return None\n\n\ndef find_submodules(\n    module: Optional[str] = None,\n    match: Optional[Set[str]] = None,\n    exclude: Optional[Set[str]] = None,\n    recursive: bool = True,\n) -> Iterable[str]:\n    \"\"\"\n    Find tango submodules.\n    \"\"\"\n    from fnmatch import fnmatch\n\n    root = Path(__file__).parent / \"..\"\n    if module:\n        if module.startswith(\"tango.\"):\n            module = module.replace(\"tango.\", \"\", 1)\n        for m in module.split(\".\"):\n            root = root / m\n        module = f\"tango.{module}\"\n    else:\n        module = \"tango\"\n    for path in root.iterdir():\n        if path.name[0] in {\"_\", \".\"}:\n            continue\n        submodule: str\n        if path.is_dir():\n            submodule = path.name\n        elif path.suffix == \".py\":\n            submodule = path.name[:-3]\n        else:\n            continue\n        submodule = f\"{module}.{submodule}\"\n        if exclude and any((fnmatch(submodule, pat) for pat in exclude)):\n            continue\n        if match and not any((fnmatch(submodule, pat) for pat in match)):\n            continue\n        yield submodule\n        if recursive and path.is_dir():\n            yield from find_submodules(submodule, match=match, exclude=exclude)\n\n\ndef find_integrations() -> Iterable[str]:\n    \"\"\"\n    Find all tango integration modules.\n    \"\"\"\n    yield from find_submodules(\"tango.integrations\", recursive=False)\n\n\nSAFE_FILENAME_CHARS = frozenset(\"-_.%s%s\" % (string.ascii_letters, string.digits))\n\n\ndef filename_is_safe(filename: str) -> bool:\n    return all(c in SAFE_FILENAME_CHARS for c in filename)\n\n\ndef make_safe_filename(name: str) -> str:\n    if filename_is_safe(name):\n        return name\n    else:\n        from tango.common.det_hash import det_hash\n\n        name_hash = det_hash(name)\n        name = name.replace(\" \", \"-\").replace(\"/\", \"--\")\n        return \"\".join(c for c in name if c in SAFE_FILENAME_CHARS) + f\"-{name_hash[:7]}\"\n\n\ndef could_be_class_name(name: str) -> bool:\n    if \".\" in name and not name.endswith(\".\"):\n        return all([_is_valid_python_name(part) for part in name.split(\".\")])\n    else:\n        return False\n\n\ndef _is_valid_python_name(name: str) -> bool:\n    return bool(name and name[0].isalpha() and name.replace(\"_\", \"\").isalnum())\n\n\ndef threaded_generator(g, queue_size: int = 16):\n    \"\"\"\n    Puts the generating side of a generator into its own thread\n\n    Let's say you have a generator that reads records from disk, and something that consumes the\n    generator that spends most of its time in PyTorch. Wouldn't it be great if you could read more\n    records while the PyTorch code runs? If you wrap your record-reading generator with\n    ``threaded_generator(inner)``, that's exactly what happens. The reading code will run in a new thread,\n    while the consuming code runs in the main thread as normal. ``threaded_generator()`` uses a queue\n    to hand off items.\n\n    :param queue_size: the maximum queue size for hand-offs between the main thread and the generator thread\n    \"\"\"\n    from queue import Queue\n    from threading import Thread\n\n    q: Queue = Queue(maxsize=queue_size)\n\n    sentinel = object()\n\n    def fill_queue():\n        try:\n            for value in g:\n                q.put(value)\n        finally:\n            q.put(sentinel)\n\n    thread = Thread(name=repr(g), target=fill_queue, daemon=True)\n    thread.start()\n\n    yield from iter(q.get, sentinel)\n\n    thread.join()\n\n\ndef exception_to_string(e: BaseException) -> str:\n    \"\"\"\n    Generates a string that contains an exception plus stack frames based on an exception.\n\n    This became trivial in Python 3.10, but we need to run on Python 3.8 as well.\n    \"\"\"\n    if sys.version_info >= (3, 10):\n        formatted = traceback.format_exception(e)\n    else:\n        formatted = traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)\n    return \"\".join(formatted)\n\n\ndef utc_now_datetime() -> datetime:\n    return datetime.utcnow().replace(tzinfo=pytz.utc)\n\n\ndef local_timezone() -> Optional[tzinfo]:\n    return datetime.now().astimezone().tzinfo\n\n\ndef replace_steps_with_unique_id(o: Any):\n    from tango.step import Step, StepIndexer\n\n    if isinstance(o, Step):\n        return {\"type\": \"ref\", \"ref\": o.unique_id}\n    elif isinstance(o, StepIndexer):\n        return {\"type\": \"ref\", \"ref\": o.step.unique_id, \"key\": o.key}\n    elif isinstance(o, (list, tuple, set)):\n        return o.__class__(replace_steps_with_unique_id(i) for i in o)\n    elif isinstance(o, dict):\n        return {key: replace_steps_with_unique_id(value) for key, value in o.items()}\n    else:\n        return o\n\n\ndef jsonify(o: Any) -> Any:\n    \"\"\"\n    Transform an object into a JSON-serializable equivalent (if there is one)\n    in a deterministic way. For example, tuples and sets are turned into lists,\n    dictionaries are turned into ordered dictionaries where the order depends on the sorting\n    of the keys, and datetimes are turned into formatted strings.\n    \"\"\"\n    if isinstance(o, (tuple, set)):\n        return [jsonify(x) for x in o]\n    elif isinstance(o, dict):\n        return OrderedDict((k, jsonify(v)) for k, v in sorted(o.items(), key=lambda x: x[0]))\n    elif isinstance(o, datetime):\n        return o.strftime(\"%Y-%m-%dT%H:%M:%S\")\n    elif is_dataclass(o):\n        return jsonify(asdict(o))\n    elif isinstance(o, Path):\n        return str(o)\n    else:\n        return o\n\n\nclass StrEnum(str, Enum):\n    def __str__(self) -> str:\n        return self.value\n"
  },
  {
    "path": "tango/executor.py",
    "content": "import logging\nimport warnings\nfrom dataclasses import dataclass, field\nfrom typing import TYPE_CHECKING, Dict, List, Optional, Sequence, TypeVar\n\nfrom rich import get_console\nfrom rich.table import Table\n\nfrom .common.logging import cli_logger, log_exception\nfrom .common.registrable import Registrable\nfrom .common.util import import_extra_module\nfrom .step_graph import StepGraph\nfrom .workspace import Workspace\n\nif TYPE_CHECKING:\n    from .step import Step\n\nlogger = logging.getLogger(__name__)\n\n\nT = TypeVar(\"T\")\n\n\n@dataclass\nclass ExecutionMetadata:\n    logs_location: Optional[str] = None\n    \"\"\"\n    Path or URL to the logs for the step's execution.\n    \"\"\"\n\n    result_location: Optional[str] = None\n    \"\"\"\n    Path or URL to the result of the step's execution.\n    \"\"\"\n\n\n@dataclass\nclass ExecutorOutput:\n    \"\"\"\n    Describes the outcome of the execution.\n    \"\"\"\n\n    successful: Dict[str, ExecutionMetadata] = field(default_factory=dict)\n    \"\"\"Steps which ran successfully or were found in the cache.\"\"\"\n\n    failed: Dict[str, ExecutionMetadata] = field(default_factory=dict)\n    \"\"\"Steps that failed.\"\"\"\n\n    not_run: Dict[str, ExecutionMetadata] = field(default_factory=dict)\n    \"\"\"Steps that were ignored (usually because of failed dependencies).\"\"\"\n\n    def display(self) -> None:\n        table = Table(caption_style=\"\")\n        table.add_column(\"Step Name\", justify=\"left\", style=\"cyan\")\n        table.add_column(\"Status\", justify=\"left\")\n        table.add_column(\"Results\", justify=\"left\")\n        all_steps = dict(self.successful)\n        all_steps.update(self.failed)\n        all_steps.update(self.not_run)\n        for step_name in sorted(all_steps):\n            status_str: str\n            result_str: str = \"[grey62]N/A[/]\"\n            if step_name in self.failed:\n                status_str = \"[red]\\N{ballot x} failed[/]\"\n                execution_metadata = self.failed[step_name]\n                if execution_metadata.logs_location is not None:\n                    result_str = f\"[cyan]{execution_metadata.logs_location}[/]\"\n            elif step_name in self.not_run:\n                status_str = \"[yellow]- not run[/]\"\n            elif step_name in self.successful:\n                status_str = \"[green]\\N{check mark} succeeded[/]\"\n                execution_metadata = self.successful[step_name]\n                if execution_metadata.result_location is not None:\n                    result_str = f\"[cyan]{execution_metadata.result_location}[/]\"\n                elif execution_metadata.logs_location is not None:\n                    result_str = f\"[cyan]{execution_metadata.logs_location}[/]\"\n            else:\n                continue\n\n            table.add_row(step_name, status_str, result_str)\n\n        caption_parts: List[str] = []\n        if self.failed:\n            caption_parts.append(f\"[red]\\N{ballot x}[/] [italic]{len(self.failed)} failed[/]\")\n        if self.successful:\n            caption_parts.append(\n                f\"[green]\\N{check mark}[/] [italic]{len(self.successful)} succeeded[/]\"\n            )\n        if self.not_run:\n            caption_parts.append(f\"[italic]{len(self.not_run)} not run[/]\")\n        table.caption = \", \".join(caption_parts)\n\n        if logger.isEnabledFor(logging.INFO):\n            logger.info(table)\n        elif cli_logger.isEnabledFor(logging.INFO):\n            cli_logger.info(table)\n        else:\n            get_console().print(table)\n\n\nclass Executor(Registrable):\n    \"\"\"\n    An ``Executor`` is a class that is responsible for running steps and caching their results.\n\n    This is the base class and default implementation, registered as \"default\".\n\n    .. note::\n        The ``parallelism`` parameter has no effect with this default :class:`Executor`,\n        but is part of the API because most subclass implementations allow configuring\n        parallelism.\n    \"\"\"\n\n    default_implementation = \"default\"\n\n    def __init__(\n        self,\n        workspace: Workspace,\n        include_package: Optional[Sequence[str]] = None,\n        parallelism: Optional[int] = None,\n    ) -> None:\n        self.workspace = workspace\n        self.include_package = include_package\n        self.parallelism = parallelism\n\n    def execute_step(self, step: \"Step\") -> None:\n        # Import included packages to find registered components.\n        if self.include_package is not None:\n            for package_name in self.include_package:\n                import_extra_module(package_name)\n\n        if step.cache_results:\n            step.ensure_result(self.workspace)\n        else:\n            step.result(self.workspace)\n\n    def execute_step_graph(\n        self, step_graph: StepGraph, run_name: Optional[str] = None\n    ) -> ExecutorOutput:\n        \"\"\"\n        Execute a :class:`~tango.step_graph.StepGraph`. This attempts to execute\n        every step in order. If a step fails, its dependent steps are not run,\n        but unrelated steps are still executed. Step failures will be logged, but\n        no exceptions will be raised.\n        \"\"\"\n        if self.parallelism is not None:\n            warnings.warn(\n                \"The 'parallelism' parameter has no effect with the default Executor. \"\n                \"If you want to run steps in parallel, consider using the MulticoreExecutor.\",\n                UserWarning,\n            )\n\n        successful: Dict[str, ExecutionMetadata] = {}\n        failed: Dict[str, ExecutionMetadata] = {}\n        not_run: Dict[str, ExecutionMetadata] = {}\n        uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps()\n\n        for step in step_graph.values():\n            if not step.cache_results and step not in uncacheable_leaf_steps:\n                # If a step is uncacheable and required for another step, it will be\n                # executed as part of the downstream step's execution.\n                continue\n            if any(dep.name in failed for dep in step.recursive_dependencies):\n                not_run[step.name] = ExecutionMetadata()\n            else:\n                try:\n                    self.execute_step(step)\n                    successful[step.name] = ExecutionMetadata(\n                        result_location=self.workspace.step_info(step).result_location\n                    )\n                except Exception as exc:\n                    failed[step.name] = ExecutionMetadata()\n                    log_exception(exc, logger)\n\n        return ExecutorOutput(successful=successful, failed=failed, not_run=not_run)\n\n    # NOTE: The reason for having this method instead of just using `execute_step()` to run\n    # a single step is that the certain executors, such as the BeakerExecutor, need to\n    # serialize steps somehow, and the easiest way to serialize a step is by serializing the\n    # whole step config (which can be accessed via the step graph).\n\n    def execute_sub_graph_for_steps(\n        self, step_graph: StepGraph, *step_names: str, run_name: Optional[str] = None\n    ) -> ExecutorOutput:\n        \"\"\"\n        Execute the sub-graph associated with a particular step in a\n        :class:`~tango.step_graph.StepGraph`.\n        \"\"\"\n        sub_graph = step_graph.sub_graph(*step_names)\n        return self.execute_step_graph(sub_graph, run_name=run_name)\n\n\nExecutor.register(\"default\")(Executor)\n"
  },
  {
    "path": "tango/executors/__init__.py",
    "content": "\"\"\"\nBuilt-in :class:`~tango.executor.Executor` implementations.\n\"\"\"\nfrom .multicore_executor import MulticoreExecutor\n"
  },
  {
    "path": "tango/executors/multicore_executor.py",
    "content": "import logging\nimport os\nimport subprocess\nimport time\nfrom tempfile import NamedTemporaryFile\nfrom typing import Dict, List, Optional, OrderedDict, Sequence, Set, TypeVar\n\nfrom tango.executor import ExecutionMetadata, Executor, ExecutorOutput\nfrom tango.step import Step\nfrom tango.step_graph import StepGraph\nfrom tango.step_info import StepState\nfrom tango.workspace import Workspace\n\nlogger = logging.getLogger(__name__)\n\nT = TypeVar(\"T\")\n\n\n@Executor.register(\"multicore\")\nclass MulticoreExecutor(Executor):\n    \"\"\"\n    A ``MulticoreExecutor`` runs the steps in parallel and caches their results.\n    \"\"\"\n\n    def __init__(\n        self,\n        workspace: Workspace,\n        include_package: Optional[Sequence[str]] = None,\n        parallelism: Optional[int] = 1,\n        num_tries_to_sync_states: int = 3,\n        wait_seconds_to_sync_states: int = 3,\n    ) -> None:\n        super().__init__(workspace, include_package=include_package, parallelism=parallelism or 1)\n        assert self.parallelism is not None\n        if self.parallelism < 0:\n            self.parallelism = min(32, os.cpu_count() or 1)\n\n        # Perhaps there's a better way to do this without these being passed as args.\n        self._num_tries_to_sync_states = num_tries_to_sync_states\n        self._wait_seconds_to_sync_states = wait_seconds_to_sync_states\n\n    def execute_step_graph(\n        self, step_graph: StepGraph, run_name: Optional[str] = None\n    ) -> ExecutorOutput:\n        \"\"\"\n        Execute a :class:`tango.step_graph.StepGraph`. This attempts to execute steps in parallel.\n        If a step fails, its dependent steps are not run, but unrelated steps are still executed.\n        Step failures will be logged, but no exceptions will be raised.\n        \"\"\"\n\n        _running: OrderedDict[str, subprocess.Popen] = OrderedDict({})\n        _successful: Dict[str, ExecutionMetadata] = {}\n        _failed: Dict[str, ExecutionMetadata] = {}\n        _queued_steps: List[str] = []\n\n        uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps()\n\n        def _sync_step_states() -> Dict[str, StepState]:\n            \"\"\"\n            Update the StepState info.\n            Although, this is not really elegant. The issue is as follows: The main multicore executor process\n            queues a step, and starts step execution in a different process. If we try to read the StepState\n            before that process has had time to update the StepState, the Workspace will throw the out of sync\n            error (IOError: process should be running but it's considered incomplete...).\n\n            Hence, we try to read a few times, so that the child process has time to update the step's state.\n            \"\"\"\n\n            attempts = 0\n            while attempts < self._num_tries_to_sync_states:\n                attempts += 1\n                try:\n                    step_states = {step.name: self._get_state(step) for step in step_graph.values()}\n                    break\n                except IOError:\n                    if attempts == self._num_tries_to_sync_states:\n                        raise\n                    step_states = {}\n                    time.sleep(self._wait_seconds_to_sync_states)\n            return step_states\n\n        def _has_incomplete_steps(step_states: Dict[str, StepState]) -> bool:\n            \"\"\"\n            Are there any steps in the graph that are currently:\n            1) running, or\n            2) queued, or\n            3) incomplete (with no failed dependencies).\n\n            If there are any failed dependencies for a step, it will never manage to run.\n            \"\"\"\n\n            def _failed_dependencies(step: Step) -> bool:\n                for dependency in step.recursive_dependencies:\n                    if (\n                        step_states[dependency.name] == StepState.FAILED\n                        or dependency.name in _failed\n                    ):\n                        return True\n                return False\n\n            uncacheable_leaf_step_names = {step.name for step in uncacheable_leaf_steps}\n            for step_name, step_state in step_states.items():\n                if (\n                    step_name in _running\n                    or step_name in _queued_steps\n                    or (\n                        # If the workspace already has a previous run, we disregard the failure.\n                        step_state in [StepState.INCOMPLETE, StepState.FAILED]\n                        and not _failed_dependencies(step_graph[step_name])\n                        # We check for failures in this run.\n                        and step_name not in _failed\n                    )\n                    or (\n                        # Uncacheable leaf steps need to run, but their StepState will always be UNCACHEABLE.\n                        step_name in uncacheable_leaf_step_names\n                        and step_name not in _successful\n                        and step_name not in _failed\n                        and not _failed_dependencies(step_graph[step_name])\n                    )\n                ):\n                    return True\n            return False\n\n        def _update_running_steps(step_states: Dict[str, StepState]) -> List[str]:\n            \"\"\"\n            Check the running processes for status. We use poll_status to check if the process ended,\n            but the StepState for checking completion/failure status, because after the process ends,\n            the lock release etc. sometimes takes a beat longer.\n            \"\"\"\n            done = []\n            errors = []\n            for step_name, process in _running.items():\n                poll_status = process.poll()\n                if poll_status is not None:\n                    # The step may have finished since we synced step states.\n                    if step_states[step_name] == StepState.RUNNING:\n                        step_states[step_name] = self._get_state(step_graph[step_name])\n\n                    if step_states[step_name] == StepState.UNCACHEABLE:\n                        if poll_status == 0:\n                            done.append(step_name)\n                        else:\n                            errors.append(step_name)\n                    elif step_states[step_name] == StepState.COMPLETED:\n                        done.append(step_name)\n                    elif (\n                        step_states[step_name] == StepState.FAILED\n                        or step_states[step_name] == StepState.INCOMPLETE\n                    ):\n                        # TODO: look into why the step status changes from running back to incomplete sometimes.\n                        # Possibly it's due to the workspace being aggressive in marking it as incomplete when\n                        # it thinks that the process is not running.\n                        errors.append(step_name)\n                    else:\n                        raise RuntimeError(\n                            f\"Step '{step_name}' has the state {step_states[step_name]}, \"\n                            \"but the corresponding process has ended!\"\n                        )\n\n            for step_name in done + errors:\n                _running.pop(step_name)\n\n            for step_name in done:\n                step = step_graph[step_name]\n                _successful[step_name] = ExecutionMetadata(\n                    result_location=None\n                    if not step.cache_results\n                    else self.workspace.step_info(step).result_location\n                )\n\n            for step_name in errors:\n                _failed[step_name] = ExecutionMetadata()\n\n            return errors\n\n        def _get_steps_to_run(step_states: Dict[str, StepState]) -> Set[str]:\n            \"\"\"\n            Returns the steps that can be queued to run immediately.\n            Criteria:\n                1) All dependencies are available.\n                2) Step is not already running or queued.\n                3) Step has not run in the past and failed.\n                4) Step's state is INCOMPLETE (or FAILED from a previous run), or\n                   step's state is UNCACHEABLE and it is a leaf step.\n\n            (We only run uncacheable steps if they are needed for another step downstream,\n            as part of the downstream step).\n            \"\"\"\n\n            def _are_dependencies_available(step: Step) -> bool:\n                for dependency in step.dependencies:\n                    if step_states[dependency.name] not in [\n                        StepState.COMPLETED,\n                        StepState.UNCACHEABLE,\n                    ]:\n                        return False\n                return True\n\n            to_run: Set[str] = set()\n            for step in step_graph.values():\n                if (\n                    _are_dependencies_available(step)\n                    and step.name not in _running  # Not already running.\n                    and step.name not in _queued_steps  # Not queued to run.\n                    and step.name not in _failed  # Not already failed.\n                    # See comment in _has_incomplete_steps\n                    and (\n                        step_states[step.name] in [StepState.INCOMPLETE, StepState.FAILED]\n                        or (\n                            step_states[step.name] == StepState.UNCACHEABLE\n                            and step in uncacheable_leaf_steps\n                            and step.name not in _successful\n                        )\n                    )\n                ):\n                    to_run.add(step.name)\n            return to_run\n\n        def _queue_step(step_name: str) -> None:\n            _queued_steps.append(step_name)\n            logger.debug(f\"Step {step_name} added to the queue for execution.\")\n\n        def _try_to_execute_next_step(config_path: str, run_name: Optional[str] = None) -> None:\n            \"\"\"\n            If there are queued steps, try to start processes for them (limited by `parallelism`).\n            \"\"\"\n            if len(_queued_steps) == 0:\n                logger.debug(\"No steps in queue!\")\n                return\n            if len(_running) < (self.parallelism or 1):\n                step_name = _queued_steps.pop(0)\n                command: List[str] = [\n                    \"tango\",\n                    \"--called-by-executor\",\n                    \"run\",\n                    config_path,\n                    \"-s\",\n                    step_name,\n                    \"-w\",\n                    self.workspace.url,\n                ]\n                if self.include_package is not None:\n                    for package in self.include_package:\n                        command += [\"-i\", package]\n                if run_name is not None:\n                    command += [\"-n\", run_name]\n                process = subprocess.Popen(command, shell=False)\n                _running[step_name] = process\n            else:\n                logger.debug(\n                    f\"{self.parallelism or 1} steps are already running. Will attempt to execute later.\"\n                )\n\n        # Creates a temporary file in which to store the config. This is passed as a command line\n        # argument to child step processes.\n        with NamedTemporaryFile(prefix=\"step-graph-to-file-run\", suffix=\".jsonnet\") as file_ref:\n            step_graph.to_file(file_ref.name, include_unique_id=True)\n            assert os.path.exists(file_ref.name)\n\n            step_states = _sync_step_states()\n\n            while _has_incomplete_steps(step_states):\n                # Cleanup previously running steps.\n                _update_running_steps(step_states)\n\n                # Get steps that are ready to run.\n                to_run = _get_steps_to_run(step_states)\n                if to_run:\n                    logger.debug(f\"Steps ready to run: {to_run}\")\n\n                for step_name in to_run:\n                    _queue_step(step_name)\n\n                # Begin processes for any queued steps (if not enough processes are already running).\n                while len(_queued_steps) > 0 and len(_running) < (self.parallelism or 1):\n                    _try_to_execute_next_step(config_path=file_ref.name, run_name=run_name)\n\n                # Re-sync the StepState info.\n                step_states = _sync_step_states()\n\n        assert not _running and not _queued_steps\n        _not_run: Dict[str, ExecutionMetadata] = {}\n        for step_name, step in step_graph.items():\n            if step_name in _successful or step_name in _failed:\n                # tried to execute directly\n                continue\n            elif not step.cache_results and step not in uncacheable_leaf_steps:\n                # uncacheable interior step; didn't execute directly.\n                continue\n            elif (\n                step.cache_results\n                and step_name in step_states\n                and step_states[step_name] == StepState.COMPLETED\n            ):\n                # step result was found in cache.\n                # NOTE: since neither `Step.result()` nor `Step.ensure_result()` will have been\n                # called, we invoke the CLI logger here to let users know that we didn't run this\n                # step because we found it in the cache.\n                step.log_cache_hit()\n                _successful[step_name] = ExecutionMetadata(\n                    result_location=self.workspace.step_info(step_graph[step_name]).result_location\n                )\n            else:\n                # step wasn't executed because parents failed, or\n                # step is uncacheable leaf step, so we do care about what happened to it.\n                _not_run[step_name] = ExecutionMetadata()\n\n        return ExecutorOutput(successful=_successful, failed=_failed, not_run=_not_run)\n\n    def _get_state(self, step: Step) -> StepState:\n        \"\"\"\n        Returns the StepState as determined by the workspace.\n        \"\"\"\n        return self.workspace.step_info(step).state\n"
  },
  {
    "path": "tango/format.py",
    "content": "import bz2\nimport dataclasses\nimport gzip\nimport importlib\nimport json\nimport logging\nimport lzma\nfrom abc import abstractmethod\nfrom os import PathLike\nfrom pathlib import Path\nfrom typing import (\n    IO,\n    Any,\n    Callable,\n    Dict,\n    Generic,\n    Iterable,\n    Iterator,\n    List,\n    Optional,\n    Sequence,\n    TypeVar,\n    Union,\n    cast,\n)\n\nimport dill\n\nfrom tango.common import DatasetDict, filename_is_safe\nfrom tango.common.aliases import PathOrStr\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.registrable import Registrable\nfrom tango.common.sequences import SqliteSparseSequence\n\nT = TypeVar(\"T\")\n\n\nclass Format(Registrable, Generic[T]):\n    \"\"\"\n    Formats write objects to directories and read them back out.\n\n    In the context of Tango, the objects that are written by formats are usually\n    the result of a :class:`~tango.step.Step`.\n    \"\"\"\n\n    VERSION: str = NotImplemented\n    \"\"\"\n    Formats can have versions. Versions are part of a step's unique signature, part of\n    :attr:`~tango.step.Step.unique_id`, so when a step's format changes,\n    that will cause the step to be recomputed.\n    \"\"\"\n\n    default_implementation = \"dill\"\n\n    @abstractmethod\n    def write(self, artifact: T, dir: PathOrStr):\n        \"\"\"Writes the ``artifact`` to the directory at ``dir``.\"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def read(self, dir: PathOrStr) -> T:\n        \"\"\"Reads an artifact from the directory at ``dir`` and returns it.\"\"\"\n        raise NotImplementedError()\n\n    def _to_params(self) -> Dict[str, Any]:\n        params_dict = super()._to_params()\n        for key in [\"logger\", \"__orig_class__\"]:\n            params_dict.pop(key, None)  # Removing unnecessary keys.\n        params_dict[\"type\"] = self.__module__ + \".\" + self.__class__.__qualname__\n        return params_dict\n\n\n_OPEN_FUNCTIONS: Dict[Optional[str], Callable[[PathLike, str], IO]] = {\n    None: open,\n    \"None\": open,\n    \"none\": open,\n    \"null\": open,\n    \"gz\": gzip.open,  # type: ignore\n    \"gzip\": gzip.open,  # type: ignore\n    \"bz\": bz2.open,  # type: ignore\n    \"bz2\": bz2.open,  # type: ignore\n    \"bzip\": bz2.open,  # type: ignore\n    \"bzip2\": bz2.open,  # type: ignore\n    \"lzma\": lzma.open,\n}\n\n_SUFFIXES: Dict[Callable, str] = {\n    open: \"\",\n    gzip.open: \".gz\",\n    bz2.open: \".bz2\",\n    lzma.open: \".xz\",\n}\n\n\ndef _open_compressed(filename: PathOrStr, mode: str) -> IO:\n    open_fn: Callable\n    filename = str(filename)\n    for open_fn, suffix in _SUFFIXES.items():\n        if len(suffix) > 0 and filename.endswith(suffix):\n            break\n    else:\n        open_fn = open\n    return open_fn(filename, mode)\n\n\n@Format.register(\"dill\")\nclass DillFormat(Format[T], Generic[T]):\n    \"\"\"\n    This format writes the artifact as a single file called \"data.dill\" using dill\n    (a drop-in replacement for pickle). Optionally, it can compress the data.\n\n    This is very flexible, but not always the fastest.\n\n    .. tip::\n        This format has special support for iterables. If you write an iterator, it will consume the\n        iterator. If you read an iterator, it will read the iterator lazily.\n\n    \"\"\"\n\n    VERSION = \"001\"\n\n    def __init__(self, compress: Optional[str] = None):\n        if compress not in _OPEN_FUNCTIONS:\n            raise ConfigurationError(f\"The {compress} compression format does not exist.\")\n        self.compress = compress\n\n    def write(self, artifact: T, dir: PathOrStr):\n        filename = self._get_artifact_path(dir)\n        open_method = _OPEN_FUNCTIONS[self.compress]\n        with open_method(filename, \"wb\") as f:\n            pickler = dill.Pickler(file=f)\n            pickler.dump(self.VERSION)\n            if hasattr(artifact, \"__next__\"):\n                pickler.dump(True)\n                for item in cast(Iterable, artifact):\n                    pickler.dump(item)\n            else:\n                pickler.dump(False)\n                pickler.dump(artifact)\n\n    def read(self, dir: PathOrStr) -> T:\n        filename = self._get_artifact_path(dir)\n        open_method = _OPEN_FUNCTIONS[self.compress]\n        with open_method(filename, \"rb\") as f:\n            unpickler = dill.Unpickler(file=f)\n            version = unpickler.load()\n            if version > self.VERSION:\n                raise ValueError(\n                    f\"File {filename} is too recent for this version of {self.__class__}.\"\n                )\n            iterator = unpickler.load()\n            if iterator:\n                return DillFormatIterator(filename)  # type: ignore\n            else:\n                return unpickler.load()\n\n    def _get_artifact_path(self, dir: PathOrStr) -> Path:\n        return Path(dir) / (\"data.dill\" + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]])\n\n\nclass DillFormatIterator(Iterator[T], Generic[T]):\n    \"\"\"\n    An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.DillFormat.read`.\n    \"\"\"\n\n    def __init__(self, filename: PathOrStr):\n        self.f: Optional[IO[Any]] = _open_compressed(filename, \"rb\")\n        self.unpickler = dill.Unpickler(self.f)\n        version = self.unpickler.load()\n        if version > DillFormat.VERSION:\n            raise ValueError(f\"File {filename} is too recent for this version of {self.__class__}.\")\n        iterator = self.unpickler.load()\n        if not iterator:\n            raise ValueError(\n                f\"Tried to open {filename} as an iterator, but it does not store an iterator.\"\n            )\n\n    def __iter__(self) -> Iterator[T]:\n        return self\n\n    def __next__(self) -> T:\n        if self.f is None:\n            raise StopIteration()\n        try:\n            return self.unpickler.load()\n        except EOFError:\n            self.f.close()\n            self.f = None\n            raise StopIteration()\n\n\n@Format.register(\"json\")\nclass JsonFormat(Format[T], Generic[T]):\n    \"\"\"This format writes the artifact as a single file in json format.\n    Optionally, it can compress the data. This is very flexible, but not always the fastest.\n\n    .. tip::\n        This format has special support for iterables. If you write an iterator, it will consume the\n        iterator. If you read an iterator, it will read the iterator lazily.\n    \"\"\"\n\n    VERSION = \"002\"\n\n    def __init__(self, compress: Optional[str] = None):\n        self.logger = logging.getLogger(self.__class__.__name__)\n        if compress not in _OPEN_FUNCTIONS:\n            raise ConfigurationError(f\"The {compress} compression format does not exist.\")\n        self.compress = compress\n\n    @staticmethod\n    def _encoding_fallback(unencodable: Any):\n        try:\n            import torch\n\n            if isinstance(unencodable, torch.Tensor):\n                if len(unencodable.shape) == 0:\n                    return unencodable.item()\n                else:\n                    raise TypeError(\n                        \"Tensors must have 1 element and no dimensions to be JSON serializable.\"\n                    )\n        except ImportError:\n            pass\n\n        if dataclasses.is_dataclass(unencodable):\n            result = dataclasses.asdict(unencodable)\n            module = type(unencodable).__module__\n            qualname = type(unencodable).__qualname__\n            if module == \"builtins\":\n                result[\"_dataclass\"] = qualname\n            else:\n                result[\"_dataclass\"] = [module, qualname]\n            return result\n\n        raise TypeError(f\"Object of type {type(unencodable)} is not JSON serializable\")\n\n    @staticmethod\n    def _decoding_fallback(o: Dict) -> Any:\n        if \"_dataclass\" in o:\n            classname: Union[str, List[str]] = o.pop(\"_dataclass\")\n            if isinstance(classname, list) and len(classname) == 2:\n                module, classname = classname\n                constructor: Callable = importlib.import_module(module)  # type: ignore\n                for item in classname.split(\".\"):\n                    constructor = getattr(constructor, item)\n            elif isinstance(classname, str):\n                constructor = globals()[classname]\n            else:\n                raise RuntimeError(f\"Could not parse {classname} as the name of a dataclass.\")\n            return constructor(**o)\n        return o\n\n    def write(self, artifact: T, dir: PathOrStr):\n        open_method = _OPEN_FUNCTIONS[self.compress]\n        if hasattr(artifact, \"__next__\"):\n            filename = self._get_artifact_path(dir, iterator=True)\n            with open_method(filename, \"wt\") as f:\n                for item in cast(Iterable, artifact):\n                    json.dump(item, f, default=self._encoding_fallback)\n                    f.write(\"\\n\")\n        else:\n            filename = self._get_artifact_path(dir, iterator=False)\n            with open_method(filename, \"wt\") as f:\n                json.dump(artifact, f, default=self._encoding_fallback)\n\n    def read(self, dir: PathOrStr) -> T:\n        iterator_filename = self._get_artifact_path(dir, iterator=True)\n        iterator_exists = iterator_filename.exists()\n        non_iterator_filename = self._get_artifact_path(dir, iterator=False)\n        non_iterator_exists = non_iterator_filename.exists()\n\n        if iterator_exists and non_iterator_exists:\n            self.logger.warning(\n                \"Both %s and %s exist. Ignoring %s.\",\n                iterator_filename,\n                non_iterator_filename,\n                iterator_filename,\n            )\n            iterator_exists = False\n\n        if not iterator_exists and not non_iterator_exists:\n            raise IOError(\"Attempting to read non-existing data from %s\", dir)\n        if iterator_exists and not non_iterator_exists:\n            return JsonFormatIterator(iterator_filename)  # type: ignore\n        elif not iterator_exists and non_iterator_exists:\n            open_method = _OPEN_FUNCTIONS[self.compress]\n            with open_method(non_iterator_filename, \"rt\") as f:\n                return json.load(f, object_hook=self._decoding_fallback)\n        else:\n            raise RuntimeError(\"This should be impossible.\")\n\n    def _get_artifact_path(self, dir: PathOrStr, iterator: bool = False) -> Path:\n        return Path(dir) / (\n            (\"data.jsonl\" if iterator else \"data.json\") + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]]\n        )\n\n\nclass JsonFormatIterator(Iterator[T], Generic[T]):\n    \"\"\"\n    An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.JsonFormat.read`.\n    \"\"\"\n\n    def __init__(self, filename: PathOrStr):\n        self.f: Optional[IO[Any]] = _open_compressed(filename, \"rt\")\n\n    def __iter__(self) -> Iterator[T]:\n        return self\n\n    def __next__(self) -> T:\n        if self.f is None:\n            raise StopIteration()\n        try:\n            line = self.f.readline()\n            if len(line) <= 0:\n                raise EOFError()\n            return json.loads(line, object_hook=JsonFormat._decoding_fallback)\n        except EOFError:\n            self.f.close()\n            self.f = None\n            raise StopIteration()\n\n\n@Format.register(\"text\")\nclass TextFormat(Format[Union[str, Iterable[str]]]):\n    \"\"\"This format writes the artifact as a single file in text format.\n    Optionally, it can compress the data. This is very flexible, but not always the fastest.\n\n    This format can only write strings, or iterable of strings.\n\n    .. tip::\n        This format has special support for iterables. If you write an iterator, it will consume the\n        iterator. If you read an iterator, it will read the iterator lazily.\n\n        Be aware that if your strings contain newlines, you will read out more strings than you wrote.\n        For this reason, it's often advisable to use :class:`JsonFormat` instead. With :class:`JsonFormat`,\n        all special characters are escaped, strings are quoted, but it's all still human-readable.\n    \"\"\"\n\n    VERSION = \"001\"\n\n    def __init__(self, compress: Optional[str] = None):\n        self.logger = logging.getLogger(self.__class__.__name__)\n        if compress not in _OPEN_FUNCTIONS:\n            raise ConfigurationError(f\"The {compress} compression format does not exist.\")\n        self.compress = compress\n\n    def write(self, artifact: Union[str, Iterable[str]], dir: PathOrStr):\n        open_method = _OPEN_FUNCTIONS[self.compress]\n        if hasattr(artifact, \"__next__\"):\n            filename = self._get_artifact_path(dir, iterator=True)\n            with open_method(filename, \"wt\") as f:\n                for item in cast(Iterable, artifact):\n                    f.write(str(item))\n                    f.write(\"\\n\")\n        else:\n            filename = self._get_artifact_path(dir, iterator=False)\n            with open_method(filename, \"wt\") as f:\n                f.write(str(artifact))\n\n    def read(self, dir: PathOrStr) -> Union[str, Iterable[str]]:\n        iterator_filename = self._get_artifact_path(dir, iterator=True)\n        iterator_exists = iterator_filename.exists()\n        non_iterator_filename = self._get_artifact_path(dir, iterator=False)\n        non_iterator_exists = non_iterator_filename.exists()\n\n        if iterator_exists and non_iterator_exists:\n            self.logger.warning(\n                \"Both %s and %s exist. Ignoring %s.\",\n                iterator_filename,\n                non_iterator_filename,\n                iterator_filename,\n            )\n            iterator_exists = False\n\n        if not iterator_exists and not non_iterator_exists:\n            raise IOError(\"Attempting to read non-existing data from %s\", dir)\n        if iterator_exists and not non_iterator_exists:\n            return TextFormatIterator(iterator_filename)  # type: ignore\n        elif not iterator_exists and non_iterator_exists:\n            open_method = _OPEN_FUNCTIONS[self.compress]\n            with open_method(non_iterator_filename, \"rt\") as f:\n                return f.read()\n        else:\n            raise RuntimeError(\"This should be impossible.\")\n\n    def _get_artifact_path(self, dir: PathOrStr, iterator: bool = False) -> Path:\n        return Path(dir) / (\n            (\"texts.txt\" if iterator else \"text.txt\") + _SUFFIXES[_OPEN_FUNCTIONS[self.compress]]\n        )\n\n\nclass TextFormatIterator(Iterator[str]):\n    \"\"\"\n    An ``Iterator`` class that is used to return an iterator from :meth:`tango.format.TextFormat.read`.\n    \"\"\"\n\n    def __init__(self, filename: PathOrStr):\n        self.f: Optional[IO[Any]] = _open_compressed(filename, \"rt\")\n\n    def __iter__(self) -> Iterator[str]:\n        return self\n\n    def __next__(self) -> str:\n        if self.f is None:\n            raise StopIteration()\n        try:\n            line = self.f.readline()\n            if len(line) <= 0:\n                raise EOFError()\n            if line.endswith(\"\\n\"):\n                line = line[:-1]\n            return line\n        except EOFError:\n            self.f.close()\n            self.f = None\n            raise StopIteration()\n\n\n@Format.register(\"sqlite_sequence\")\nclass SqliteSequenceFormat(Format[Sequence[T]]):\n    VERSION = \"003\"\n\n    FILENAME = \"data.sqlite\"\n\n    def write(self, artifact: Sequence[T], dir: Union[str, PathLike]):\n        dir = Path(dir)\n        try:\n            (dir / self.FILENAME).unlink()\n        except FileNotFoundError:\n            pass\n        if isinstance(artifact, SqliteSparseSequence):\n            artifact.copy_to(dir / self.FILENAME)\n        else:\n            sqlite = SqliteSparseSequence(dir / self.FILENAME)\n            sqlite.extend(artifact)\n\n    def read(self, dir: Union[str, PathLike]) -> Sequence[T]:\n        dir = Path(dir)\n        return SqliteSparseSequence(dir / self.FILENAME, read_only=True)\n\n\n@Format.register(\"sqlite\")\nclass SqliteDictFormat(Format[DatasetDict]):\n    \"\"\"This format works specifically on results of type :class:`~tango.common.DatasetDict`. It writes those\n    datasets into Sqlite databases.\n\n    During reading, the advantage is that the dataset can be read lazily. Reading a result that is stored\n    in :class:`SqliteDictFormat` takes milliseconds. No actual reading takes place until you access individual\n    instances.\n\n    During writing, you have to take some care to take advantage of the same trick. Recall that\n    :class:`~tango.DatasetDict` is basically a map, mapping split names to lists of instances. If you ensure\n    that those lists of instances are of type :class:`~tango.common.sequences.SqliteSparseSequence`, then writing\n    the results in :class:`SqliteDictFormat` can in many cases be instantaneous.\n\n    Here is an example of the pattern to use to make writing fast:\n\n    .. code-block:: Python\n\n        @Step.register(\"my_step\")\n        class MyStep(Step[DatasetDict]):\n\n            FORMAT: Format = SqliteDictFormat()\n            VERSION = \"001\"\n\n            def run(self, ...) -> DatasetDict:\n                result: Dict[str, Sequence] = {}\n                for split_name in my_list_of_splits:\n                    output_split = SqliteSparseSequence(self.work_dir / f\"{split_name}.sqlite\")\n                    for instance in instances:\n                        output_split.append(instance)\n                    result[split_name] = output_split\n\n                metadata = {}\n                return DatasetDict(result, metadata)\n\n    Observe how for each split, we create a :class:`~tango.common.sequences.SqliteSparseSequence` in the step's\n    work directory (accessible with :meth:`~tango.step.Step.work_dir`). This has the added advantage that if the\n    step fails and you have to re-run it, the previous results that were already written to the\n    :class:`~tango.common.sequences.SqliteSparseSequence` are still there. You could replace the inner ``for``\n    loop like this to take advantage:\n\n    .. code-block:: Python\n\n        output_split = SqliteSparseSequence(self.work_dir / f\"{split_name}.sqlite\")\n        for instance in instances[len(output_split):]:      # <-- here is the difference\n            output_split.append(instance)\n        result[split_name] = output_split\n\n    This works because when you re-run the step, the work directory will still be there, so ``output_split`` is\n    not empty when you open it.\n    \"\"\"\n\n    VERSION = \"003\"\n\n    def write(self, artifact: DatasetDict, dir: Union[str, PathLike]):\n        dir = Path(dir)\n        with gzip.open(dir / \"metadata.dill.gz\", \"wb\") as f:\n            dill.dump(artifact.metadata, f)\n        for split_name, split in artifact.splits.items():\n            filename = f\"{split_name}.sqlite\"\n            if not filename_is_safe(filename):\n                raise ValueError(f\"{split_name} is not a valid name for a split.\")\n            try:\n                (dir / filename).unlink()\n            except FileNotFoundError:\n                pass\n            if isinstance(split, SqliteSparseSequence):\n                split.copy_to(dir / filename)\n            else:\n                sqlite = SqliteSparseSequence(dir / filename)\n                sqlite.extend(split)\n\n    def read(self, dir: Union[str, PathLike]) -> DatasetDict:\n        dir = Path(dir)\n        with gzip.open(dir / \"metadata.dill.gz\", \"rb\") as f:\n            metadata = dill.load(f)\n        splits = {\n            filename.stem: SqliteSparseSequence(filename, read_only=True)\n            for filename in dir.glob(\"*.sqlite\")\n        }\n        return DatasetDict(metadata=metadata, splits=splits)\n"
  },
  {
    "path": "tango/integrations/__init__.py",
    "content": "\"\"\"\nIn :mod:`tango.integrations` we provide many ready-to-use `component <../components/index.html>`_\nimplementations for leveraging the functionality from popular libraries.\n\n.. tip::\n    All registered components will be registered under a name that starts with the name of the integration module,\n    possibly followed by a double colon (\"::\") and another identifier if there are multiple registered\n    components of a given type.\n\n    For example, the :class:`~tango.integrations.datasets.LoadDataset` step in the `🤗 Datasets <datasets.html>`_\n    integration is registered under the name \"datasets::load\", and the\n    :class:`~tango.integrations.torch.TorchFormat` format in the `PyTorch <torch.html>`_ integration\n    is registered under the name \"torch\".\n\n\"\"\"\n"
  },
  {
    "path": "tango/integrations/beaker/__init__.py",
    "content": "\"\"\"\n.. important::\n    To use this integration you should install ``tango`` with the \"beaker\" extra\n    (e.g. ``pip install tango[beaker]``) or just install the `beaker-py <https://beaker-py.readthedocs.io>`_\n    library after the fact (e.g. ``pip install beaker-py``).\n\nComponents for Tango integration with `Beaker <https://beaker.org/>`_.\n\"\"\"\n\nfrom tango.common.exceptions import IntegrationMissingError\n\ntry:\n    from beaker import Beaker\nexcept (ModuleNotFoundError, ImportError):\n    raise IntegrationMissingError(\"beaker\", dependencies={\"beaker-py\"})\n\nfrom .executor import (\n    BeakerExecutor,\n    BeakerScheduler,\n    ResourceAssignment,\n    ResourceAssignmentError,\n    SimpleBeakerScheduler,\n    UnrecoverableResourceAssignmentError,\n)\nfrom .step_cache import BeakerStepCache\nfrom .workspace import BeakerWorkspace\n\n__all__ = [\n    \"BeakerStepCache\",\n    \"BeakerWorkspace\",\n    \"BeakerExecutor\",\n    \"BeakerScheduler\",\n    \"SimpleBeakerScheduler\",\n    \"ResourceAssignment\",\n    \"ResourceAssignmentError\",\n    \"UnrecoverableResourceAssignmentError\",\n]\n"
  },
  {
    "path": "tango/integrations/beaker/common.py",
    "content": "import atexit\nimport json\nimport logging\nimport os.path\nimport tempfile\nimport time\nimport urllib\nimport urllib.parse\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Union\n\nfrom beaker import Beaker\nfrom beaker import Dataset as BeakerDataset\nfrom beaker import DatasetConflict, DatasetNotFound, Experiment, ExperimentNotFound\n\nfrom tango.common.remote_utils import RemoteConstants\nfrom tango.step import Step\nfrom tango.step_info import StepInfo\nfrom tango.version import VERSION\n\nlogger = logging.getLogger(__name__)\n\n\nclass Constants(RemoteConstants):\n    ENTRYPOINT_DATASET_PREFIX = \"tango-entrypoint-\"\n    BEAKER_TOKEN_SECRET_NAME: str = \"BEAKER_TOKEN\"\n    GOOGLE_TOKEN_SECRET_NAME: str = \"GOOGLE_TOKEN\"\n    DEFAULT_GOOGLE_CREDENTIALS_FILE: str = os.path.expanduser(\n        os.path.join(\"~\", \".config\", \"gcloud\", \"application_default_credentials.json\")\n    )\n    ENTRYPOINT_DIR: str = \"/tango/entrypoint\"\n    ENTRYPOINT_FILENAME: str = \"entrypoint.sh\"\n\n\ndef get_client(beaker_workspace: Optional[str] = None, **kwargs) -> Beaker:\n    user_agent = f\"tango v{VERSION}\"\n    if beaker_workspace is not None:\n        return Beaker.from_env(\n            default_workspace=beaker_workspace,\n            session=True,\n            user_agent=user_agent,\n            **kwargs,\n        )\n    else:\n        return Beaker.from_env(session=True, user_agent=user_agent, **kwargs)\n\n\ndef dataset_url(beaker: Beaker, dataset: Optional[str] = None) -> str:\n    # this just creates a string url.\n    workspace_url = beaker.workspace.url()\n    if dataset:\n        return (\n            workspace_url\n            + \"/datasets?\"\n            + urllib.parse.urlencode(\n                {\n                    \"text\": dataset,\n                    \"committed\": \"false\",\n                }\n            )\n        )\n    return workspace_url\n\n\nclass BeakerStepLock:\n    METADATA_FNAME = \"metadata.json\"\n\n    def __init__(\n        self,\n        beaker: Beaker,\n        step: Union[str, StepInfo, Step],\n        current_beaker_experiment: Optional[Experiment] = None,\n    ):\n        self._beaker = beaker\n        self._step_id = step if isinstance(step, str) else step.unique_id\n        self._lock_dataset_name = RemoteConstants.step_lock_artifact_name(step)\n        self._lock_dataset: Optional[BeakerDataset] = None\n        self._current_beaker_experiment = current_beaker_experiment\n        self.lock_dataset_url = dataset_url(beaker, self._lock_dataset_name)\n\n    @property\n    def metadata(self) -> Dict[str, Any]:\n        return {\n            \"beaker_experiment\": None\n            if not self._current_beaker_experiment\n            else self._current_beaker_experiment.id\n        }\n\n    def _last_metadata(self) -> Optional[Dict[str, Any]]:\n        try:\n            metadata_bytes = self._beaker.dataset.get_file(\n                self._lock_dataset_name, self.METADATA_FNAME, quiet=True\n            )\n            metadata = json.loads(metadata_bytes)\n            return metadata\n        except (DatasetNotFound, FileNotFoundError):\n            return None\n\n    def _acquiring_job_is_done(self) -> bool:\n        last_metadata = self._last_metadata()\n        if last_metadata is None:\n            return False\n\n        last_experiment_id = last_metadata.get(\"beaker_experiment\")\n        if last_experiment_id is None:\n            return False\n\n        try:\n            last_experiment = self._beaker.experiment.get(last_experiment_id)\n            if (\n                self._current_beaker_experiment is not None\n                and self._current_beaker_experiment.id == last_experiment_id\n            ):\n                # This means a previous job for this experiment was preempted and\n                # it didn't clean up after itself.\n                return True\n            else:\n                job = self._beaker.experiment.latest_job(last_experiment)\n                return False if job is None else job.is_done\n        except ExperimentNotFound:\n            # Experiment must have been deleted.\n            return True\n        except ValueError:\n            return False\n\n    def acquire(self, timeout=None, poll_interval: float = 2.0, log_interval: float = 30.0) -> None:\n        if self._lock_dataset is not None:\n            return\n        start = time.monotonic()\n        last_logged = None\n        while timeout is None or (time.monotonic() - start < timeout):\n            try:\n                self._lock_dataset = self._beaker.dataset.create(\n                    self._lock_dataset_name, commit=False\n                )\n\n                atexit.register(self.release)\n\n                # Write metadata.\n                with tempfile.TemporaryDirectory() as tmp_dir_name:\n                    tmp_dir = Path(tmp_dir_name)\n                    metadata_path = tmp_dir / self.METADATA_FNAME\n                    with open(metadata_path, \"w\") as f:\n                        json.dump(self.metadata, f)\n                    self._beaker.dataset.sync(self._lock_dataset, metadata_path, quiet=True)\n            except DatasetConflict:\n                # Check if existing lock was created from a Beaker experiment.\n                # If it was, and the experiment is no-longer running, we can safely\n                # delete it.\n                if self._acquiring_job_is_done():\n                    self._beaker.dataset.delete(self._lock_dataset_name)\n                    continue\n\n                now = time.monotonic()\n                if last_logged is None or now - last_logged >= log_interval:\n                    logger.warning(\n                        \"Waiting to acquire lock dataset for step '%s':\\n\\n%s\\n\\n\"\n                        \"This probably means the step is being run elsewhere, but if you're sure it isn't \"\n                        \"you can just delete the lock dataset.\",\n                        self._step_id,\n                        self.lock_dataset_url,\n                    )\n                    last_logged = now\n                time.sleep(poll_interval)\n                continue\n            else:\n                break\n        else:\n            raise TimeoutError(\n                f\"Timeout error occurred while waiting to acquire dataset lock for step '{self._step_id}':\\n\\n\"\n                f\"{self.lock_dataset_url}\\n\\n\"\n                f\"This probably means the step is being run elsewhere, but if you're sure it isn't you can \"\n                f\"just delete the lock dataset.\"\n            )\n\n    def release(self):\n        if self._lock_dataset is not None:\n            try:\n                self._beaker.dataset.delete(self._lock_dataset)\n            except DatasetNotFound:\n                # Dataset must have been manually deleted.\n                pass\n            self._lock_dataset = None\n            atexit.unregister(self.release)\n\n    def __del__(self):\n        self.release()\n"
  },
  {
    "path": "tango/integrations/beaker/entrypoint.sh",
    "content": "#!/bin/bash\n#\n# This is the entrypoint script that the Beaker Executor uses when it runs a step\n# on Beaker.\n# It will work on any Docker image that has bash and conda / miniconda installed.\n\nset -eo pipefail\n\n# Ensure we have all the environment variables we need.\nfor env_var in \"$GITHUB_TOKEN\" \"$GITHUB_REPO\" \"$GIT_REF\"; do\n    if [[ -z \"$env_var\" ]]; then\n        echo >&2 \"error: required environment variable is empty\"\n        exit 1\n    fi\ndone\n\n# Initialize conda for bash.\n# See https://stackoverflow.com/a/58081608/4151392\neval \"$(command conda 'shell.bash' 'hook' 2> /dev/null)\"\n\necho \"\n[TANGO] [1/3] Installing prerequisites...\n\"\n\n# Install GitHub CLI.\nif ! command -v gh &> /dev/null; then\n    conda install gh --channel conda-forge\nfi\n\n# Configure git to use GitHub CLI as a credential helper so that we can clone private repos.\ngh auth setup-git\n\necho \"\n[TANGO] [2/3] Cloning source code from '$GITHUB_REPO'...\n\"\n\n# Clone the repo and checkout the target commit.\ngh repo clone \"$GITHUB_REPO\" src\ncd src\ngit checkout \"$GIT_REF\"\n\necho \"\n[TANGO] [3/3] Reconstructing Python env...\n\"\n\nif [[ -z \"$VENV_NAME\" ]]; then\n    VENV_NAME=venv\nfi\nif [[ -z \"$CONDA_ENV_FILE\" ]]; then\n    # shellcheck disable=SC2296\n    CONDA_ENV_FILE=\"environment.yml\"\nfi\nif [[ -z \"$PIP_REQUIREMENTS_FILE\" ]]; then\n    # shellcheck disable=SC2296\n    PIP_REQUIREMENTS_FILE=\"requirements.txt\"\nfi\n\nif conda activate $VENV_NAME &>/dev/null; then\n    echo \"[TANGO] Using existing conda environment '$VENV_NAME'\"\n    # The virtual environment already exists. Possibly update it based on an environment file.\n    if [[ -f \"$CONDA_ENV_FILE\" ]]; then\n        echo \"[TANGO] Updating environment from conda env file '$CONDA_ENV_FILE'...\"\n        conda env update -f \"$CONDA_ENV_FILE\"\n    fi\nelse\n    # The virtual environment doesn't exist yet. Create it.\n    if [[ -f \"$CONDA_ENV_FILE\" ]]; then\n        # Create from the environment file.\n        echo \"[TANGO] Initializing environment from conda env file '$CONDA_ENV_FILE'...\"\n        conda env create -n \"$VENV_NAME\" -f \"$CONDA_ENV_FILE\" \n    elif [[ -z \"$PYTHON_VERSION\" ]]; then\n        # Create a new empty environment with the whatever the default Python version is.\n        echo \"[TANGO] Initializing environment with default Python version...\"\n        conda create -n \"$VENV_NAME\" pip\n    else\n        # Create a new empty environment with the specific Python version.\n        echo \"[TANGO] Initializing environment with Python $PYTHON_VERSION...\"\n        conda create -n \"$VENV_NAME\" \"python=$PYTHON_VERSION\" pip\n    fi\n    conda activate \"$VENV_NAME\"\nfi\n\n# Every time Beaker changes their APIs, we need to upgrade beaker-py. This happens all the\n# time, so we make sure we have the latest.\n# We do this when the conda environment is up, but before the requirements, so that\n# requirements can request a particular beaker-py version if they want.\npip install --upgrade beaker-py\n\nif [[ -z \"$INSTALL_CMD\" ]]; then\n    # Check for a 'requirements.txt' and/or 'setup.py/pyproject.toml/setup.cfg' file.\n    if ( [[ -f 'setup.py' ]] || [[ -f 'pyproject.toml' ]] || [[ -f 'setup.cfg' ]] ) && [[ -f \"$PIP_REQUIREMENTS_FILE\" ]]; then\n        echo \"[GANTRY] Installing local project and packages from '$PIP_REQUIREMENTS_FILE'...\"\n        pip install . -r \"$PIP_REQUIREMENTS_FILE\"\n    elif ( [[ -f 'setup.py' ]] || [[ -f 'pyproject.toml' ]] || [[ -f 'setup.cfg' ]] ); then\n        echo \"[GANTRY] Installing local project...\"\n        pip install .\n    elif [[ -f \"$PIP_REQUIREMENTS_FILE\" ]]; then\n        echo \"[GANTRY] Installing packages from '$PIP_REQUIREMENTS_FILE'...\"\n        pip install -r \"$PIP_REQUIREMENTS_FILE\"\n    fi\nelse\n    echo \"[TANGO] Installing packages with given command: $INSTALL_CMD\"\n    eval \"$INSTALL_CMD\"\nfi\n\nPYTHONPATH=\"$(pwd)\"\nexport PYTHONPATH\n\necho \"\nEnvironment info:\n\"\n\necho \"Using $(python --version) from $(which python)\"\necho \"Packages:\"\nif which sed >/dev/null; then\n    pip freeze | sed 's/^/- /'\nelse\n    pip freeze\nfi\n\necho \"\n[TANGO] Setup complete ✓\n\"\n\n# Execute the arguments to this script as commands themselves.\nexec \"$@\"\n"
  },
  {
    "path": "tango/integrations/beaker/executor.py",
    "content": "import json\nimport logging\nimport os\nimport threading\nimport time\nimport uuid\nimport warnings\nfrom abc import abstractmethod\nfrom typing import Dict, List, NamedTuple, Optional, Sequence, Set, Tuple, Union\n\nfrom beaker import (\n    Beaker,\n    DataMount,\n    Dataset,\n    DatasetConflict,\n    DatasetNotFound,\n    Digest,\n    EnvVar,\n    Experiment,\n    ExperimentNotFound,\n    ExperimentSpec,\n    JobFailedError,\n    JobTimeoutError,\n    NodeResources,\n    Priority,\n    TaskResources,\n    TaskSpec,\n    TaskStoppedError,\n)\nfrom git import Git, GitCommandError, InvalidGitRepositoryError, Repo\n\nfrom tango.common.exceptions import (\n    CancellationError,\n    ConfigurationError,\n    ExecutorError,\n    RunCancelled,\n)\nfrom tango.common.logging import cli_logger, log_exception\nfrom tango.common.registrable import Registrable\nfrom tango.executor import ExecutionMetadata, Executor, ExecutorOutput\nfrom tango.step import Step\nfrom tango.step_graph import StepGraph\nfrom tango.step_info import GitMetadata\nfrom tango.version import VERSION\nfrom tango.workspace import Workspace\n\nfrom .common import Constants, get_client\n\nlogger = logging.getLogger(__name__)\n\n\nclass StepFailedError(ExecutorError):\n    def __init__(self, msg: str, experiment_url: str):\n        super().__init__(msg)\n        self.experiment_url = experiment_url\n\n\nclass ResourceAssignmentError(ExecutorError):\n    \"\"\"\n    Raised when a scheduler can't find enough free resources at the moment to run a step.\n    \"\"\"\n\n\nclass UnrecoverableResourceAssignmentError(ExecutorError):\n    \"\"\"\n    An unrecoverable version of :class:`ResourceAssignmentError`. Raises this\n    from a :class:`BeakerScheduler` will cause the executor to fail.\n    \"\"\"\n\n\nclass ResourceAssignment(NamedTuple):\n    \"\"\"\n    Resources assigned to a step.\n    \"\"\"\n\n    cluster: Union[str, List[str]]\n    \"\"\"\n    The cluster(s) to use to execute the step.\n    \"\"\"\n\n    resources: TaskResources\n    \"\"\"\n    The compute resources on the cluster to allocate for execution of the step.\n    \"\"\"\n\n    priority: Union[str, Priority]\n    \"\"\"\n    The priority to execute the step with.\n    \"\"\"\n\n\nclass BeakerScheduler(Registrable):\n    \"\"\"\n    A :class:`BeakerScheduler` is responsible for determining which resources and priority to\n    assign to the execution of a step.\n    \"\"\"\n\n    default_implementation = \"simple\"\n    \"\"\"\n    The default implementation is :class:`SimpleBeakerScheduler`.\n    \"\"\"\n\n    def __init__(self):\n        self._beaker: Optional[Beaker] = None\n\n    @property\n    def beaker(self) -> Beaker:\n        if self._beaker is None:\n            raise ValueError(\"'beaker' client has not be assigned to scheduler yet!\")\n        return self._beaker\n\n    @beaker.setter\n    def beaker(self, beaker: Beaker) -> None:\n        self._beaker = beaker\n\n    @abstractmethod\n    def schedule(self, step: Step) -> ResourceAssignment:\n        \"\"\"\n        Determine the :class:`ResourceAssignment` for a step.\n\n        :raises ResourceAssignmentError: If the scheduler can't find enough free\n            resources at the moment to run the step.\n        \"\"\"\n        raise NotImplementedError()\n\n\n@BeakerScheduler.register(\"simple\")\nclass SimpleBeakerScheduler(BeakerScheduler):\n    \"\"\"\n    The :class:`SimpleBeakerScheduler` just searches the given clusters for one\n    with enough resources to match what's specified by the step's required resources.\n    \"\"\"\n\n    def __init__(self, clusters: List[str], priority: Union[str, Priority]):\n        super().__init__()\n        self.clusters = clusters\n        self.priority = priority\n        self._node_resources: Optional[Dict[str, List[NodeResources]]] = None\n        if not self.clusters:\n            raise ConfigurationError(\"At least one cluster is required in 'clusters'\")\n\n    @property\n    def node_resources(self) -> Dict[str, List[NodeResources]]:\n        if self._node_resources is None:\n            node_resources = {\n                cluster: [node.limits for node in self.beaker.cluster.nodes(cluster)]\n                for cluster in self.clusters\n            }\n            self._node_resources = node_resources\n            return node_resources\n        else:\n            return self._node_resources\n\n    def schedule(self, step: Step) -> ResourceAssignment:\n        step_resources = step.resources\n        task_resources = TaskResources(\n            cpu_count=step_resources.cpu_count,\n            gpu_count=step_resources.gpu_count,\n            memory=step_resources.memory,\n            shared_memory=step_resources.shared_memory,\n        )\n        clusters = self.clusters\n        if step_resources.gpu_type is not None:\n            clusters = [\n                cluster\n                for cluster, nodes in self.node_resources.items()\n                if all([node.gpu_type == step_resources.gpu_type for node in nodes])\n            ]\n            if not clusters:\n                raise UnrecoverableResourceAssignmentError(\n                    f\"Could not find cluster with nodes that have GPU type '{step_resources.gpu_type}'\"\n                )\n        return ResourceAssignment(\n            cluster=clusters, resources=task_resources, priority=self.priority\n        )\n\n\n@Executor.register(\"beaker\")\nclass BeakerExecutor(Executor):\n    \"\"\"\n    This is a :class:`~tango.executor.Executor` that runs steps on `Beaker`_.\n    Each step is run as its own Beaker experiment.\n\n    .. tip::\n        Registered as an :class:`~tango.executor.Executor` under the name \"beaker\".\n\n    .. important::\n        The :class:`BeakerExecutor` requires that you run Tango within a GitHub repository and you push\n        all of your changes prior to each ``tango run`` call. It also requires that you have\n        a `GitHub personal access token <https://github.com/settings/tokens/new>`_\n        with at least the \"repo\" scope set to the environment variable ``GITHUB_TOKEN``\n        (you can also set it using the ``github_token`` parameter, see below).\n\n        This is because :class:`BeakerExecutor` has to be able to clone your code from Beaker.\n\n    .. important::\n        The :class:`BeakerExecutor` will try to recreate your Python environment on Beaker\n        every time a step is run, so it's important that you specify all of your dependencies\n        in a PIP ``requirements.txt`` file, ``setup.py`` file, or a conda ``environment.yml`` file.\n        Alternatively you could provide the ``install_cmd`` argument.\n\n    .. important::\n        The :class:`BeakerExecutor` takes no responsibility for saving the results of steps that\n        it runs on Beaker. That's the job of your workspace. So make sure your using the\n        right type of workspace or your results will be lost.\n\n        For example, any \"remote\" workspace (like the :class:`BeakerWorkspace`) would work,\n        or in some cases you could use a :class:`~tango.workspaces.LocalWorkspace` on an NFS drive.\n\n    .. important::\n        If you're running a step that requires special hardware, e.g. a GPU, you should\n        specify that in the ``step_resources`` parameter to the step, or by overriding\n        the step's :meth:`.resources() <tango.step.Step.resources>` property method.\n\n    :param workspace: The :class:`~tango.workspace.Workspace` to use.\n    :param clusters: A list of Beaker clusters that the executor may use to run steps.\n        If ``scheduler`` is specified, this argument is ignored.\n    :param include_package: A list of Python packages to import before running steps.\n    :param beaker_workspace: The name or ID of the Beaker workspace to use.\n    :param github_token: You can use this parameter to set a GitHub personal access token instead of using\n        the ``GITHUB_TOKEN`` environment variable.\n    :param google_token: You can use this parameter to set a Google Cloud token instead of using\n        the ``GOOGLE_TOKEN`` environment variable.\n    :param beaker_image: The name or ID of a Beaker image to use for running steps on Beaker.\n        The image must come with bash and `conda <https://docs.conda.io/en/latest/index.html>`_\n        installed (Miniconda is okay).\n        This is mutually exclusive with the ``docker_image`` parameter. If neither ``beaker_image``\n        nor ``docker_image`` is specified, the :data:`DEFAULT_BEAKER_IMAGE` will be used.\n    :param docker_image: The name of a publicly-available Docker image to use for running\n        steps on Beaker. The image must come with bash and `conda <https://docs.conda.io/en/latest/index.html>`_\n        installed (Miniconda is okay).\n        This is mutually exclusive with the ``beaker_image`` parameter.\n    :param datasets: External data sources to mount into the Beaker job for each step. You could use\n        this to mount an NFS drive, for example.\n    :param env_vars: Environment variables to set in the Beaker job for each step.\n    :param venv_name: The name of the conda virtual environment to use or create on the image.\n        If you're using your own image that already has a conda environment you want to use,\n        you should set this variable to the name of that environment.\n        You can also set this to \"base\" to use the base environment.\n    :param parallelism: Control the maximum number of steps run in parallel on Beaker.\n    :param install_cmd: Override the command used to install your code and its dependencies\n        in each Beaker job.\n        For example, you could set ``install_cmd=\"pip install .[dev]\"``.\n    :param priority: The default task priority to assign to jobs ran on Beaker.\n        If ``scheduler`` is specified, this argument is ignored.\n    :param scheduler: A :class:`BeakerScheduler` to use for assigning resources to steps.\n        If not specified the :class:`SimpleBeakerScheduler` is used with the given\n        ``clusters`` and ``priority``.\n    :param allow_dirty: By default, the Beaker Executor requires that your git working directory has no uncommitted\n        changes. If you set this to ``True``, we skip this check.\n    :param kwargs: Additional keyword arguments passed to :meth:`Beaker.from_env() <beaker.Beaker.from_env()>`.\n\n    .. attention::\n        Certain parameters should not be included in the :data:`~tango.settings.TangoGlobalSettings.executor`\n        part of your ``tango.yml`` file, namely ``workspace`` and ``include_package``.\n        Instead use the top-level :data:`~tango.settings.TangoGlobalSettings.workspace`\n        and :data:`~tango.settings.TangoGlobalSettings.include_package` fields, respectively.\n\n    :examples:\n\n    **Minimal tango.yaml file**\n\n    You can use this executor by specifying it in your ``tango.yml`` settings file:\n\n    .. code:: yaml\n\n        executor:\n          type: beaker\n          beaker_workspace: ai2/my-workspace\n          clusters:\n            - ai2/general-cirrascale\n\n    **Using GPUs**\n\n    If you have a step that requires a GPU, there are two things you need to do:\n\n    1. First, you'll need to ensure that the :class:`BeakerExecutor` can install your dependencies the right way\n    to support the GPU hardware. There are usually two ways to do this: use a Docker image that comes\n    with a proper installation of your hardware-specific dependencies (e.g. PyTorch), or add a conda\n    ``environment.yml`` file to your project that specifies the proper version of those dependencies.\n\n    If you go with first option you don't necessarily need to build your own Docker image.\n    If PyTorch is the only hardware-specific dependency you have, you could just use\n    one of AI2's pre-built PyTorch images. Just add these lines to your ``tango.yml`` file:\n\n    .. code:: diff\n\n         executor:\n           type: beaker\n           beaker_workspace: ai2/my-workspace\n        +  docker_image: ghcr.io/allenai/pytorch:1.12.0-cuda11.3-python3.9\n        +  venv_name: base\n           clusters:\n             - ai2/general-cirrascale\n\n    The ``venv_name: base`` line tells the :class:`BeakerExecutor` to use the existing\n    conda environment called \"base\" on the image instead of creating a new one.\n\n    Alternatively, you could use the :data:`default image <DEFAULT_BEAKER_IMAGE>`\n    and just add a conda ``environment.yml`` file to the root of your project\n    that looks like this:\n\n    .. code:: yaml\n\n        name: torch-env\n        channels:\n          - pytorch\n        dependencies:\n          - python=3.9\n          - cudatoolkit=11.3\n          - numpy\n          - pytorch\n          - ...\n\n    2. And second, you'll need to specify the GPUs required by each step in the config for that step under\n    the :class:`step_resources <tango.step.StepResources>` parameter. For example,\n\n    .. code:: json\n\n        \"steps\": {\n            \"train\": {\n                \"type\": \"torch::train\",\n                \"step_resources\": {\n                    \"gpu_count\": 1\n                }\n            }\n        }\n\n    \"\"\"\n\n    DEFAULT_BEAKER_IMAGE: str = \"ai2/conda\"\n    \"\"\"\n    The default image. Used if neither ``beaker_image`` nor ``docker_image`` are set.\n    \"\"\"\n\n    DEFAULT_NFS_DRIVE = \"/net/nfs.cirrascale\"\n\n    RESOURCE_ASSIGNMENT_WARNING_INTERVAL = 60 * 5\n\n    def __init__(\n        self,\n        workspace: Workspace,\n        clusters: Optional[List[str]] = None,\n        include_package: Optional[Sequence[str]] = None,\n        beaker_workspace: Optional[str] = None,\n        github_token: Optional[str] = None,\n        google_token: Optional[str] = None,\n        beaker_image: Optional[str] = None,\n        docker_image: Optional[str] = None,\n        datasets: Optional[List[DataMount]] = None,\n        env_vars: Optional[List[EnvVar]] = None,\n        venv_name: Optional[str] = None,\n        parallelism: Optional[int] = None,\n        install_cmd: Optional[str] = None,\n        priority: Optional[Union[str, Priority]] = None,\n        allow_dirty: bool = False,\n        scheduler: Optional[BeakerScheduler] = None,\n        budget: Optional[str] = None,\n        **kwargs,\n    ):\n        # Pre-validate arguments.\n        if beaker_image is None and docker_image is None:\n            beaker_image = self.DEFAULT_BEAKER_IMAGE\n        elif (beaker_image is None) == (docker_image is None):\n            raise ConfigurationError(\n                \"Either 'beaker_image' or 'docker_image' must be specified for BeakerExecutor, but not both.\"\n            )\n\n        if budget is None:\n            raise ConfigurationError(\"You must specify a budget to use the beaker executor.\")\n        else:\n            self._budget = budget\n\n        from tango.workspaces import LocalWorkspace, MemoryWorkspace\n\n        if isinstance(workspace, MemoryWorkspace):\n            raise ConfigurationError(\n                \"You cannot use the `MemoryWorkspace` with the `BeakerExecutor`! \"\n                \"Please specify a different workspace.\"\n            )\n        elif isinstance(workspace, LocalWorkspace):\n            if str(workspace.dir).startswith(self.DEFAULT_NFS_DRIVE):\n                # Mount the NFS drive if not mount already.\n                datasets = datasets or []\n                if not datasets or not any(\n                    [\n                        dm.source.host_path is not None\n                        and dm.source.host_path.startswith(self.DEFAULT_NFS_DRIVE)\n                        for dm in datasets\n                    ]\n                ):\n                    nfs_mount = DataMount.new(\n                        self.DEFAULT_NFS_DRIVE, host_path=self.DEFAULT_NFS_DRIVE\n                    )\n                    datasets.append(nfs_mount)\n            else:\n                warnings.warn(\n                    \"It appears that you're using a `LocalWorkspace` on a directory that is not an NFS drive. \"\n                    \"If the `BeakerExecutor` cannot access this directory from Beaker, your results will be lost.\",\n                    UserWarning,\n                )\n\n        super().__init__(workspace, include_package=include_package, parallelism=parallelism)\n\n        self.max_thread_workers = self.parallelism or min(32, (os.cpu_count() or 1) + 4)\n        self.beaker = get_client(beaker_workspace=beaker_workspace, **kwargs)\n        self.beaker_image = beaker_image\n        self.docker_image = docker_image\n        self.datasets = datasets\n        self.env_vars = env_vars\n        self.venv_name = venv_name\n        self.install_cmd = install_cmd\n        self.allow_dirty = allow_dirty\n        self.scheduler: BeakerScheduler\n        if scheduler is None:\n            if clusters is None:\n                raise ConfigurationError(\n                    \"Either 'scheduler' or 'clusters' argument to BeakerExecutor is required\"\n                )\n            self.scheduler = SimpleBeakerScheduler(clusters, priority=priority or Priority.normal)\n        else:\n            if clusters is not None:\n                warnings.warn(\n                    \"The 'clusters' parameter will be ignored since you specified a 'scheduler'\",\n                    UserWarning,\n                )\n            if priority is not None:\n                warnings.warn(\n                    \"The 'priority' parameter will be ignored since you specified a 'scheduler'\",\n                    UserWarning,\n                )\n            self.scheduler = scheduler\n        self.scheduler.beaker = self.beaker\n\n        self._is_cancelled = threading.Event()\n        self._logged_git_info = False\n        self._last_resource_assignment_warning: Optional[float] = None\n        self._jobs = 0\n\n        try:\n            self.github_token: str = github_token or os.environ[\"GITHUB_TOKEN\"]\n        except KeyError:\n            raise ConfigurationError(\n                \"A GitHub personal access token with the 'repo' scope is required. \"\n                \"This can be set with the 'github_token' argument to the BeakerExecutor, \"\n                \"or as the environment variable 'GITHUB_TOKEN'.\"\n            )\n\n        self.google_token = google_token or os.environ.get(\"GOOGLE_TOKEN\")\n\n        # Check if google auth credentials are in the default location\n        if self.google_token is None and os.path.exists(Constants.DEFAULT_GOOGLE_CREDENTIALS_FILE):\n            self.google_token = Constants.DEFAULT_GOOGLE_CREDENTIALS_FILE\n\n        # If credentials are provided in the form of a file path, load the credentials\n        # so that they can be used in beaker. Do this only if required, i.e., only if GSWorkspace\n        # is being used.\n        if self.google_token is not None and self.google_token.endswith(\".json\"):\n            from tango.integrations.gs import GSWorkspace\n\n            if isinstance(workspace, GSWorkspace):\n                with open(self.google_token) as f:\n                    self.google_token = f.read()\n\n        if self.google_token is None:\n            self.google_token = \"default\"\n\n        # Ensure entrypoint dataset exists.\n        self._ensure_entrypoint_dataset()\n\n        # Get repo info and make sure we're in a GitHub-hosted repository.\n        git = GitMetadata.check_for_repo()\n        if (\n            git is None\n            or git.commit is None\n            or git.remote is None\n            or \"github.com\" not in git.remote\n        ):\n            raise ExecutorError(\n                f\"Missing git data. \"\n                f\"BeakerExecutor requires a git repository with a GitHub remote.\"\n            )\n        self._github_account, self._github_repo = self._parse_git_remote(git.remote)\n        self._git_commit = git.commit\n\n    def check_repo_state(self):\n        if not self.allow_dirty:\n            # Make sure repository is clean, if we're in one.\n            try:\n                # Check for uncommitted changes.\n                repo = Repo(\".\")\n                if repo.is_dirty():\n                    raise ExecutorError(\n                        \"You have uncommitted changes! Commit your changes or use the 'allow_dirty' option.\"\n                    )\n\n                # Check for un-pushed commits.\n                remote_name = repo.remote().name\n                git = Git(\".\")\n                if git.log([f\"{remote_name}..HEAD\", \"--not\", \"--remotes\", \"--oneline\"]):\n                    raise ExecutorError(\n                        \"You have unpushed changes! Push your changes or use the 'allow_dirty' option.\"\n                    )\n            except InvalidGitRepositoryError:\n                raise ExecutorError(\n                    \"It appears you're not in a valid git repository. \"\n                    \"The Beaker executor requires a git repository.\"\n                )\n            except GitCommandError:\n                pass\n\n    def execute_step_graph(\n        self, step_graph: StepGraph, run_name: Optional[str] = None\n    ) -> ExecutorOutput:\n        import concurrent.futures\n\n        self.check_repo_state()\n\n        self._is_cancelled.clear()\n\n        # These will hold the final results which we'll update along the way.\n        successful: Dict[str, ExecutionMetadata] = {}\n        failed: Dict[str, ExecutionMetadata] = {}\n        not_run: Dict[str, ExecutionMetadata] = {}\n\n        # Keeps track of steps that are next up to run on Beaker.\n        steps_to_run: Set[str] = set()\n        # These are steps that have been submitted to Beaker but haven't completed yet.\n        submitted_steps: Set[str] = set()\n        # Futures for tracking the Beaker runs for each step.\n        step_futures: List[concurrent.futures.Future] = []\n\n        uncacheable_leaf_steps = step_graph.uncacheable_leaf_steps()\n\n        # These are all of the steps that still need to be run at some point.\n        steps_left_to_run = uncacheable_leaf_steps | {\n            step for step in step_graph.values() if step.cache_results\n        }\n\n        def update_steps_to_run():\n            nonlocal steps_to_run, not_run\n            for step_name, step in step_graph.items():\n                if (\n                    step_name in submitted_steps\n                    or step_name in successful\n                    or step_name in failed\n                    or step_name in not_run\n                ):\n                    # Make sure step is no longer in queue.\n                    steps_to_run.discard(step_name)  # This does NOT raise KeyError if not found\n                else:\n                    # Check dependencies.\n                    for dependency in step.dependencies:\n                        if dependency.name not in successful and dependency.cache_results:\n                            if dependency.name in failed or dependency.name in not_run:\n                                # A dependency failed or can't be run, so this step can't be run.\n                                not_run[step_name] = ExecutionMetadata()\n                                steps_to_run.discard(step_name)\n                                steps_left_to_run.discard(step)\n                            break\n                    else:\n                        # Dependencies are OK, so we can run this step now.\n                        if step.cache_results or step in uncacheable_leaf_steps:\n                            steps_to_run.add(step_name)\n\n        def make_future_done_callback(step_name: str):\n            def future_done_callback(future: concurrent.futures.Future):\n                nonlocal successful, failed, steps_left_to_run\n\n                self._jobs = max(0, self._jobs - 1)\n                step = step_graph[step_name]\n\n                try:\n                    exc = future.exception()\n                    if exc is None:\n                        successful[step_name] = ExecutionMetadata(\n                            result_location=None\n                            if not step.cache_results\n                            else self.workspace.step_info(step).result_location,\n                            logs_location=future.result(),\n                        )\n                        steps_left_to_run.discard(step)\n                    elif isinstance(exc, ResourceAssignmentError):\n                        submitted_steps.discard(step_name)\n                        self._emit_resource_assignment_warning()\n                    elif isinstance(exc, StepFailedError):\n                        failed[step_name] = ExecutionMetadata(logs_location=exc.experiment_url)\n                        steps_left_to_run.discard(step)\n                    elif isinstance(exc, (ExecutorError, CancellationError)):\n                        failed[step_name] = ExecutionMetadata()\n                        steps_left_to_run.discard(step)\n                    else:\n                        log_exception(exc, logger)\n                        failed[step_name] = ExecutionMetadata()\n                        steps_left_to_run.discard(step)\n                except concurrent.futures.TimeoutError as exc:\n                    log_exception(exc, logger)\n                    failed[step_name] = ExecutionMetadata()\n                    steps_left_to_run.discard(step)\n\n            return future_done_callback\n\n        last_progress_update = time.monotonic()\n\n        def log_progress():\n            nonlocal last_progress_update\n\n            now = time.monotonic()\n            if now - last_progress_update >= 60 * 2:\n                last_progress_update = now\n\n                waiting_for = [\n                    step_name\n                    for step_name in submitted_steps\n                    if step_name not in failed and step_name not in successful\n                ]\n                if len(waiting_for) > 5:\n                    logger.info(\n                        \"Waiting for %d steps...\",\n                        len(waiting_for),\n                    )\n                elif len(waiting_for) > 1:\n                    logger.info(\n                        \"Waiting for %d steps (%s)...\",\n                        len(waiting_for),\n                        \"'\" + \"', '\".join(waiting_for) + \"'\",\n                    )\n                elif len(waiting_for) == 1:\n                    logger.info(\"Waiting for 1 step ('%s')...\", list(waiting_for)[0])\n\n                still_to_run = [\n                    step.name for step in steps_left_to_run if step.name not in submitted_steps\n                ]\n                if len(still_to_run) > 5:\n                    logger.info(\n                        \"Still waiting to submit %d more steps...\",\n                        len(still_to_run),\n                    )\n                elif len(still_to_run) > 1:\n                    logger.info(\n                        \"Still waiting to submit %d more steps (%s)...\",\n                        len(still_to_run),\n                        \"'\" + \"', '\".join(still_to_run) + \"'\",\n                    )\n                elif len(still_to_run) == 1:\n                    logger.info(\"Still waiting to submit 1 more step ('%s')...\", still_to_run[0])\n\n        update_steps_to_run()\n\n        try:\n            with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_thread_workers) as pool:\n                while steps_left_to_run:\n                    # Submit steps left to run.\n                    for step_name in steps_to_run:\n                        future = pool.submit(\n                            self._execute_sub_graph_for_step, step_graph, step_name, True\n                        )\n                        future.add_done_callback(make_future_done_callback(step_name))\n                        self._jobs += 1\n                        step_futures.append(future)\n                        submitted_steps.add(step_name)\n\n                    if step_futures:\n                        # Wait for something to happen.\n                        _, not_done = concurrent.futures.wait(\n                            step_futures,\n                            return_when=concurrent.futures.FIRST_COMPLETED,\n                            timeout=2.0,\n                        )\n\n                        # Update the list of running futures.\n                        step_futures.clear()\n                        step_futures = list(not_done)\n                    else:\n                        time.sleep(2.0)\n\n                    # Update the step queue.\n                    update_steps_to_run()\n\n                    log_progress()\n        except (KeyboardInterrupt, CancellationError):\n            if step_futures:\n                cli_logger.warning(\"Received interrupt, canceling steps...\")\n                self._is_cancelled.set()\n                concurrent.futures.wait(step_futures)\n            raise\n        finally:\n            self._is_cancelled.clear()\n\n        # NOTE: The 'done callback' added to each future is executed in a thread,\n        # and so might not complete before the last 'update_steps_to_run()' is called\n        # in the loop above. Therefore we have to call 'update_steps_to_run()'\n        # one last time here to ensure the 'not_run' set is up-to-date.\n        update_steps_to_run()\n\n        return ExecutorOutput(successful=successful, failed=failed, not_run=not_run)\n\n    def _emit_resource_assignment_warning(self):\n        if self._last_resource_assignment_warning is None or (\n            time.monotonic() - self._last_resource_assignment_warning\n            > self.RESOURCE_ASSIGNMENT_WARNING_INTERVAL\n        ):\n            self._last_resource_assignment_warning = time.monotonic()\n            logger.warning(\n                \"Some steps can't be run yet - waiting on more Beaker resources \"\n                \"to become available...\"\n            )\n\n    def _check_if_cancelled(self):\n        if self._is_cancelled.is_set():\n            raise RunCancelled\n\n    def _execute_sub_graph_for_step(\n        self,\n        step_graph: StepGraph,\n        step_name: str,\n        in_thread: bool = False,\n    ) -> Optional[str]:\n        if not in_thread:\n            self._is_cancelled.clear()\n        else:\n            self._check_if_cancelled()\n\n        step = step_graph[step_name]\n\n        if step.cache_results and step in self.workspace.step_cache:\n            cli_logger.info(\n                '[green]\\N{check mark} Found output for step [bold]\"%s\"[/] in cache...[/]',\n                step_name,\n            )\n            return None\n\n        if step.resources.machine == \"local\":\n            if step.cache_results:\n                step.ensure_result(self.workspace)\n            else:\n                result = step.result(self.workspace)\n                if hasattr(result, \"__next__\"):\n                    from collections import deque\n\n                    deque(result, maxlen=0)\n            return None\n\n        experiment: Optional[Experiment] = None\n        experiment_url: Optional[str] = None\n        ephemeral_datasets: List[Dataset] = []\n\n        # Try to find any existing experiments for this step that are still running.\n        if step.cache_results:\n            for exp in self.beaker.workspace.experiments(\n                match=f\"{Constants.STEP_EXPERIMENT_PREFIX}{step.unique_id}-\"\n            ):\n                self._check_if_cancelled()\n                try:\n                    latest_job = self.beaker.experiment.latest_job(exp)\n                except (ValueError, ExperimentNotFound):\n                    continue\n                if latest_job is not None and not latest_job.is_done:\n                    experiment = exp\n                    experiment_url = self.beaker.experiment.url(exp)\n                    cli_logger.info(\n                        \"[blue]\\N{black rightwards arrow} Found existing Beaker experiment [b]%s[/] for \"\n                        'step [b]\"%s\"[/] that is still running...[/]',\n                        experiment_url,\n                        step_name,\n                    )\n                    break\n\n        # Otherwise we submit a new experiment...\n        if experiment is None:\n            # Initialize experiment and task spec.\n            experiment_name, spec, ephemeral_datasets = self._build_experiment_spec(\n                step_graph, step_name\n            )\n            self._check_if_cancelled()\n\n            step.log_starting()\n\n            # Create experiment.\n            experiment = self.beaker.experiment.create(experiment_name, spec)\n            experiment_url = self.beaker.experiment.url(experiment)\n            cli_logger.info(\n                '[blue]\\N{black rightwards arrow} Submitted Beaker experiment [b]%s[/] for step [b]\"%s\"[/]...[/]',\n                experiment_url,\n                step_name,\n            )\n\n        assert experiment is not None\n        assert experiment_url is not None\n\n        # Follow the experiment until it completes.\n        try:\n            while True:\n                poll_interval = min(60, 5 * min(self._jobs, self.max_thread_workers))\n                try:\n                    self._check_if_cancelled()\n                    self.beaker.experiment.wait_for(\n                        experiment,\n                        strict=True,\n                        quiet=True,\n                        timeout=poll_interval + 2,\n                        poll_interval=poll_interval,\n                    )\n                    break\n                except JobTimeoutError:\n                    time.sleep(poll_interval)\n                    continue\n        except (JobFailedError, TaskStoppedError):\n            cli_logger.error(\n                '[red]\\N{ballot x} Step [b]\"%s\"[/] failed. You can check the logs at [b]%s[/][/]',\n                step_name,\n                experiment_url,\n            )\n            raise StepFailedError(\n                f'Beaker job for step \"{step_name}\" failed. '\n                f\"You can check the logs at {experiment_url}\",\n                experiment_url,\n            )\n        except (KeyboardInterrupt, CancellationError):\n            cli_logger.warning(\n                'Stopping Beaker experiment [cyan]%s[/] for step [b]\"%s\"[/] (%s)',\n                experiment_url,\n                step_name,\n                step.unique_id,\n            )\n            self.beaker.experiment.stop(experiment)\n            raise\n        else:\n            step.log_finished()\n        finally:\n            # Remove ephemeral datasets.\n            result_dataset = self.beaker.experiment.results(experiment)\n            if result_dataset is not None:\n                ephemeral_datasets.append(result_dataset)\n            for dataset in ephemeral_datasets:\n                try:\n                    self.beaker.dataset.delete(dataset)\n                except DatasetNotFound:\n                    pass\n\n        return experiment_url\n\n    @staticmethod\n    def _parse_git_remote(url: str) -> Tuple[str, str]:\n        \"\"\"\n        Parse a git remote URL into a GitHub (account, repo) pair.\n        \"\"\"\n        account, repo = (\n            url.split(\"https://github.com/\")[-1]\n            .split(\"git@github.com:\")[-1]\n            .split(\".git\")[0]\n            .split(\"/\")\n        )\n        return account, repo\n\n    def _ensure_entrypoint_dataset(self) -> Dataset:\n        import hashlib\n        from importlib.resources import read_binary\n\n        import tango.integrations.beaker\n\n        workspace_id = self.beaker.workspace.get().id\n\n        # Get hash of the local entrypoint source file.\n        sha256_hash = hashlib.sha256()\n        contents = read_binary(tango.integrations.beaker, Constants.ENTRYPOINT_FILENAME)\n        sha256_hash.update(contents)\n\n        entrypoint_dataset_name = (\n            f\"{Constants.ENTRYPOINT_DATASET_PREFIX}{workspace_id}-{sha256_hash.hexdigest()[:6]}\"\n        )\n        tmp_entrypoint_dataset_name = (\n            f\"{Constants.ENTRYPOINT_DATASET_PREFIX}{str(uuid.uuid4())}-tmp\"\n        )\n\n        # Ensure entrypoint dataset exists.\n        entrypoint_dataset: Dataset\n        try:\n            entrypoint_dataset = self.beaker.dataset.get(entrypoint_dataset_name)\n        except DatasetNotFound:\n            # Create it.\n            logger.debug(f\"Creating entrypoint dataset '{entrypoint_dataset_name}'\")\n            try:\n                tmp_entrypoint_dataset = self.beaker.dataset.create(\n                    tmp_entrypoint_dataset_name, quiet=True, commit=False\n                )\n                self.beaker.dataset.upload(\n                    tmp_entrypoint_dataset, contents, Constants.ENTRYPOINT_FILENAME, quiet=True\n                )\n                self.beaker.dataset.commit(tmp_entrypoint_dataset)\n                entrypoint_dataset = self.beaker.dataset.rename(\n                    tmp_entrypoint_dataset, entrypoint_dataset_name\n                )\n            except DatasetConflict:  # could be in a race with another `tango` process.\n                time.sleep(1.0)\n                entrypoint_dataset = self.beaker.dataset.get(entrypoint_dataset_name)\n\n        # Verify contents.\n        err_msg = (\n            f\"Checksum failed for entrypoint dataset {self.beaker.dataset.url(entrypoint_dataset)}\\n\"\n            f\"This could be a bug, or it could mean someone has tampered with the dataset.\\n\"\n            f\"If you're sure no one has tampered with it, you can delete the dataset from \"\n            f\"the Beaker dashboard and try again.\"\n        )\n        file_info = self.beaker.dataset.file_info(entrypoint_dataset, Constants.ENTRYPOINT_FILENAME)\n        if file_info.digest is not None and file_info.digest != Digest.from_decoded(\n            sha256_hash.digest(), \"SHA256\"\n        ):\n            raise ExecutorError(err_msg)\n\n        return entrypoint_dataset\n\n    def _ensure_step_graph_dataset(self, step_graph: StepGraph) -> Dataset:\n        step_graph_dataset_name = f\"{Constants.STEP_GRAPH_ARTIFACT_PREFIX}{str(uuid.uuid4())}\"\n        try:\n            dataset = self.beaker.dataset.create(step_graph_dataset_name, quiet=True, commit=False)\n            self.beaker.dataset.upload(\n                dataset,\n                json.dumps({\"steps\": step_graph.to_config(include_unique_id=True)}).encode(),\n                Constants.STEP_GRAPH_FILENAME,\n                quiet=True,\n            )\n            self.beaker.dataset.commit(dataset)\n        except DatasetConflict:  # could be in a race with another `tango` process.\n            time.sleep(1.0)\n            dataset = self.beaker.dataset.get(step_graph_dataset_name)\n        return dataset\n\n    def _build_experiment_spec(\n        self, step_graph: StepGraph, step_name: str\n    ) -> Tuple[str, ExperimentSpec, List[Dataset]]:\n        from tango.common.logging import TANGO_LOG_LEVEL\n\n        step = step_graph[step_name]\n        sub_graph = step_graph.sub_graph(step_name)\n        step_info = self.workspace.step_info(step)\n        experiment_name = (\n            f\"{Constants.STEP_EXPERIMENT_PREFIX}{step.unique_id}-{str(uuid.uuid4())[:8]}\"\n        )\n        github_account, github_repo, git_ref = (\n            self._github_account,\n            self._github_repo,\n            self._git_commit,\n        )\n        if not self._logged_git_info:\n            self._logged_git_info = True\n            cli_logger.info(\n                \"[blue]Using source code from \"\n                \"[b]https://github.com/%s/%s/commit/%s[/] to run steps on Beaker[/]\",\n                github_account,\n                github_repo,\n                git_ref,\n            )\n\n        # Get cluster, resources, and priority to use.\n        clusters, task_resources, priority = self.scheduler.schedule(step)\n        self._check_if_cancelled()\n\n        # Ensure dataset with the entrypoint script exists and get it.\n        entrypoint_dataset = self._ensure_entrypoint_dataset()\n        self._check_if_cancelled()\n\n        # Create dataset for step graph.\n        step_graph_dataset = self._ensure_step_graph_dataset(sub_graph)\n        self._check_if_cancelled()\n\n        # Write the GitHub token secret.\n        self.beaker.secret.write(Constants.GITHUB_TOKEN_SECRET_NAME, self.github_token)\n        self._check_if_cancelled()\n\n        # Write the Beaker token secret.\n        self.beaker.secret.write(Constants.BEAKER_TOKEN_SECRET_NAME, self.beaker.config.user_token)\n        self._check_if_cancelled()\n\n        # Write the Google Cloud token secret.\n        if self.google_token is not None:\n            self.beaker.secret.write(Constants.GOOGLE_TOKEN_SECRET_NAME, self.google_token)\n            self._check_if_cancelled()\n\n        # Build Tango command to run.\n        command = [\n            \"tango\",\n            \"--log-level\",\n            \"debug\",\n            \"--called-by-executor\",\n            \"beaker-executor-run\",\n            Constants.INPUT_DIR + \"/\" + Constants.STEP_GRAPH_FILENAME,\n            step.name,\n            self.workspace.url,\n        ]\n        if self.include_package is not None:\n            for package in self.include_package:\n                command += [\"-i\", package, \"--log-level\", TANGO_LOG_LEVEL or \"debug\"]\n\n        self._check_if_cancelled()\n\n        # Ignore the patch version.\n        # E.g. '3.9.7' -> '3.9'\n        python_version = step_info.environment.python\n        python_version = python_version[: python_version.find(\".\", python_version.find(\".\") + 1)]\n\n        # Build task spec.\n        task_spec = (\n            TaskSpec.new(\n                step.unique_id,\n                beaker_image=self.beaker_image,\n                docker_image=self.docker_image,\n                result_path=Constants.RESULTS_DIR,\n                command=[\"bash\", Constants.ENTRYPOINT_DIR + \"/\" + Constants.ENTRYPOINT_FILENAME],\n                arguments=command,\n                resources=task_resources,\n                datasets=self.datasets,\n                env_vars=self.env_vars,\n                priority=priority,\n            )\n            .with_constraint(cluster=[clusters] if isinstance(clusters, str) else clusters)\n            .with_env_var(name=\"TANGO_VERSION\", value=VERSION)\n            .with_env_var(name=\"GITHUB_TOKEN\", secret=Constants.GITHUB_TOKEN_SECRET_NAME)\n            .with_env_var(name=\"BEAKER_TOKEN\", secret=Constants.BEAKER_TOKEN_SECRET_NAME)\n            .with_env_var(name=\"GOOGLE_TOKEN\", secret=Constants.GOOGLE_TOKEN_SECRET_NAME)\n            .with_env_var(name=\"GITHUB_REPO\", value=f\"{github_account}/{github_repo}\")\n            .with_env_var(name=\"GIT_REF\", value=git_ref)\n            .with_env_var(name=\"PYTHON_VERSION\", value=python_version)\n            .with_env_var(name=\"BEAKER_EXPERIMENT_NAME\", value=experiment_name)\n            .with_dataset(Constants.ENTRYPOINT_DIR, beaker=entrypoint_dataset.id)\n            .with_dataset(Constants.INPUT_DIR, beaker=step_graph_dataset.id)\n        )\n\n        if self.venv_name is not None:\n            task_spec = task_spec.with_env_var(name=\"VENV_NAME\", value=self.venv_name)\n\n        if self.install_cmd is not None:\n            task_spec = task_spec.with_env_var(name=\"INSTALL_CMD\", value=self.install_cmd)\n\n        return (\n            experiment_name,\n            ExperimentSpec(\n                tasks=[task_spec],\n                description=f'Tango step \"{step_name}\" ({step.unique_id})',\n                budget=self._budget,\n            ),\n            [step_graph_dataset],\n        )\n"
  },
  {
    "path": "tango/integrations/beaker/step_cache.py",
    "content": "import logging\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nfrom beaker import Beaker\nfrom beaker import Dataset as BeakerDataset\nfrom beaker import DatasetConflict, DatasetNotFound, DatasetWriteError\n\nfrom tango import Step\nfrom tango.common import PathOrStr\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.util import make_safe_filename, tango_cache_dir\nfrom tango.integrations.beaker.common import Constants, get_client\nfrom tango.step_cache import StepCache\nfrom tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache\nfrom tango.step_info import StepInfo\n\nlogger = logging.getLogger(__name__)\n\n\n@StepCache.register(\"beaker\")\nclass BeakerStepCache(RemoteStepCache):\n    \"\"\"\n    This is a :class:`~tango.step_cache.StepCache` that's used by :class:`BeakerWorkspace`.\n    It stores the results of steps on Beaker as datasets.\n\n    It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a\n    step's resulting subsequent times should be fast.\n\n    .. tip::\n        Registered as a :class:`~tango.step_cache.StepCache` under the name \"beaker\".\n\n    :param workspace: The name or ID of the Beaker workspace to use.\n    :param beaker: The Beaker client to use.\n    \"\"\"\n\n    Constants = Constants\n\n    def __init__(self, beaker_workspace: Optional[str] = None, beaker: Optional[Beaker] = None):\n        self.beaker: Beaker\n        if beaker is not None:\n            self.beaker = beaker\n            if beaker_workspace is not None:\n                self.beaker.config.default_workspace = beaker_workspace\n                self.beaker.workspace.ensure(beaker_workspace)\n        else:\n            self.beaker = get_client(beaker_workspace=beaker_workspace)\n        if self.beaker.config.default_workspace is None:\n            raise ConfigurationError(\"Beaker default workspace must be set\")\n        super().__init__(\n            tango_cache_dir()\n            / \"beaker_cache\"\n            / make_safe_filename(self.beaker.config.default_workspace)\n        )\n\n    def _step_result_remote(self, step: Union[Step, StepInfo]) -> Optional[BeakerDataset]:\n        \"\"\"\n        Returns a `BeakerDataset` object containing the details of the step.\n        This only returns if the step has been finalized (committed).\n        \"\"\"\n        try:\n            dataset = self.beaker.dataset.get(self.Constants.step_artifact_name(step))\n            return dataset if dataset.committed is not None else None\n        except DatasetNotFound:\n            return None\n\n    def _upload_step_remote(self, step: Step, objects_dir: Path) -> BeakerDataset:\n        \"\"\"\n        Uploads the step's output to remote location.\n        \"\"\"\n        dataset_name = self.Constants.step_artifact_name(step)\n        try:\n            self.beaker.dataset.create(dataset_name, commit=False)\n        except DatasetConflict:\n            pass\n        try:\n            self.beaker.dataset.sync(dataset_name, objects_dir, quiet=True)\n            self.beaker.dataset.commit(dataset_name)\n        except DatasetWriteError:\n            pass\n\n        return self.beaker.dataset.get(dataset_name)\n\n    def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None:\n        \"\"\"\n        Downloads the step's output from remote location.\n        \"\"\"\n        try:\n            self.beaker.dataset.fetch(step_result, target_dir, quiet=True)\n        except DatasetNotFound:\n            raise RemoteNotFoundError()\n\n    def __len__(self):\n        \"\"\"\n        Returns the number of committed step outputs present in the remote location.\n        \"\"\"\n        # NOTE: lock datasets should not count here.\n        return sum(\n            1\n            for ds in self.beaker.workspace.iter_datasets(\n                match=self.Constants.STEP_ARTIFACT_PREFIX, uncommitted=False, results=False\n            )\n            if ds.name is not None\n            and ds.name.startswith(self.Constants.STEP_ARTIFACT_PREFIX)\n            and not ds.name.endswith(self.Constants.LOCK_ARTIFACT_SUFFIX)\n        )\n"
  },
  {
    "path": "tango/integrations/beaker/workspace.py",
    "content": "import json\nimport logging\nimport os\nimport random\nfrom collections import OrderedDict\nfrom pathlib import Path\nfrom typing import Dict, List, Optional, Type, TypeVar, Union, cast\nfrom urllib.parse import ParseResult\n\nimport petname\nfrom beaker import Dataset\nfrom beaker import Dataset as BeakerDataset\nfrom beaker import (\n    DatasetConflict,\n    DatasetNotFound,\n    DatasetSort,\n    Digest,\n    Experiment,\n    ExperimentNotFound,\n)\n\nfrom tango.common.util import make_safe_filename, tango_cache_dir\nfrom tango.step import Step\nfrom tango.step_info import StepInfo, StepState\nfrom tango.workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace\nfrom tango.workspaces.remote_workspace import RemoteWorkspace\n\nfrom .common import BeakerStepLock, Constants, dataset_url, get_client\nfrom .step_cache import BeakerStepCache\n\nT = TypeVar(\"T\")\nU = TypeVar(\"U\", Run, StepInfo)\n\nlogger = logging.getLogger(__name__)\n\n\n@Workspace.register(\"beaker\")\nclass BeakerWorkspace(RemoteWorkspace):\n    \"\"\"\n    This is a :class:`~tango.workspace.Workspace` that stores step artifacts on `Beaker`_.\n\n    .. tip::\n        Registered as a :class:`~tango.workspace.Workspace` under the name \"beaker\".\n\n    :param workspace: The name or ID of the Beaker workspace to use.\n    :param kwargs: Additional keyword arguments passed to :meth:`Beaker.from_env() <beaker.Beaker.from_env()>`.\n    \"\"\"\n\n    STEP_INFO_CACHE_SIZE = 512\n    Constants = Constants\n    NUM_CONCURRENT_WORKERS = 9\n\n    def __init__(self, workspace: str, max_workers: Optional[int] = None, **kwargs):\n        self.beaker = get_client(beaker_workspace=workspace, **kwargs)\n        self._cache = BeakerStepCache(beaker=self.beaker)\n        self._locks: Dict[Step, BeakerStepLock] = {}\n        super().__init__()\n        self.max_workers = max_workers\n        self._disk_cache_dir = tango_cache_dir() / \"beaker_cache\" / \"_objects\"\n        self._mem_cache: \"OrderedDict[Digest, Union[StepInfo, Run]]\" = OrderedDict()\n\n    @property\n    def cache(self):\n        return self._cache\n\n    @property\n    def locks(self):\n        return self._locks\n\n    @property\n    def steps_dir_name(self):\n        return \"beaker_workspace\"\n\n    @property\n    def url(self) -> str:\n        return f\"beaker://{self.beaker.workspace.get().full_name}\"\n\n    def _step_location(self, step: Step) -> str:\n        return dataset_url(self.beaker, self.Constants.step_artifact_name(step))\n\n    @classmethod\n    def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:\n        workspace: str\n        if parsed_url.netloc and parsed_url.path:\n            # e.g. \"beaker://ai2/my-workspace\"\n            workspace = parsed_url.netloc + parsed_url.path\n        elif parsed_url.netloc:\n            # e.g. \"beaker://my-workspace\"\n            workspace = parsed_url.netloc\n        else:\n            raise ValueError(f\"Bad URL for Beaker workspace '{parsed_url}'\")\n        return cls(workspace)\n\n    @property\n    def current_beaker_experiment(self) -> Optional[Experiment]:\n        \"\"\"\n        When the workspace is being used within a Beaker experiment that was submitted\n        by the Beaker executor, this will return the `Experiment` object.\n        \"\"\"\n        experiment_name = os.environ.get(\"BEAKER_EXPERIMENT_NAME\")\n        if experiment_name is not None:\n            try:\n                return self.beaker.experiment.get(experiment_name)\n            except ExperimentNotFound:\n                return None\n        else:\n            return None\n\n    def _remote_lock(self, step: Step) -> BeakerStepLock:\n        return BeakerStepLock(\n            self.beaker, step, current_beaker_experiment=self.current_beaker_experiment\n        )\n\n    def _get_object_from_cache(self, digest: Digest, o_type: Type[U]) -> Optional[U]:\n        cache_path = self._disk_cache_dir / make_safe_filename(str(digest))\n        if digest in self._mem_cache:\n            cached = self._mem_cache.pop(digest)\n            # Move to end.\n            self._mem_cache[digest] = cached\n            return cached if isinstance(cached, o_type) else None\n        elif cache_path.is_file():\n            try:\n                with cache_path.open(\"r+t\") as f:\n                    json_dict = json.load(f)\n                    cached = o_type.from_json_dict(json_dict)\n            except Exception as exc:\n                logger.warning(\"Error while loading object from workspace cache: %s\", str(exc))\n                try:\n                    os.remove(cache_path)\n                except FileNotFoundError:\n                    pass\n                return None\n            # Add to in-memory cache.\n            self._mem_cache[digest] = cached\n            while len(self._mem_cache) > self.STEP_INFO_CACHE_SIZE:\n                self._mem_cache.popitem(last=False)\n            return cached  # type: ignore\n        else:\n            return None\n\n    def _add_object_to_cache(self, digest: Digest, o: U):\n        self._disk_cache_dir.mkdir(parents=True, exist_ok=True)\n        cache_path = self._disk_cache_dir / make_safe_filename(str(digest))\n        self._mem_cache[digest] = o\n        with cache_path.open(\"w+t\") as f:\n            json.dump(o.to_json_dict(), f)\n        while len(self._mem_cache) > self.STEP_INFO_CACHE_SIZE:\n            self._mem_cache.popitem(last=False)\n\n    def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:\n        try:\n            dataset = self.beaker.dataset.get(self.Constants.step_artifact_name(step_or_unique_id))\n            return self._get_step_info_from_dataset(dataset)\n        except (DatasetNotFound, FileNotFoundError):\n            if not isinstance(step_or_unique_id, Step):\n                raise KeyError(step_or_unique_id)\n            step_info = StepInfo.new_from_step(step_or_unique_id)\n            self._update_step_info(step_info)\n            return step_info\n\n    def _get_step_info_from_dataset(self, dataset: Dataset) -> StepInfo:\n        file_info = self.beaker.dataset.file_info(dataset, Constants.STEP_INFO_FNAME)\n        step_info: StepInfo\n        cached = (\n            None\n            if file_info.digest is None\n            else self._get_object_from_cache(file_info.digest, StepInfo)\n        )\n        if cached is not None:\n            step_info = cached\n        else:\n            step_info_bytes = self.beaker.dataset.get_file(dataset, file_info, quiet=True)\n            step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))\n            if file_info.digest is not None:\n                self._add_object_to_cache(file_info.digest, step_info)\n        return step_info\n\n    def _save_run(\n        self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None\n    ) -> Run:\n        # Create a remote dataset that represents this run. The dataset which just contain\n        # a JSON file that maps step names to step unique IDs.\n        run_dataset: BeakerDataset\n        if name is None:\n            # Find a unique name to use.\n            while True:\n                name = petname.generate() + str(random.randint(0, 100))\n                try:\n                    run_dataset = self.beaker.dataset.create(\n                        self.Constants.run_artifact_name(cast(str, name)), commit=False\n                    )\n                except DatasetConflict:\n                    continue\n                else:\n                    break\n        else:\n            try:\n                run_dataset = self.beaker.dataset.create(\n                    self.Constants.run_artifact_name(name), commit=False\n                )\n            except DatasetConflict:\n                raise ValueError(f\"Run name '{name}' is already in use\")\n\n        # Upload run data to dataset.\n        # NOTE: We don't commit the dataset here since we'll need to upload the logs file\n        # after the run.\n        self.beaker.dataset.upload(\n            run_dataset, json.dumps(run_data).encode(), self.Constants.RUN_DATA_FNAME, quiet=True\n        )\n\n        return Run(name=cast(str, name), steps=steps, start_date=run_dataset.created)\n\n    def registered_runs(self) -> Dict[str, Run]:\n        import concurrent.futures\n\n        runs: Dict[str, Run] = {}\n\n        with concurrent.futures.ThreadPoolExecutor(\n            max_workers=self.NUM_CONCURRENT_WORKERS,\n            thread_name_prefix=\"BeakerWorkspace.registered_runs()-\",\n        ) as executor:\n            run_futures = []\n            for dataset in self.beaker.workspace.iter_datasets(\n                match=self.Constants.RUN_ARTIFACT_PREFIX, uncommitted=True, results=False\n            ):\n                run_futures.append(executor.submit(self._get_run_from_dataset, dataset))\n            for future in concurrent.futures.as_completed(run_futures):\n                run = future.result()\n                if run is not None:\n                    runs[run.name] = run\n\n        return runs\n\n    def search_registered_runs(\n        self,\n        *,\n        sort_by: Optional[RunSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        start: Optional[int] = None,\n        stop: Optional[int] = None,\n    ) -> List[RunInfo]:\n        if match is None:\n            match = Constants.RUN_ARTIFACT_PREFIX\n        else:\n            match = Constants.RUN_ARTIFACT_PREFIX + match\n\n        if sort_by is None or sort_by == RunSort.START_DATE:\n            sort = DatasetSort.created\n        elif sort_by == RunSort.NAME:\n            sort = DatasetSort.dataset_name\n        else:\n            raise NotImplementedError\n\n        runs = []\n        for dataset in self.beaker.workspace.iter_datasets(\n            match=match,\n            results=False,\n            cursor=start or 0,\n            limit=None if stop is None else stop - (start or 0),\n            sort_by=sort,\n            descending=sort_descending,\n        ):\n            if dataset.name is not None and dataset.name.startswith(\n                self.Constants.RUN_ARTIFACT_PREFIX\n            ):\n                run_name = dataset.name[len(self.Constants.RUN_ARTIFACT_PREFIX) :]\n                runs.append(RunInfo(name=run_name, start_date=dataset.created))\n\n        return runs\n\n    def num_registered_runs(self, *, match: Optional[str] = None) -> int:\n        if match is None:\n            match = Constants.RUN_ARTIFACT_PREFIX\n        else:\n            match = Constants.RUN_ARTIFACT_PREFIX + match\n\n        count = 0\n        for dataset in self.beaker.workspace.iter_datasets(\n            match=match,\n            results=False,\n        ):\n            if dataset.name is not None and dataset.name.startswith(Constants.RUN_ARTIFACT_PREFIX):\n                count += 1\n\n        return count\n\n    def search_step_info(\n        self,\n        *,\n        sort_by: Optional[StepInfoSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        state: Optional[StepState] = None,\n        start: int = 0,\n        stop: Optional[int] = None,\n    ) -> List[StepInfo]:\n        if state is not None:\n            raise NotImplementedError(\n                f\"{self.__class__.__name__} cannot filter steps efficiently by state\"\n            )\n\n        if match is None:\n            match = Constants.STEP_ARTIFACT_PREFIX\n        else:\n            match = Constants.STEP_ARTIFACT_PREFIX + match\n\n        sort: Optional[DatasetSort] = None\n        if sort_by is None or sort_by == StepInfoSort.START_TIME:\n            sort = DatasetSort.created\n        elif sort_by == StepInfoSort.UNIQUE_ID:\n            sort = DatasetSort.dataset_name\n        elif sort_by is not None:\n            raise NotImplementedError\n\n        steps = []\n        for dataset in self.beaker.workspace.iter_datasets(\n            match=match,\n            results=False,\n            cursor=start or 0,\n            limit=None if stop is None else stop - (start or 0),\n            sort_by=sort or DatasetSort.created,\n            descending=sort_descending,\n        ):\n            try:\n                steps.append(self._get_step_info_from_dataset(dataset))\n            except (DatasetNotFound, FileNotFoundError):\n                continue\n\n        return steps\n\n    def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int:\n        if state is not None:\n            raise NotImplementedError(\n                f\"{self.__class__.__name__} cannot filter steps efficiently by state\"\n            )\n\n        if match is None:\n            match = Constants.STEP_ARTIFACT_PREFIX\n        else:\n            match = Constants.STEP_ARTIFACT_PREFIX + match\n\n        count = 0\n        for dataset in self.beaker.workspace.iter_datasets(\n            match=match,\n            results=False,\n        ):\n            if dataset.name is not None and dataset.name.startswith(Constants.STEP_ARTIFACT_PREFIX):\n                count += 1\n\n        return count\n\n    def registered_run(self, name: str) -> Run:\n        err_msg = f\"Run '{name}' not found in workspace\"\n\n        try:\n            dataset_for_run = self.beaker.dataset.get(self.Constants.run_artifact_name(name))\n            # Make sure the run is in our workspace.\n            if dataset_for_run.workspace_ref.id != self.beaker.workspace.get().id:  # type: ignore # TODO\n                raise DatasetNotFound\n        except DatasetNotFound:\n            raise KeyError(err_msg)\n\n        run = self._get_run_from_dataset(dataset_for_run)\n        if run is None:\n            raise KeyError(err_msg)\n        else:\n            return run\n\n    def _save_run_log(self, name: str, log_file: Path):\n        run_dataset = self.Constants.run_artifact_name(name)\n        self.beaker.dataset.sync(run_dataset, log_file, quiet=True)\n        self.beaker.dataset.commit(run_dataset)\n\n    def _get_run_from_dataset(self, dataset: BeakerDataset) -> Optional[Run]:\n        if dataset.name is None:\n            return None\n        if not dataset.name.startswith(self.Constants.RUN_ARTIFACT_PREFIX):\n            return None\n\n        try:\n            run_name = dataset.name[len(self.Constants.RUN_ARTIFACT_PREFIX) :]\n            steps_info_bytes = self.beaker.dataset.get_file(\n                dataset, self.Constants.RUN_DATA_FNAME, quiet=True\n            )\n            steps_info = json.loads(steps_info_bytes)\n        except (DatasetNotFound, FileNotFoundError):\n            return None\n\n        steps: Dict[str, StepInfo] = {}\n        import concurrent.futures\n\n        with concurrent.futures.ThreadPoolExecutor(\n            max_workers=self.NUM_CONCURRENT_WORKERS,\n            thread_name_prefix=\"BeakerWorkspace._get_run_from_dataset()-\",\n        ) as executor:\n            step_info_futures = []\n            for unique_id in steps_info.values():\n                step_info_futures.append(executor.submit(self.step_info, unique_id))\n            for future in concurrent.futures.as_completed(step_info_futures):\n                step_info = future.result()\n                assert step_info.step_name is not None\n                steps[step_info.step_name] = step_info\n\n        return Run(name=run_name, start_date=dataset.created, steps=steps)\n\n    def _update_step_info(self, step_info: StepInfo):\n        dataset_name = self.Constants.step_artifact_name(step_info)\n\n        step_info_dataset: BeakerDataset\n        try:\n            self.beaker.dataset.create(dataset_name, commit=False)\n        except DatasetConflict:\n            pass\n        step_info_dataset = self.beaker.dataset.get(dataset_name)\n\n        self.beaker.dataset.upload(\n            step_info_dataset,  # folder name\n            json.dumps(step_info.to_json_dict()).encode(),  # step info dict.\n            self.Constants.STEP_INFO_FNAME,  # step info filename\n            quiet=True,\n        )\n\n    def _remove_step_info(self, step_info: StepInfo) -> None:\n        # remove dir from beaker workspace\n        dataset_name = self.Constants.step_artifact_name(step_info)\n        step_dataset = self.beaker.dataset.get(dataset_name)\n        if step_dataset is not None:\n            self.beaker.dataset.delete(step_dataset)\n"
  },
  {
    "path": "tango/integrations/datasets/__init__.py",
    "content": "\"\"\"\n.. important::\n    To use this integration you should install ``tango`` with the \"datasets\" extra\n    (e.g. ``pip install tango[datasets]``) or just install the ``datasets`` library after the fact\n    (e.g. ``pip install datasets``).\n\nComponents for Tango integration with `🤗 Datasets <https://huggingface.co/docs/datasets/>`_.\n\nExample: loading and combining\n------------------------------\n\nHere's an example config that uses the built-in steps from this integration to load,\nconcatenate, and interleave datasets from HuggingFace:\n\n.. literalinclude:: ../../../../test_fixtures/integrations/datasets/config.json\n\nYou could run this with:\n\n.. code-block::\n\n    tango run config.json\n\n\"\"\"\n\n\nimport re\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, TypeVar, Union, overload\n\nfrom tango.common.aliases import PathOrStr\nfrom tango.common.dataset_dict import DatasetDict, IterableDatasetDict\nfrom tango.common.exceptions import ConfigurationError, IntegrationMissingError\nfrom tango.format import Format\nfrom tango.step import Step\n\ntry:\n    import datasets as ds\nexcept ModuleNotFoundError:\n    raise IntegrationMissingError(\"datasets\")\n\n__all__ = [\n    \"LoadDataset\",\n    \"LoadStreamingDataset\",\n    \"DatasetsFormat\",\n    \"convert_to_tango_dataset_dict\",\n    \"InterleaveDatasets\",\n    \"ConcatenateDatasets\",\n    \"DatasetRemixStep\",\n]\n\n\n@overload\ndef convert_to_tango_dataset_dict(hf_dataset_dict: ds.DatasetDict) -> DatasetDict:\n    ...\n\n\n@overload\ndef convert_to_tango_dataset_dict(hf_dataset_dict: ds.IterableDatasetDict) -> IterableDatasetDict:  # type: ignore\n    ...\n\n\ndef convert_to_tango_dataset_dict(hf_dataset_dict):\n    \"\"\"\n    A helper function that can be used to convert a HuggingFace :class:`~datasets.DatasetDict`\n    or :class:`~datasets.IterableDatasetDict` into a native Tango\n    :class:`~tango.common.DatasetDict` or :class:`~tango.common.IterableDatasetDict`.\n\n    This is important to do when your dataset dict is input to another step for caching\n    reasons.\n    \"\"\"\n    if isinstance(hf_dataset_dict, ds.IterableDatasetDict):\n        return IterableDatasetDict(splits=hf_dataset_dict)\n    else:\n        return DatasetDict(splits=hf_dataset_dict)\n\n\nT = Union[ds.Dataset, ds.DatasetDict]\n\n\n@Format.register(\"datasets\")\nclass DatasetsFormat(Format[T]):\n    \"\"\"\n    This format writes a :class:`datasets.Dataset` or :class:`datasets.DatasetDict` to disk\n    using :meth:`datasets.Dataset.save_to_disk()`.\n\n    It is the default :class:`~tango.format.Format` for the :class:`LoadDataset` step.\n    \"\"\"\n\n    VERSION = \"001\"\n\n    def write(self, artifact: T, dir: PathOrStr):\n        dataset_path = Path(dir) / \"data\"\n        artifact.save_to_disk(str(dataset_path))\n\n    def read(self, dir: PathOrStr) -> T:\n        dataset_path = Path(dir) / \"data\"\n        return ds.load_from_disk(str(dataset_path))\n\n\n@Step.register(\"datasets::load\")\nclass LoadDataset(Step):\n    \"\"\"\n    This step loads a `HuggingFace dataset <https://huggingface.co/datasets>`_.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"datasets::load\".\n\n    .. important::\n\n        If you are loading an :class:`~datasets.IterableDataset` or :class:`~datasets.IterableDatasetDict`\n        you need to use the :class:`LoadStreamingDataset` step instead.\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    VERSION = \"001\"\n    CACHEABLE = True\n    # Even though HuggingFace datasets has its own caching mechanism, it can still be worth caching\n    # this step with tango's mechanism since some datasets take a really long time to query from HuggingFace\n    # (\"bigscience/P3\", for example). Tango's caching mechanism circumvents that issue.\n    FORMAT = DatasetsFormat()\n\n    def run(self, path: str, **kwargs) -> Union[ds.DatasetDict, ds.Dataset]:  # type: ignore\n        \"\"\"\n        Load the HuggingFace dataset specified by ``path``.\n\n        ``path`` is the canonical name or path to the dataset. Additional key word arguments\n        are passed as-is to :func:`datasets.load_dataset()`.\n        \"\"\"\n        dataset = ds.load_dataset(path, **kwargs)\n        if not isinstance(dataset, (ds.Dataset, ds.DatasetDict)):\n            raise ConfigurationError(\n                f\"{self.__class__.__name__} can only be used with non-streaming datasets. \"\n                f\"For streaming datasets, use the 'LoadStreamingDataset' ('datasets::load_streaming') step instead.\"\n            )\n        return dataset\n\n\n@Step.register(\"datasets::load_streaming\")\nclass LoadStreamingDataset(Step):\n    \"\"\"\n    This step loads an iterable/streaming `HuggingFace dataset <https://huggingface.co/datasets>`_.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"datasets::load_streaming\".\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    VERSION = \"001\"\n    CACHEABLE = (\n        False  # can't be cached with `DatasetsFormat`, and might be really inefficient anyway.\n    )\n\n    def run(  # type: ignore\n        self, path: str, **kwargs\n    ) -> Union[ds.IterableDatasetDict, ds.IterableDataset]:\n        \"\"\"\n        Load the HuggingFace streaming dataset specified by ``path``.\n\n        ``path`` is the canonical name or path to the dataset. Additional key word arguments\n        are passed as-is to :func:`datasets.load_dataset()`.\n        \"\"\"\n        dataset = ds.load_dataset(path, **kwargs)\n        if not isinstance(dataset, (ds.IterableDataset, ds.IterableDatasetDict)):\n            raise ConfigurationError(\n                f\"{self.__class__.__name__} can only be used with streaming datasets. \"\n                f\"For non-streaming datasets, use the 'LoadDataset' ('datasets::load') step instead.\"\n            )\n        return dataset\n\n\nDatasetType = TypeVar(\"DatasetType\", ds.Dataset, ds.IterableDataset)\n\n\n@Step.register(\"datasets::interleave\")\nclass InterleaveDatasets(Step):\n    \"\"\"\n    This steps interleaves multiple datasets using :func:`~datasets.interleave_datasets()`.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"datasets::interleave\".\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    VERSION = \"001\"\n    CACHEABLE = False  # Not worth caching\n\n    def run(  # type: ignore[override]\n        self,\n        datasets: List[DatasetType],\n        probabilities: Optional[List[float]] = None,\n        seed: Optional[int] = None,\n    ) -> DatasetType:\n        \"\"\"\n        Interleave the list of datasets.\n        \"\"\"\n        return ds.interleave_datasets(datasets, probabilities=probabilities, seed=seed)\n\n\n@Step.register(\"datasets::concatenate\")\nclass ConcatenateDatasets(Step):\n    \"\"\"\n    This step concatenates multiple datasets using :func:`~datasets.concatenate_datasets()`.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"datasets::concatenate\".\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    VERSION = \"001\"\n    CACHEABLE = False  # Not worth caching\n\n    def run(  # type: ignore[override]\n        self,\n        datasets: List[ds.Dataset],\n        info: Optional[Any] = None,\n        split: Optional[Any] = None,\n        axis: int = 0,\n    ) -> ds.Dataset:\n        \"\"\"\n        Concatenate the list of datasets.\n        \"\"\"\n        return ds.concatenate_datasets(datasets, info=info, split=split, axis=axis)\n\n\n@Step.register(\"datasets::dataset_remix\")\nclass DatasetRemixStep(Step):\n    \"\"\"\n    This step can remix splits in a :class:`~datasets.DatasetDict` into new splits.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"datasets::dataset_remix\".\n\n    Examples\n    --------\n\n    .. testcode::\n        :hide:\n\n        from tango.common.logging import initialize_logging\n        initialize_logging(enable_cli_logs=True)\n        import datasets\n\n    .. testcode::\n\n        input = datasets.load_dataset(\"lhoestq/test\")\n        new_splits = {\n            \"all\": \"train + validation\",\n            \"crossval_train\": \"train[:1] + validation[1:]\",\n            \"crossval_test\": \"train[1:] + validation[:1]\",\n        }\n        step = DatasetRemixStep()\n        remixed_dataset = step.run(input=input, new_splits=new_splits)\n\n    .. testoutput::\n        :hide:\n        :options: +ELLIPSIS\n\n        ...\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    VERSION = \"001\"\n\n    def run(  # type: ignore\n        self,\n        input: ds.DatasetDict,\n        new_splits: Dict[str, str],\n        keep_old_splits: bool = True,\n        shuffle_before: bool = False,\n        shuffle_after: bool = False,\n        random_seed: int = 1532637578,\n    ) -> ds.DatasetDict:\n        \"\"\"\n        Remixes and shuffles a dataset. This is done eagerly with native 🤗 Datasets features.\n\n        :param input:\n            The input dataset that will be remixed.\n        :param new_splits:\n            Specifies the new splits that the output dataset should have. Keys are the name of the new\n            splits. Values refer to the original splits. You can refer to original splits in the following ways:\n\n            * Mention the original split name to copy it to a new name.\n            * Mention the original split name with Python's slicing syntax to select part of the original\n              split's instances. For example, ``\"train[:1000]\"`` selects the first 1000 instances from the\n              ``\"train\"`` split.\n            * ``\"instances + instances\"`` concatenates the instances into one split.\n\n            You can combine these possibilities.\n        :param keep_old_splits:\n            Whether to keep the splits from the input dataset in addition to the new ones given by\n            ``new_splits``.\n        :param shuffle_before:\n            Whether to shuffle the input splits before creating the new ones.\n\n            If you need shuffled instances and you're not sure the input is properly shuffled, use this.\n        :param shuffle_after:\n            Whether to shuffle the input splits after creating the new ones.\n\n            If you need shuffled instances and you're slicing or concatenating splits, use this.\n\n            If you want to be on the safe side, shuffle both before and after.\n        :param random_seed:\n            Random seed, affects shuffling\n\n        :returns:\n            Returns a new dataset that is appropriately remixed.\n        \"\"\"\n\n        if shuffle_before:\n            input = input.shuffle(random_seed)\n\n        def get_slice(split_name: str) -> ds.Dataset:\n            slice_match = re.match(r\"(.*)\\[(-?[0-9]*:-?[0-9]*)\\]\", split_name)\n            if slice_match is None:\n                return input[split_name]\n            else:\n                split_name = slice_match[1]\n                slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(\":\")]\n                slice_indices = range(*slice(*slice_args).indices(len(input[split_name])))\n                return input[split_name].select(slice_indices)\n\n        def parse_split_spec(split_spec: str):\n            parts = [get_slice(name.strip()) for name in split_spec.split(\"+\")]\n            if len(parts) == 1:\n                return parts[0]\n            else:\n                return ds.concatenate_datasets(parts)\n\n        if keep_old_splits:\n            result = ds.DatasetDict(input.items())\n        else:\n            result = ds.DatasetDict()\n        result.update(\n            {\n                new_split_name: parse_split_spec(new_split_spec)\n                for new_split_name, new_split_spec in new_splits.items()\n            }\n        )\n\n        if shuffle_after:\n            result = result.shuffle(random_seed)\n\n        return result\n"
  },
  {
    "path": "tango/integrations/fairscale/__init__.py",
    "content": "\"\"\"\n.. important::\n    To use this integration you should install ``tango`` with the \"fairscale\" extra\n    (e.g. ``pip install tango[fairscale]``) or just install FairScale after the fact.\n\n    This integration also depends on `PyTorch <https://pytorch.org/>`_, so make sure you\n    install the correct version of torch *first* given your operating system and supported\n    CUDA version. Check `pytorch.org/get-started/locally/ <https://pytorch.org/get-started/locally/>`_\n    for more details.\n\nComponents for Tango integration with `FairScale <https://github.com/facebookresearch/fairscale>`_.\n\nOverview\n--------\n\nFairScale is a PyTorch library for large scale training. Among other things, it implements\nthe main memory-savings techniques for distributed data-parallel training (DDP) that came from the paper\n`ZeRO: Memory Optimization Towards Training A Trillion Parameter Models\n<https://api.semanticscholar.org/CorpusID:203736482>`_.\n\nThe main part of this Tango integration is the :class:`FairScaleTrainingEngine`.\nThis is a :class:`~tango.integrations.torch.TrainingEngine` implementation that utilizes\nFairScale's :class:`~fairscale.nn.FullyShardedDataParallel` (FSDP) for substantial memory savings\nduring distributed training.\n\nFor the best performance you should also use :func:`with_wrapped_modules()` to wrap the inner modules\nof your :class:`~tango.integrations.torch.Model`. When used with FSDP this will dramatically reduce\nthe memory required to load your model.\n\n\"\"\"\n\nfrom tango.common.exceptions import IntegrationMissingError\n\ntry:\n    import fairscale\nexcept ModuleNotFoundError:\n    raise IntegrationMissingError(\"fairscale\")\n\n__all__ = [\n    \"FairScaleTrainingEngine\",\n    \"FSDPConfig\",\n    \"with_wrapped_modules\",\n]\n\nfrom .fsdp_config import FSDPConfig\nfrom .module_wrapper import with_wrapped_modules\nfrom .training_engine import FairScaleTrainingEngine\n"
  },
  {
    "path": "tango/integrations/fairscale/fsdp_config.py",
    "content": "from dataclasses import asdict, dataclass\nfrom typing import Any, Dict, Optional\n\nimport torch\nfrom fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP\n\nfrom tango.common import FromParams\n\n\n@dataclass\nclass FSDPConfig(FromParams):\n    \"\"\"\n    Defines all of the configurable options for FairScale's :class:`~fairscale.nn.FullyShardedDataParallel`.\n\n    .. seealso::\n        `Best practices for FullyShardedDataParallel <https://fairscale.readthedocs.io/en/latest/deep_dive/oss_sdp_fsdp.html#best-practices-for-fairscale-nn-fullyshardeddataparallel>`_\n        from the FairScale docs.\n\n    \"\"\"  # noqa: E501\n\n    reshard_after_forward: bool = True\n    \"\"\"\n    See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.\n    \"\"\"\n\n    move_params_to_cpu: bool = False\n    \"\"\"\n    See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.\n    \"\"\"\n\n    move_grads_to_cpu: Optional[bool] = None\n    \"\"\"\n    See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.\n\n    .. seealso::\n        :data:`move_params_to_cpu`\n\n    .. warning::\n        At the moment we recommend that you don't mess with this parameter, or only explicitly\n        set it to the same value as :data:`move_params_to_cpu`. If you leave it as ``None``\n        (the default), it will automatically be set to match :data:`move_params_to_cpu` by FairScale.\n\n        Currently training seems to crash if you set this ``False`` while :data:`move_params_to_cpu` is ``True``.\n        We're tracking `fairscale#918 <https://github.com/facebookresearch/fairscale/issues/918>`_,\n        which may be related.\n    \"\"\"\n\n    mixed_precision: bool = False\n    \"\"\"\n    See the docstring for :class:`~fairscale.nn.FullyShardedDataParallel`.\n\n    .. important::\n        We recommend setting this to the same value as the ``amp`` parameter in\n        :class:`FairScaleTrainingEngine`.\n\n        Based on our experiments, if you're training with AMP enabled (``amp=True``)\n        you might see a small additional speedup in training time along with a small\n        additional decrease in GPU memory utilization without any performance penalty\n        (with respect to convergence) by setting this to ``True``.\n        But if you're *not* training with AMP, setting this ``True`` could impact the\n        model's ability to converge.\n\n    \"\"\"\n\n    def as_kwargs(self) -> Dict[str, Any]:\n        \"\"\"\n        Convert to the appropriate ``kwargs`` for :class:`~fairscale.nn.FullyShardedDataParallel`.\n        \"\"\"\n        return asdict(self)\n\n    def wrap(self, module: torch.nn.Module):\n        \"\"\"\n        A convenience method for wrapping a module in :class:`~fairscale.nn.FullyShardedDataParallel`\n        with all of the options defined in this class.\n\n        .. seealso::\n            Internally this is what :func:`with_wrapped_modules()` calls.\n\n        \"\"\"\n        return FSDP(module, **self.as_kwargs())\n"
  },
  {
    "path": "tango/integrations/fairscale/module_wrapper.py",
    "content": "import re\nfrom typing import Optional, Set\n\nimport torch\nimport torch.nn as nn\nfrom fairscale.nn.checkpoint import checkpoint_wrapper\n\nfrom tango.integrations.torch import Model\n\nfrom .fsdp_config import FSDPConfig\n\n\n@Model.register(\"fairscale::with_wrapped_modules\")  # type: ignore[arg-type]\ndef with_wrapped_modules(\n    model: Model,\n    modules_to_wrap: Set[str],\n    fsdp_config: Optional[FSDPConfig] = None,\n    activation_checkpointing: bool = False,\n) -> Model:\n    \"\"\"\n    A :class:`~tango.integrations.torch.Model` wrapper that can be used to easily wrap\n    inner modules of a model with FairScale's :class:`~fairscale.nn.FullyShardedDataParallel` wrapper\n    and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`.\n\n    .. tip::\n        Registered as a :class:`~tango.integrations.torch.Model` constructor under the name\n        \"fairscale::with_wrapped_modules\".\n\n    .. important::\n        This is meant to be used with the :class:`FairScaleTrainingEngine`.\n\n    :param model:\n        The model to wrap.\n    :param modules_to_wrap:\n        The names of submodule to wrap. Can be regular expressions.\n    :param fsdp_config:\n        The ``FullyShardedDataParallel`` configuration to use when wrapping the modules.\n        If not specified, the modules will NOT be wrapped with FSDP.\n    :param activation_checkpointing:\n        Whether to wrap the modules with FairScale's\n        :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`.\n\n    Examples\n    --------\n\n    You can use this as a :class:`~tango.integrations.torch.Model` constructor from a config/params\n    like this:\n\n    .. testcode::\n\n        import torch.nn as nn\n        from tango.integrations.torch import Model\n\n\n        class FeedForward(nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.linear = nn.Linear(4, 4)\n                self.activation = nn.ReLU()\n\n            def forward(self, x):\n                return self.activation(self.linear(x))\n\n        @Model.register(\"simple_regression_model\")\n        class SimpleRegressionModel(Model):\n            def __init__(self):\n                super().__init__()\n                self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)])\n                self.regression_head = nn.Linear(4, 1)\n                self.loss_fcn = nn.MSELoss()\n\n            def forward(self, x, y):\n                output = self.blocks(x)\n                output = self.regression_head(output)\n                loss = self.loss_fcn(output, y)\n                return {\"loss\": loss}\n\n\n        model = Model.from_params({\n            \"type\": \"fairscale::with_wrapped_modules\",\n            \"model\": {\n                \"type\": \"simple_regression_model\",\n            },\n            \"modules_to_wrap\": [r\"blocks\\\\.[0-9]+\", \"regression_head\"],\n            \"activation_checkpointing\": True,\n        })\n\n    \"\"\"\n\n    def wrap_module(\n        module: nn.Module,\n    ) -> nn.Module:\n        if activation_checkpointing:\n            module = checkpoint_wrapper(module, offload_to_cpu=True)\n        if fsdp_config is not None and torch.distributed.is_initialized():\n            module = fsdp_config.wrap(module)\n        return module\n\n    all_module_names: Set[str] = set([name for name, _ in model.named_modules() if name])\n    actual_modules_to_wrap: Set[str] = set()\n    unmatched_patterns: Set[str] = modules_to_wrap.copy()\n    for module_name in all_module_names:\n        for pattern in modules_to_wrap:\n            if re.fullmatch(pattern, module_name):\n                actual_modules_to_wrap.add(module_name)\n                if pattern in unmatched_patterns:\n                    unmatched_patterns.remove(pattern)\n\n    if unmatched_patterns:\n        raise ValueError(\n            f\"Some patterns in 'modules_to_wrap' did not match actual module names ({unmatched_patterns})\"\n        )\n\n    for module_name in actual_modules_to_wrap:\n        if \".\" in module_name:\n            *parent_parts, module_name = module_name.split(\".\")\n            parent_module = model.get_submodule(\".\".join(parent_parts))\n        else:\n            parent_module = model\n        module = parent_module.get_submodule(module_name)\n        module = wrap_module(module)\n        parent_module.add_module(module_name, module)\n\n    return model\n"
  },
  {
    "path": "tango/integrations/fairscale/training_engine.py",
    "content": "import logging\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Union\n\nimport torch\nfrom fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP\nfrom fairscale.optim.grad_scaler import ShardedGradScaler\n\nfrom tango.common import Lazy\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.integrations.torch import (\n    LRScheduler,\n    Model,\n    Optimizer,\n    TorchTrainingEngine,\n    TrainConfig,\n    TrainingEngine,\n)\n\nfrom .fsdp_config import FSDPConfig\n\n\n@TrainingEngine.register(\"fairscale\")\nclass FairScaleTrainingEngine(TorchTrainingEngine):\n    \"\"\"\n    A :class:`~tango.integrations.torch.TrainingEngine` that leverages FairScale's\n    :class:`~fairscale.nn.FullyShardedDataParallel` for use within\n    :class:`~tango.integrations.torch.TorchTrainStep`.\n\n    .. tip::\n        Registered as an :class:`~tango.integrations.torch.TrainingEngine` under the name\n        \"fairscale\".\n\n    .. tip::\n        To get the best performance out of :class:`FairScaleTrainingEngine` you should\n        wrap individual layers of your model with :class:`~fairscale.nn.FullyShardedDataParallel`\n        and/or :class:`~fairscale.nn.checkpoint.checkpoint_wrapper`\n        while instantiating them. You can use :class:`with_wrapped_modules()` to accomplish this.\n\n    .. important::\n        Only the parameters listed below should be defined in a configuration\n        file. The other parameters will be automatically passed to the constructor\n        within :class:`~tango.integrations.torch.TorchTrainStep`.\n\n    .. warning::\n        :class:`~FairScaleTrainingEngine` can only be used in distributed training, i.e.\n        when ``device_count > 1`` in the :class:`~tango.integrations.torch.TorchTrainStep`.\n\n    For maximum memory savings, we recommend training with AMP enabled and the following\n    :class:`FSDPConfig`:\n\n    .. testcode::\n\n        from tango.integrations.fairscale import FSDPConfig\n\n        fsdp_config = FSDPConfig(\n            reshard_after_forward=True,\n            move_params_to_cpu=True,\n            move_grads_to_cpu=True,\n            mixed_precision=True,\n        )\n\n    For maximum training *speed*, we recommend training with AMP enabled and the following\n    :class:`FSDPConfig`:\n\n    .. testcode::\n\n        from tango.integrations.fairscale import FSDPConfig\n\n        fsdp_config = FSDPConfig(\n            reshard_after_forward=False,\n            move_params_to_cpu=False,\n            move_grads_to_cpu=False,\n            mixed_precision=True,\n        )\n\n    :param amp:\n        Use automatic mixed precision (AMP). Default is ``False``.\n    :param max_grad_norm:\n        If set, gradients will be clipped to have this max norm. Default is ``None``.\n    :param amp_use_bfloat16:\n        Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training.\n        Only applicable when ``amp=True``. If not specified, the default behavior will be\n        to use ``bfloat16`` when training with AMP on CPU, otherwise not.\n    :param fsdp_config:\n        The options for :class:`~fairscale.nn.FullyShardedDataParallel`.\n        If not specified, the default options will be used.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        train_config: TrainConfig,\n        model: Lazy[Model],\n        optimizer: Lazy[Optimizer],\n        *,\n        lr_scheduler: Optional[Lazy[LRScheduler]] = None,\n        amp: bool = False,\n        max_grad_norm: Optional[float] = None,\n        amp_use_bfloat16: Optional[bool] = None,\n        fsdp_config: Optional[FSDPConfig] = None,\n    ) -> None:\n        if not train_config.is_distributed:\n            raise ConfigurationError(\n                f\"{self.__class__.__name__} can only be used with distributed training\"\n            )\n\n        self.fsdp_config = fsdp_config or FSDPConfig()\n        self.logger = logging.getLogger(self.__class__.__name__)\n\n        super().__init__(\n            train_config,\n            model,\n            optimizer,\n            lr_scheduler=lr_scheduler,\n            amp=amp,\n            max_grad_norm=max_grad_norm,\n            amp_use_bfloat16=amp_use_bfloat16,\n        )\n        if amp:\n            self.grad_scaler = ShardedGradScaler()\n\n    def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model:\n        if isinstance(model, Lazy):\n            model = model.construct()\n        if not self.fsdp_config.move_params_to_cpu:\n            model.to(self.train_config.worker_local_default_device)\n        return FSDP(model, **self.fsdp_config.as_kwargs())\n\n    def clip_grad_norm(self) -> None:\n        if self.max_grad_norm is not None:\n            self.model.clip_grad_norm_(self.max_grad_norm)  # type: ignore\n\n    def get_model_state(self) -> Dict[str, torch.Tensor]:\n        return {\n            \"weights\": self.model.local_state_dict(),  # type: ignore\n            \"metadata\": self.model.local_metadata_dict(),  # type: ignore\n        }\n\n    def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None:\n        self.model.load_local_state_dict(state_dict[\"weights\"])  # type: ignore\n\n    def save_complete_weights_from_checkpoint(\n        self, checkpoint_dir: Path, weights_path: Path\n    ) -> None:\n        self.logger.info(\"Consolidating sharded checkpoint weights...\")\n        sharded_weights: List[Dict[str, torch.Tensor]] = []\n        sharded_metadata: List[Dict[str, Any]] = []\n        for path in checkpoint_dir.resolve().glob(\"worker*_model.pt\"):\n            sharded_state = torch.load(path, map_location=\"cpu\")\n            sharded_weights.append(sharded_state[\"weights\"])\n            sharded_metadata.append(sharded_state[\"metadata\"])\n        full_state = FSDP.consolidate_shard_weights(sharded_weights, sharded_metadata)\n        del sharded_weights\n        del sharded_metadata\n        torch.save(full_state, weights_path)\n"
  },
  {
    "path": "tango/integrations/flax/__init__.py",
    "content": "from tango.common.exceptions import IntegrationMissingError\n\ntry:\n    import flax\nexcept ModuleNotFoundError:\n    raise IntegrationMissingError(\"flax\")\n\n__all__ = [\n    \"DataLoader\",\n    \"FlaxDataLoader\",\n    \"LRScheduler\",\n    \"Model\",\n    \"Optimizer\",\n    \"FlaxTrainStep\",\n    \"FlaxFormat\",\n    \"TrainCallback\",\n    \"EvalCallback\",\n    \"FlaxWrapper\",\n    \"TrainConfig\",\n    \"FlaxEvalStep\",\n]\n\nfrom .data import DataLoader, FlaxDataLoader\nfrom .eval import FlaxEvalStep\nfrom .eval_callback import EvalCallback\nfrom .format import FlaxFormat\nfrom .model import Model\nfrom .optim import LRScheduler, Optimizer\nfrom .train import FlaxTrainStep\nfrom .train_callback import TrainCallback\nfrom .train_config import TrainConfig\nfrom .wrapper import FlaxWrapper\n"
  },
  {
    "path": "tango/integrations/flax/data.py",
    "content": "import logging\nfrom typing import Generic, TypeVar\n\nimport jax.random\nimport numpy as np\nfrom datasets import Dataset\nfrom flax.training.common_utils import shard\n\nfrom tango.common.registrable import Registrable\n\nT = TypeVar(\"T\")\n\n\nclass DataLoader(Generic[T], Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of a ``Flax DataLoader``.\n    ``Flax DataLoader`` accepts Dataset object. The class yields a numpy batch.\n    \"\"\"\n\n\n@DataLoader.register(\"flax::dataloader\")\nclass FlaxDataLoader(DataLoader):\n    def __init__(\n        self,\n        dataset: Dataset,\n        batch_size: int = 8,\n        drop_last: bool = True,\n        shuffle: bool = True,\n    ):\n        self.dataset = dataset\n        self.dataset_size = dataset.num_rows\n        self.batch_size = batch_size\n        self.drop_last = drop_last\n        if not drop_last:\n            raise NotImplementedError(\n                \"With Jax you have to drop the last incomplete batch, because the batch size is compiled into the \"\n                \"model.\"\n            )\n        self.shuffle = shuffle\n\n        self.logger = logging.getLogger(FlaxDataLoader.__name__)\n\n    def __call__(self, rng: jax._src.random.KeyArrayLike, do_distributed: bool):\n        steps_per_epoch = self.dataset_size // self.batch_size\n\n        if self.shuffle:\n            perms = jax.random.permutation(rng, self.dataset_size)\n            perms = np.asarray(perms)  # using jax arrays for indexing is a bottleneck on TPUs.\n        else:\n            perms = np.arange(self.dataset_size)\n\n        self.logger.info(\"Skipping last incomplete batch\")\n        perms = perms[: steps_per_epoch * self.batch_size]  # Skip incomplete batch.\n        perms = perms.reshape((steps_per_epoch, self.batch_size))\n\n        for perm in perms:\n            batch = self.dataset[perm]\n            if do_distributed:\n                batch = shard(batch)\n            yield batch\n"
  },
  {
    "path": "tango/integrations/flax/eval.py",
    "content": "import logging\nfrom collections import defaultdict\nfrom itertools import islice\nfrom typing import Dict, List, Optional, Sequence\n\nimport jax\nfrom flax import jax_utils\nfrom flax.training.train_state import TrainState\n\nfrom tango.common.dataset_dict import DatasetDictBase\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.lazy import Lazy\nfrom tango.common.tqdm import Tqdm\nfrom tango.format import Format, JsonFormat\nfrom tango.step import Step\n\nfrom .data import FlaxDataLoader\nfrom .eval_callback import EvalCallback\nfrom .util import get_PRNGkey\nfrom .wrapper import FlaxWrapper\n\n\n@Step.register(\"flax::eval\")\nclass FlaxEvalStep(Step):\n    \"\"\"\n    A Flax evaluation loop that pairs well with :class:`FlaxTrainStep`.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"flax::eval\".\n\n    .. important::\n\n        The evaluation loop will use a GPU/TPU automatically if one is available.\n        You can control which GPU it uses with the environment variable ``CUDA_VISIBLE_DEVICES``.\n        For example, set ``CUDA_VISIBLE_DEVICES=1`` to force ``FlaxEvalStep`` to only use\n        the GPU with ID 1.\n\n    .. warning::\n\n        By default the metrics specified by the ``metric_names`` parameter\n        are aggregated by simply averaging across batches.\n        This behavior is usually correct for metrics like \"loss\" or \"accuracy\",\n        for example, but may not be correct for other metrics like \"F1\".\n\n        If this is not correct for your metric you will need to handle the aggregation\n        internally in your model or with an :class:`EvalCallback`\n        using the :meth:`EvalCallback.post_batch()` method.\n        Then set the parameter ``auto_aggregate_metrics`` to ``False``.\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = JsonFormat()\n    SKIP_ID_ARGUMENTS = {\"log_every\"}\n\n    def run(  # type: ignore[override]\n        self,\n        state: TrainState,\n        dataset: DatasetDictBase,\n        dataloader: Lazy[FlaxDataLoader],\n        wrapper: FlaxWrapper,\n        test_split: str = \"test\",\n        seed: int = 42,\n        log_every: int = 1,\n        do_distributed: bool = False,\n        eval_steps: Optional[int] = None,\n        metric_names: Sequence[str] = (\"loss\",),\n        auto_aggregate_metrics: bool = True,\n        callbacks: Optional[List[Lazy[EvalCallback]]] = None,\n    ) -> Dict[str, float]:\n        \"\"\"\n        Evaluate the ``model``.\n\n        :param state:\n            The state of the model to evaluate. This contains the parameters.\n        :param dataset:\n            Should contain the test data.\n        :param dataloader:\n            The data loader that generates test batches. The batches should be :class:`dict`\n            objects.\n        :param wrapper:\n            The wrapper should define :meth:`eval_metrics`.\n        :param test_split:\n            The name of the data split used for evaluation in the ``dataset_dict``.\n            Default is \"test\".\n        :param seed:\n             Used to set the PRNG states at the beginning of the evaluation loop.\n        :param log_every:\n            Log every this many steps. Default is ``1``.\n        :param do_distributed:\n            Whether to do distributed training or not. Set as 0 or 1.\n        :param eval_steps:\n            The number of steps to evaluate for. If not specified evaluation will\n            stop after a complete iteration through the ``dataloader``.\n        :param metric_names:\n            The names of the metrics to track and aggregate. Default is ``(\"loss\",)``.\n        :param auto_aggregate_metrics:\n            If ``True`` (the default), the metrics will be averaged across batches.\n            This may not be the correct behavior for some metrics (such as F1),\n            in which you should set this to ``False`` and handle the aggregation\n            internally in your model or with an :class:`EvalCallback`\n            (using :meth:`EvalCallback.post_batch()`).\n        :param callbacks:\n            A list of :class:`EvalCallback`.\n\n        \"\"\"\n\n        logger = logging.getLogger(FlaxEvalStep.__name__)\n        # construct dataloader\n        dataloader: FlaxDataLoader = dataloader.construct(\n            dataset=dataset[test_split].set_format(\"numpy\")\n        )\n\n        steps: int\n        try:\n            dataloader_len = dataloader.dataset_size\n            steps = dataloader_len if eval_steps is None else min(dataloader_len, eval_steps)\n        except TypeError:\n            if eval_steps is None:\n                raise ConfigurationError(\n                    \"You must set 'eval_steps' for streaming/iterable datasets\"\n                )\n            else:\n                steps = eval_steps\n\n        if do_distributed:\n            devices = jax.devices()\n            if len(devices) <= 1:\n                raise ConfigurationError(\n                    \"YOu have set distributed training=True but there is only one device.\"\n                )\n\n        # Initialize callbacks\n        callbacks: List[EvalCallback] = [\n            callback.construct(\n                step_id=self.unique_id,\n                work_dir=self.work_dir,\n                dataset_dict=dataset,\n                dataloader=dataloader,\n            )\n            for callback in (callbacks or [])\n        ]\n\n        for callback in callbacks:\n            callback.pre_eval_loop()\n\n        rng = get_PRNGkey(seed)\n        devices = jax.devices()\n        if len(devices) > 1:\n            do_distributed = True\n\n        def eval_step(state, batch):\n            labels = batch.pop(\"labels\")\n            logits = state.apply_fn(**batch, params=state.params, train=False)[0]\n            metrics = wrapper.eval_metrics(batch, logits, labels)\n            if do_distributed:\n                metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n            return logits, metrics\n\n        if do_distributed:\n            state = jax_utils.replicate(state)\n            parallel_eval_step = jax.pmap(eval_step, axis_name=\"batch\")\n\n        eval_batches = enumerate(islice(dataloader(rng, do_distributed), steps))\n\n        running_metrics: Dict[str, float] = defaultdict(float)\n        aggregated_metrics: Dict[str, float] = defaultdict(float)\n\n        with Tqdm.tqdm(eval_batches, desc=\"Evaluating\", total=steps) as batch_iter:\n            for step, batch in batch_iter:\n                should_log_this_step = step % log_every == 0 or step == steps - 1\n                for callback in callbacks:\n                    callback.pre_batch(step, batch)\n\n                if do_distributed:\n                    logits, metrics = parallel_eval_step(state, batch)\n                    metrics = jax_utils.unreplicate(metrics)\n                else:\n                    logits, metrics = eval_step(state, batch)\n\n                for callback in callbacks:\n                    callback.post_batch(step, logits)\n\n                if auto_aggregate_metrics:\n                    for key, val in metrics.items():\n                        if key in metric_names:\n                            running_metrics[key] += metrics[key].item()\n                            aggregated_metrics[key] = running_metrics[key] / (step + 1)\n                else:\n                    aggregated_metrics.update(metrics)\n\n                if should_log_this_step:\n                    batch_iter.set_postfix(**aggregated_metrics)\n                del batch\n\n        logger.info(\"Evaluation Metrics:\")\n        for key, val in aggregated_metrics.items():\n            logger.info(key, \":\", val)\n\n        for callback in callbacks:\n            callback.post_eval_loop(aggregated_metrics)\n\n        return aggregated_metrics\n"
  },
  {
    "path": "tango/integrations/flax/eval_callback.py",
    "content": "from pathlib import Path\nfrom typing import Any, Dict\n\nfrom tango.common.dataset_dict import DatasetDictBase\nfrom tango.common.registrable import Registrable\nfrom tango.workspace import Workspace\n\nfrom .data import FlaxDataLoader\n\n\nclass EvalCallback(Registrable):\n    \"\"\"\n    An ``EvalCallback`` is a :class:`~tango.common.Registrable` class that can be used\n    within :class:`FlaxEvalStep` to customize the behavior of the evaluation loop,\n    similar to how :class:`TrainCallback` is used to customize the behavior of the training\n    loop.\n\n    .. tip::\n        All of the parameters to this base class will be automatically set within\n        the training loop, so you shouldn't include them in your config for your callbacks.\n\n    :ivar Workspace workspace: The tango workspace being used.\n    :ivar str step_id: The unique ID of the step.\n    :ivar pathlib.Path work_dir: The working directory of the step\n    :ivar DatasetDictBase dataset_dict: The dataset dict containing the evaluation split.\n    :ivar DataLoader dataloader: The data loader used to load the evaluation split data.\n    \"\"\"\n\n    def __init__(\n        self,\n        workspace: Workspace,\n        step_id: str,\n        work_dir: Path,\n        dataset_dict: DatasetDictBase,\n        dataloader: FlaxDataLoader,\n    ) -> None:\n        self.workspace = workspace\n        self.step_id = step_id\n        self.work_dir = work_dir\n        self.dataset_dict = dataset_dict\n        self.dataloader = dataloader\n\n    def pre_eval_loop(self) -> None:\n        \"\"\"\n        Called right before the first batch is processed\n        \"\"\"\n        pass\n\n    def post_eval_loop(self, aggregated_metrics: Dict[str, float]) -> None:\n        \"\"\"\n        Called after the evaluation loop completes with the final aggregated metrics.\n\n        This is the last method that is called, so any cleanup can be done in this method.\n        \"\"\"\n        pass\n\n    def pre_batch(self, step: int, batch: Dict[str, Any]) -> None:\n        \"\"\"\n        Called directly before processing a batch.\n        \"\"\"\n        pass\n\n    def post_batch(self, step: int, batch_outputs: Dict[str, Any]) -> None:\n        \"\"\"\n        Called directly after processing a batch with the outputs of the batch.\n\n        .. tip::\n            This method can be used to modify ``batch_outputs`` in place, which is useful\n            in scenarios where you might need to aggregate metrics\n            in a special way other than a simple average. If that's the case, make sure\n            to set ``auto_aggregate_metrics`` to ``False`` in :class:`FlaxEvalStep`.\n\n        \"\"\"\n        pass\n"
  },
  {
    "path": "tango/integrations/flax/format.py",
    "content": "from pathlib import Path\nfrom typing import Generic, TypeVar\n\nfrom flax.training import checkpoints\n\nfrom tango.common.aliases import PathOrStr\nfrom tango.format import Format\n\nT = TypeVar(\"T\")\n\n\n@Format.register(\"flax\")\nclass FlaxFormat(Format[T], Generic[T]):\n    \"\"\"\n    This format writes the artifact.\n\n    .. tip::\n\n        Registered as a :class:`~tango.format.Format` under the name \"flax\".\n    \"\"\"\n\n    VERSION = \"002\"\n\n    def write(self, artifact: T, dir: PathOrStr) -> None:\n        checkpoints.save_checkpoint(Path(dir), artifact, step=0)\n\n    def read(self, dir: PathOrStr) -> T:\n        # will return a dict\n        return checkpoints.restore_checkpoint(dir, target=None)\n"
  },
  {
    "path": "tango/integrations/flax/model.py",
    "content": "from flax import linen as nn\n\nfrom tango.common.registrable import Registrable\n\n\nclass Model(nn.Module, Registrable):\n    \"\"\"\n    This is a :class:`~tango.common.Registrable` mixin class that inherits from\n    :class:`flax.linen.Module`.\n    Its :meth:`~flax.linen.Module.setup()` can be used to register submodules,\n    variables, parameters you will need in your model.\n    Its :meth:`~flax.linen.Module.__call__()` returns the output of the model\n    for a given input.\n    \"\"\"\n"
  },
  {
    "path": "tango/integrations/flax/optim.py",
    "content": "from inspect import isfunction\nfrom typing import Callable, Type\n\nimport optax\n\nfrom tango.common.registrable import Registrable\n\n\nclass Optimizer(Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of Optax optimizers.\n\n    All `built-in Optax optimizers\n    <https://optax.readthedocs.io/en/latest/api.html#>`_\n    are registered according to their class name (e.g. \"optax::adam\").\n\n    .. tip::\n\n        You can see a list of all available optimizers by running\n\n        .. testcode::\n\n            from tango.integrations.flax import Optimizer\n            for name in sorted(Optimizer.list_available()):\n                print(name)\n\n        .. testoutput::\n            :options: +ELLIPSIS\n\n            optax::adabelief\n            optax::adadelta\n            optax::adafactor\n            optax::adagrad\n            optax::adam\n            ...\n\n    \"\"\"\n\n    def __init__(self, optimizer: Callable) -> None:\n        self.optimizer = optimizer\n\n    def __call__(self, **kwargs) -> optax.GradientTransformation:\n        return self.optimizer(**kwargs)\n\n\nclass LRScheduler(Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of an Optax learning\n    rate scheduler.\n\n    All `built-in Optax learning rate schedulers\n    <https://optax.readthedocs.io/en/latest/api.html#schedules>`_\n    are registered according to their class name (e.g. \"optax::linear_schedule\").\n\n    .. tip::\n\n        You can see a list of all available schedulers by running\n\n        .. testcode::\n\n            from tango.integrations.flax import LRScheduler\n            for name in sorted(LRScheduler.list_available()):\n                print(name)\n\n        .. testoutput::\n            :options: +ELLIPSIS\n\n            optax::constant_schedule\n            optax::cosine_decay_schedule\n            optax::cosine_onecycle_schedule\n            optax::exponential_decay\n            ...\n\n    \"\"\"\n\n    def __init__(self, scheduler: Callable) -> None:\n        self.scheduler = scheduler\n\n    def __call__(self, **kwargs):\n        return self.scheduler(**kwargs)\n\n\ndef optimizer_factory(optim_method: Callable) -> Type[Callable]:\n    def factory_func():\n        return Optimizer(optim_method)\n\n    return factory_func()\n\n\ndef scheduler_factory(scheduler_method: Callable) -> Type[Callable]:\n    def factory_func():\n        return LRScheduler(scheduler_method)\n\n    return factory_func()\n\n\n# Register all optimizers.\nfor name, cls in optax._src.alias.__dict__.items():\n    if isfunction(cls) and not name.startswith(\"_\") and cls.__annotations__:\n        factory_func = optimizer_factory(cls)\n        Optimizer.register(\"optax::\" + name)(factory_func)\n\n# Register all learning rate schedulers.\nfor name, cls in optax.schedules.__dict__.items():\n    if isfunction(cls) and not name.startswith(\"_\") and cls.__annotations__:\n        factory_func = scheduler_factory(cls)\n        LRScheduler.register(\"optax::\" + name)(factory_func)\n\n# TODO: Handle inject_hyperparams.\n# Refer: https://optax.readthedocs.io/en/latest/api.html?highlight=inject%20hyperparam\n"
  },
  {
    "path": "tango/integrations/flax/train.py",
    "content": "import logging\nimport time\nfrom collections import defaultdict\nfrom pathlib import Path\nfrom typing import Any, DefaultDict, Dict, List, Optional\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import jax_utils\nfrom flax.training import checkpoints\nfrom flax.training.train_state import TrainState\n\nfrom tango.common.dataset_dict import DatasetDictBase\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.lazy import Lazy\nfrom tango.common.tqdm import Tqdm\nfrom tango.format import Format\nfrom tango.step import Step\nfrom tango.workspace import Workspace\n\nfrom .data import FlaxDataLoader\nfrom .format import FlaxFormat\nfrom .model import Model\nfrom .optim import LRScheduler, Optimizer\nfrom .train_callback import TrainCallback\nfrom .train_config import TrainConfig\nfrom .util import get_multiple_keys, get_PRNGkey\nfrom .wrapper import FlaxWrapper\n\nPyTree = Any\n\n\n@Step.register(\"flax::train\")\nclass FlaxTrainStep(Step):\n    \"\"\"\n    A Flax training step that supports distributed training with configurable dataloaders, callbacks,\n    optimizer.\n\n    .. tip::\n        Registered as a :class:`~tango.step.Step` under the name \"flax::train\".\n\n    .. important::\n        To train on GPUs and TPUs, installation of jax[cuda] or jax[tpu] is required. Follow the\n        instructions here: https://github.com/google/jax to set up jax for GPUs and TPUs.\n        Note: CUDA and cuDNN installation is required to run jax on NVidia GPUs. It is recommended to\n        install cuDNN in your conda environment using: ``conda install -c anaconda cudnn``.\n\n        Distributed data parallel training is activated when the ``device_count`` is greater than 1.\n        You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``.\n        For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1``\n        (and ``device_count`` to 2).\n\n    .. warning::\n        During validation, the validation metric (specified by the ``val_metric_name`` parameter)\n        is aggregated by simply averaging across validation batches and distributed processes.\n        This behavior is usually correct when your validation metric is \"loss\" or \"accuracy\",\n        for example, but may not be correct for other metrics like \"F1\".\n        If this is not correct for your metric you will need to handle the aggregation\n        internally in your model or with a :class:`TrainCallback`\n        using the :meth:`TrainCallback.post_val_batch()` method.\n        Then set the parameter ``auto_aggregate_val_metric`` to ``False``.\n\n        Jax pre-allocates 90% of GPU memory. If you run into out-of-memory (OOM) issues, please refer\n        to this: https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html.\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = FlaxFormat()\n    SKIP_ID_ARGUMENTS = {\"log_every\"}\n    METADATA = {\"artifact_kind\": \"model\"}\n\n    def run(  # type: ignore[override]\n        self,\n        model: Model,\n        dataset: DatasetDictBase,\n        optimizer: Lazy[Optimizer],\n        train_dataloader: Lazy[FlaxDataLoader],\n        *,\n        wrapper: FlaxWrapper,\n        seed: int = 42,\n        keep_checkpoints: int = 5,\n        lr_scheduler: Optional[Lazy[LRScheduler]] = None,\n        train_split: str = \"train\",\n        validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None,\n        validation_split: Optional[str] = None,\n        train_steps: Optional[int] = None,\n        train_epoch: Optional[int] = None,\n        validation_steps: Optional[int] = None,\n        log_every: int = 10,\n        checkpoint_every: int = 100,\n        validate_every: Optional[int] = None,\n        val_metric_name: str = \"loss\",\n        minimize_val_metric: bool = True,\n        auto_aggregate_val_metric: bool = True,\n        callbacks: Optional[List[Lazy[TrainCallback]]] = None,\n        remove_stale_checkpoints: bool = True,\n    ) -> PyTree:\n        \"\"\"\n        Run a basic training loop to train the ``model``.\n\n        :param model:\n            The flax model to train. It should define ``__call__()``. Defining ``setup()`` is Optional.\n        :param dataset:\n            The train and optional validation dataset.\n        :param optimizer:\n            The name of the optax Optimizer to use for training.\n        :param train_dataloader:\n            The dataloader object that generates training batches.\n        :param wrapper:\n            A Wrapper class that defines ``loss_fn()``, ``eval_fn()`` and ``compute_metrics()``\n        :param seed:\n            Used to set the PRNG state. By default, ``seed=42``\n        :param keep_checkpoints:\n            An integer which denotes how many previous checkpoints should be stored while training.\n            By default, ``keep_checkpoints=5``\n        :param lr_scheduler:\n            The name of the learning rate scheduler.\n        :param train_split:\n            The name of the data split used for training in the ``dataset_dict``.\n            Default is \"train\".\n        :param validation_dataloader:\n            An optional data loader for generating validation batches. The batches should be\n            :class:`dict` objects. If not specified, but ``validation_split`` is given,\n            the validation ``DataLoader`` will be constructed from the same parameters\n            as the train ``DataLoader``.\n        :param validation_split:\n            Optional name of the validation split in the ``dataset_dict``. Default is ``None``,\n            which means no validation.\n        :param train_steps:\n            The number of steps to train for. If not specified training will\n            stop after a complete iteration through the ``train_dataloader``.\n        :param train_epoch:\n            The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs``\n            at the same time.\n        :param validation_steps:\n            The number of steps to validate for. If not specified validation\n            will stop after a complete iteration through the ``validation_dataloader``.\n        :param log_every:\n            Log every this many steps.\n        :param checkpoint_every:\n            Save a checkpoint every this many steps.\n        :param validate_every:\n            Run the validation loop every this many steps.\n        :param val_metric_name:\n            The name of the validation metric, i.e. the key of the metric in the dictionary\n            returned by the forward pass of the model. Default is \"loss\".\n        :param minimize_val_metric:\n            Whether the validation metric is meant to be minimized (such as the loss).\n            Default is ``True``. When using a metric such as accuracy, you should set\n            this to ``False``.\n        :param auto_aggregate_val_metric:\n            If ``True`` (the default), the validation metric will be averaged across\n            validation batches and distributed processes. This may not be the correct\n            behavior for some metrics (such as F1), in which you should set this to\n            ``False`` and handle the aggregation internally in your model\n            or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`).\n        :param callbacks:\n            A list of :class: `TrainCallback`.\n        :param remove_stale_checkpoints:\n            If ``True`` (the default), stale checkpoints will be removed throughout training so that\n            only the latest and best checkpoints are kept.\n\n        :returns:\n            The trained model with the last checkpoint loaded.\n        \"\"\"\n\n        return self._train(\n            dataset=dataset,\n            model=model,\n            optimizer=optimizer,\n            train_dataloader=train_dataloader,\n            train_wrapper=wrapper,\n            seed=seed,\n            keep_checkpoints=keep_checkpoints,\n            lr_scheduler=lr_scheduler,\n            train_split=train_split,\n            validation_split=validation_split,\n            validation_dataloader=validation_dataloader,\n            train_steps=train_steps,\n            train_epochs=train_epoch,\n            validation_steps=validation_steps,\n            log_every=log_every,\n            checkpoint_every=checkpoint_every,\n            validate_every=validate_every,\n            val_metric_name=val_metric_name,\n            minimize_val_metric=minimize_val_metric,\n            auto_aggregate_val_metric=auto_aggregate_val_metric,\n            callbacks=callbacks,\n            remove_stale_checkpoints=remove_stale_checkpoints,\n        )\n\n    def _train(\n        self,\n        model: Model,\n        optimizer: Lazy[Optimizer],\n        dataset: DatasetDictBase,\n        train_dataloader: Lazy[FlaxDataLoader],\n        *,\n        train_wrapper: FlaxWrapper,\n        seed: int = 42,\n        keep_checkpoints: int = 5,\n        lr_scheduler: Optional[Lazy[LRScheduler]],\n        train_split: str = \"train\",\n        validation_split: Optional[str] = None,\n        validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None,\n        train_steps: Optional[int] = None,\n        train_epochs: Optional[int] = None,\n        validation_steps: Optional[int] = None,\n        log_every: int = 10,\n        checkpoint_every: int = 100,\n        validate_every: Optional[int] = None,\n        val_metric_name: str = \"loss\",\n        minimize_val_metric: bool = True,\n        auto_aggregate_val_metric: bool = True,\n        callbacks: Optional[List[Lazy[TrainCallback]]] = None,\n        remove_stale_checkpoints: bool = True,\n    ) -> PyTree:\n        if validate_every is not None and validation_split is None:\n            raise ConfigurationError(\n                \"You have set a validation interval, but no validation split. \"\n                \"That's probably unintentional.\"\n            )\n\n        if (train_steps is not None) and (train_epochs is not None):\n            raise ConfigurationError(\n                \"One of 'train_steps' or 'train_epochs' needs to be specified, but not both.\"\n            )\n\n        if isinstance(dataset, DatasetDictBase) and train_split is None:\n            raise ConfigurationError(\"Specify the train split for Datasets object.\")\n\n        config = TrainConfig(\n            self.unique_id,\n            self.work_dir,\n            step_name=self.name,\n            train_split=train_split,\n            validation_split=validation_split,\n            seed=seed,\n            train_steps=train_steps,\n            train_epochs=train_epochs,\n            log_every=log_every,\n            checkpoint_every=checkpoint_every,\n            validate_every=validate_every,\n            validation_steps=validation_steps,\n            val_metric_name=val_metric_name,\n            minimize_val_metric=minimize_val_metric,\n            auto_aggregate_val_metric=auto_aggregate_val_metric,\n            remove_stale_checkpoints=remove_stale_checkpoints,\n        )\n\n        optimizer = self._construct_optimizer(optimizer)\n\n        lr_scheduler_: Optional[LRScheduler] = None\n        if lr_scheduler is not None:\n            lr_scheduler_ = self._construct_lr_scheduler(lr_scheduler)\n        lr_scheduler = lr_scheduler_\n\n        final_model: Model\n\n        final_model = self.train_helper(\n            self.workspace,\n            config,\n            model,\n            optimizer,\n            keep_checkpoints,\n            lr_scheduler,\n            train_wrapper,\n            dataset,\n            train_dataloader,\n            validation_dataloader,\n            callbacks,\n        )\n        assert final_model is not None\n\n        return final_model\n\n    def train_helper(\n        self,\n        workspace: Workspace,\n        config: TrainConfig,\n        model: Model,\n        optimizer: Optimizer,\n        keep_checkpoints: int,\n        lr_scheduler: Optional[LRScheduler],\n        train_wrapper: FlaxWrapper,\n        dataset: DatasetDictBase,\n        train_dataloader: Lazy[FlaxDataLoader],\n        validation_dataloader: Optional[Lazy[FlaxDataLoader]] = None,\n        callbacks: Optional[List[Lazy[TrainCallback]]] = None,\n    ) -> PyTree:\n        if lr_scheduler is not None:\n            raise NotImplementedError(\n                \"Learning rate scheduling is not supported by the flax trainer. \"\n                \"Please voice your support for this feature at \"\n                \"https://github.com/allenai/tango/issues/477.\"\n            )\n\n        logger = logging.getLogger(FlaxTrainStep.__name__)\n\n        # construct data loaders\n        validation_dataloader_: Optional[FlaxDataLoader] = None\n        if config.validation_split is not None:\n            validation_dataset = dataset[config.validation_split]\n            validation_dataset.set_format(\"numpy\")\n            if validation_dataloader is not None:\n                validation_dataloader_ = validation_dataloader.construct(dataset=validation_dataset)\n            else:\n                validation_dataloader_ = train_dataloader.construct(dataset=validation_dataset)\n\n        validation_dataloader = validation_dataloader_\n\n        train_dataset = dataset[config.train_split]\n        train_dataset.set_format(\"numpy\")  # type:ignore\n        train_dataloader: FlaxDataLoader = train_dataloader.construct(dataset=train_dataset)\n\n        devices = self._get_devices()\n        do_distributed: bool = False\n        if len(devices) > 1:\n            do_distributed = True\n\n        if validation_dataloader is not None:\n            validation_dataloader.batch_size *= len(devices)\n        train_dataloader.batch_size *= len(devices)\n\n        rng = get_PRNGkey(config.seed)\n\n        if hasattr(model, \"params\"):\n            params = model.params\n        else:\n            # TODO: Find better way to init the shape\n            shape = list(train_dataset[\"x\"].shape)\n            shape[0] = 1\n            x = jnp.ones(shape)\n\n            params = model.init(rng, x)[\"params\"]\n\n        state = TrainState.create(apply_fn=model.__call__, params=params, tx=optimizer)\n\n        initial_state: Optional[Dict[str, Any]] = None\n        if config.state_path.exists():\n            logger.info(\"Recovering from previous run at %s\" % config.state_path)\n            state = self.load_checkpoint(config.state_path, state)\n\n        if config.train_epochs is None:\n            assert config.train_steps is not None\n            try:\n                train_epochs = len(train_dataloader.dataset) // train_dataloader.batch_size\n            except TypeError:\n                raise ConfigurationError(\n                    \"You must set train_epochs for streaming/iterable datasets\"\n                )\n\n            config.train_epochs = train_epochs\n\n        assert config.train_epochs is not None\n\n        if validation_dataloader is not None:\n            if config.validation_steps is None:\n                try:\n                    config.validation_steps = len(validation_dataloader.dataset)\n                except TypeError:\n                    raise ConfigurationError(\n                        \"You must set 'validation_steps' for streaming/iterable datasets\"\n                    )\n\n        val_metric: Optional[float] = None\n        best_val_metric: Optional[float] = None\n        start_step: int = 0\n\n        if initial_state is not None:\n            val_metric = initial_state[f\"val_{config.val_metric_name}\"]\n            best_val_metric = initial_state[f\"best_{config.val_metric_name}\"]\n            start_step = initial_state[\"training_epochs\"]\n\n        # Initialize callbacks\n        callbacks: List[TrainCallback] = [\n            callback.construct(\n                workspace=workspace,\n                train_config=config,\n                dataset=dataset,\n                train_dataloader=train_dataloader,\n                model=model,\n                optimizer=optimizer,\n                validation_dataloader=validation_dataloader,\n            )\n            for callback in (callbacks or [])\n        ]\n\n        if initial_state:\n            for callback, state in zip(callbacks, initial_state[\"callbacks\"]):\n                callback.load_state_dict(state)\n\n        del initial_state\n\n        if start_step > 0:\n            with Tqdm.tqdm(\n                train_dataloader,\n                desc=f\"Catching dataloader up to step {start_step}\",\n                total=start_step - 1,\n            ) as batch_iter:\n                for step, batch in enumerate(batch_iter):\n                    del batch\n                    if step >= start_step - 1:\n                        break\n\n        def train_step(state, batch, dropout_rng):\n            # if transformer model\n            labels = batch.pop(\"labels\")\n            dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)\n            grad_fn = jax.value_and_grad(train_wrapper.train_loss)\n            loss, grad = grad_fn(state.params, state, batch, dropout_rng, labels)\n            if do_distributed:\n                grad = jax.lax.pmean(grad, \"batch\")\n            new_state = state.apply_gradients(grads=grad)\n            other_metrics = train_wrapper.train_metrics(state, batch, labels=labels)\n            metrics = {\"loss\": loss}\n            metrics.update(other_metrics)\n            if do_distributed:\n                metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n\n            return new_state, metrics, new_dropout_rng\n\n        def val_step(state, batch):\n            labels = batch.pop(\"labels\")\n            logits = state.apply_fn(**batch, params=state.params, train=False)[0]\n            metrics = train_wrapper.val_metrics(batch, logits, labels)\n            if do_distributed:\n                metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n            return metrics\n\n        if do_distributed:\n            # NOTE: The trainer currently handles only data parallelism.\n            state = jax_utils.replicate(state)\n            dropout_rngs = get_multiple_keys(rng, jax.local_device_count())\n            parallel_train_step = jax.pmap(train_step, axis_name=\"batch\")\n            parallel_val_step = jax.pmap(val_step, axis_name=\"batch\")\n\n        step_per_epoch = train_dataloader.dataset_size // train_dataloader.batch_size\n        config.train_steps = step_per_epoch * config.train_epochs\n\n        assert config.train_steps is not None  # for mypy\n\n        for callback in callbacks:\n            callback.pre_train_loop()\n\n        logger.info(\"***** Running training *****\")\n        logger.info(f\"  Num examples = {train_dataloader.dataset_size}\")\n        logger.info(f\"  Num Epochs = {config.train_epochs}\")\n        logger.info(\n            f\"  Total train batch size (w. parallel & distributed) = {train_dataloader.batch_size}\"\n        )\n        logger.info(f\"  Total optimization steps = {config.train_steps}\")\n\n        step = start_step\n\n        epochs = Tqdm.tqdm(\n            range(config.train_epochs), desc=f\"Epoch (1/{config.train_epochs})\", position=0\n        )\n\n        for epoch in epochs:\n            start = time.time()\n            train_metrics = []\n\n            for callback in callbacks:\n                callback.pre_epoch(step, epoch)\n\n            train_loader = train_dataloader(rng, do_distributed)\n            for _ in Tqdm.tqdm(range(step_per_epoch), desc=\"Training\", position=1):\n                batch = next(train_loader)\n                for callback in callbacks:\n                    callback.pre_batch(step, epoch, batch)\n\n                if do_distributed:\n                    state, train_metric, dropout_rngs = parallel_train_step(\n                        state, batch, dropout_rngs\n                    )\n                else:\n                    state, train_metric, rng = train_step(state, batch, rng)\n\n                train_metrics.append(train_metric)\n\n                for callback in callbacks:\n                    callback.post_batch(step, epoch, train_metric)\n\n                if config.should_log_this_step(step):\n                    for callback in callbacks:\n                        callback.log_batch(step, epoch, train_metric)\n\n                if config.should_checkpoint_this_step(step):\n                    self.save_checkpoint(config.state_path, state, step, keep_checkpoints)\n                step += 1\n\n                # check if we need to do validation\n                if config.validation_split is None:\n                    # If we can't validate, we don't.\n                    should_validate = False\n                elif step == config.train_steps - 1:\n                    # If we're at the end of the training run, we always validate.\n                    should_validate = True\n                elif config.validate_every is not None and step % config.validate_every == 0:\n                    # If validate_every is given, we use that to decide.\n                    should_validate = True\n                else:\n                    # Otherwise, we don't validate.\n                    should_validate = False\n\n                if should_validate:\n                    assert validation_dataloader is not None\n                    assert config.validation_steps is not None\n\n                    val_metrics: DefaultDict = defaultdict(list)\n                    epoch_eval_metrics: DefaultDict = defaultdict(float)\n\n                    val_dataloader = validation_dataloader(rng, do_distributed)\n\n                    valid_step = 0\n                    total_val_steps = len(validation_dataset) // validation_dataloader.batch_size\n                    for callback in callbacks:\n                        callback.pre_val_loop(step, valid_step, state)\n\n                    for _ in Tqdm.tqdm(range(total_val_steps), desc=\"Evaluating\", position=2):\n                        batch = next(val_dataloader)\n                        for callback in callbacks:\n                            callback.pre_val_batch(step, valid_step, epoch, batch)\n\n                        if do_distributed:\n                            metrics = parallel_val_step(state, batch)\n                            metrics = jax_utils.unreplicate(metrics)\n                        else:\n                            metrics = val_step(state, batch)\n\n                        for key, value in metrics.items():\n                            val_metrics[key].append(value.item())\n\n                        for callback in callbacks:\n                            callback.post_val_batch(step, valid_step, epoch, val_metrics)\n\n                        valid_step += 1\n\n                    for key, value in val_metrics.items():\n                        if config.auto_aggregate_val_metric:\n                            epoch_eval_metrics[key] = jax.tree_map(\n                                jnp.mean, jnp.array(value)\n                            ).item()\n                        else:\n                            epoch_eval_metrics[key] = metrics[key].item()\n\n                    for key, value in epoch_eval_metrics.items():\n                        print(\"Validation %s : %.5f\" % (key, value))\n\n                    val_metric = epoch_eval_metrics[config.val_metric_name]\n\n                    assert val_metric is not None\n\n                    if best_val_metric is None:\n                        best_val_metric = val_metric\n                    elif config.minimize_val_metric and val_metric <= best_val_metric:\n                        best_val_metric = val_metric\n                    elif not config.minimize_val_metric and val_metric >= best_val_metric:\n                        best_val_metric = val_metric\n\n                    for callback in callbacks:\n                        callback.post_val_loop(step, epoch, val_metric, best_val_metric)\n\n            if do_distributed:\n                train_metric = jax_utils.unreplicate(train_metric)\n\n            for key, value in train_metric.items():\n                print(\"Train %s : %.2f\" % (key, value))\n\n            for callback in callbacks:\n                callback.post_epoch(step, epoch)\n\n            end = time.time()\n            train_time = (end - start) / 60\n\n            desc = f\"Epoch... ({epoch + 1}/{config.train_epochs} | Time taken (mins): {train_time})\"\n            epochs.write(desc)\n            epochs.desc = desc\n\n        for callback in callbacks:\n            callback.post_train_loop(step, epoch)\n\n        if do_distributed:\n            state = jax_utils.unreplicate(state)\n        return state\n\n    def save_checkpoint(self, dir: Path, target: PyTree, step: int, keep_checkpoints: int):\n        return checkpoints.save_checkpoint(\n            dir, target, step, prefix=\"checkpoint_\", keep=keep_checkpoints, overwrite=True\n        )\n\n    def load_checkpoint(self, dir: Path, target: PyTree):\n        return checkpoints.restore_checkpoint(dir, target, prefix=\"checkpoint_\")\n\n    def _construct_optimizer(self, optimizer):\n        self.optimizer = optimizer.construct()\n        return self.optimizer\n\n    def _construct_lr_scheduler(self, scheduler):\n        self.lr_scheduler = scheduler.construct()\n        return self.lr_scheduler\n\n    def _get_devices(self) -> List[Any]:\n        device_type = jax.default_backend()\n        self.devices = jax.devices()\n        device_count = len(self.devices)\n        print(\"Training on %d %s\" % (device_count, device_type))\n        return self.devices\n"
  },
  {
    "path": "tango/integrations/flax/train_callback.py",
    "content": "import logging\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional\n\nfrom tango.common.dataset_dict import DatasetDictBase\nfrom tango.common.registrable import Registrable\nfrom tango.workspace import Workspace\n\nfrom .data import DataLoader\nfrom .model import Model\nfrom .optim import Optimizer\nfrom .train_config import TrainConfig\n\n\nclass TrainCallback(Registrable):\n    \"\"\"\n    A :class:`TrainCallback` is a :class:`~tango.common.Registrable` class\n    that can be used within :class:`FlaxTrainStep` to customize behavior in the training\n    loop. You can set the training callbacks with the ``callbacks`` parameter to :class:`FlaxTrainStep`.\n\n    .. tip::\n        All of the parameters to this base class will be automatically set within\n        the training loop, so you shouldn't include them in your config for your callbacks.\n\n    .. tip::\n        You can access the model being trained through :attr:`self.model <model>`.\n\n    .. important::\n        The ``step`` argument to callback methods is the total/overall number of training steps\n        so far, independent of the current epoch.\n\n    .. seealso::\n        See :class:`~tango.integrations.wandb.WandbTrainCallback` for an example\n        implementation.\n\n    :ivar Workspace workspace: The tango workspace being used.\n    :ivar TrainConfig train_config: The training config.\n    :ivar tango.common.DatasetDictBase dataset_dict: The dataset dict containing train and\n        optional validation splits.\n    :ivar DataLoader train_dataloader: The dataloader used for the training split.\n    :ivar Model model: The flax model being trained.\n    :ivar Optimizer optimizer: The optimizer being used for training.\n    :ivar DataLoader validation_dataloader: Optional dataloader used for the validation split.\n    \"\"\"\n\n    def __init__(\n        self,\n        workspace: Workspace,\n        train_config: TrainConfig,\n        dataset: DatasetDictBase,\n        train_dataloader: DataLoader,\n        model: Model,\n        optimizer: Optimizer,\n        validation_dataloader: Optional[DataLoader] = None,\n    ) -> None:\n        self.workspace = workspace\n        self.train_config = train_config\n        self.dataset = dataset\n        self.train_dataloader = train_dataloader\n        self.model = model\n        self.optimizer = optimizer\n        self.validation_dataloader = validation_dataloader\n        self.logger = logging.getLogger(self.__class__.__name__)\n\n    @property\n    def step_id(self) -> str:\n        \"\"\"\n        The unique ID of the current :class:`~tango.Step`.\n        \"\"\"\n        return self.train_config.step_id\n\n    @property\n    def step_name(self) -> Optional[str]:\n        \"\"\"\n        The name of the current:class:`~tango.Step`.\n        \"\"\"\n        return self.train_config.step_name\n\n    @property\n    def work_dir(self) -> Path:\n        \"\"\"\n        The working directory of the current train step\n        \"\"\"\n        return self.train_config.work_dir\n\n    def state_dict(self) -> Dict[str, Any]:\n        \"\"\"\n        Return any state that needs to be kept after a restart.\n\n        Some callbacks need to maintain state across restarts. This is the callback's opportunity to\n        save it's state. It will be restored using :meth:`load_state_dict`.\n        \"\"\"\n        return {}\n\n    def load_state_dict(self, state_dict: Dict[str, Any]):\n        \"\"\"\n        Load the state on a restart.\n\n        Some callbacks need to maintain state across restarts. This is the callback's opportunity to\n        restore it's state. It gets saved using :meth:`state_dict`.\n        \"\"\"\n        pass\n\n    def pre_train_loop(self) -> None:\n        \"\"\"\n        Called right before the first batch is processed, or after a restart\n        \"\"\"\n        pass\n\n    def post_train_loop(self, step: int, epoch: int) -> None:\n        \"\"\"\n        Called after the training loop completes.\n\n        This is the last method that is called, so any cleanup can be done in this method.\n        \"\"\"\n        pass\n\n    def pre_epoch(self, step: int, epoch: int) -> None:\n        \"\"\"\n        Called before start of an epoch. Epochs start at 0.\n        \"\"\"\n        pass\n\n    def post_epoch(self, step: int, epoch: int) -> None:\n        \"\"\"\n        Called after an epoch is completed. Epochs start at 0.\n        \"\"\"\n        pass\n\n    def pre_batch(self, step: int, epoch: int, batch) -> None:\n        \"\"\"\n        Called directly before processing a batch.\n        \"\"\"\n\n    def post_batch(self, step: int, epoch: int, train_metrics: Dict) -> None:\n        \"\"\"\n        Called directly after processing a batch, but before unscaling gradients,\n        clipping gradients, and taking an optimizer step.\n\n        .. note::\n            The ``train_metrics`` here is the dictionary with train metrics of the\n            current batch. If doing, distributed training, use `jax_utils.unreplicate(train_metrics)`\n            before using train_metrics.\n\n            If you need the average loss, use :meth:`log_batch()`.\n        \"\"\"\n        pass\n\n    def log_batch(self, step: int, epoch: int, train_metrics: Dict) -> None:\n        \"\"\"\n        Called after the optimizer step. Here ``train_metrics`` is the average metrics across\n        all distributed workers. If doing, distributed training, use\n        `jax_utils.unreplicate(train_metrics)` before using train_metrics.\n\n        .. note::\n            This callback method is not necessarily called on every step.\n            The frequency depends on the value of the ``log_every`` parameter of\n            :class:`FlaxTrainStep`.\n\n        \"\"\"\n        pass\n\n    def pre_val_loop(self, step: int, val_step: int, state) -> None:\n        \"\"\"\n        Called right before the validation loop starts.\n        \"\"\"\n        pass\n\n    def pre_val_batch(self, step: int, val_step: int, epoch: int, val_batch) -> None:\n        \"\"\"\n        Called right before a validation batch is processed.\n        \"\"\"\n        pass\n\n    def post_val_batch(self, step: int, val_step: int, epoch: int, val_metrics: Dict) -> None:\n        \"\"\"\n        Called right after a validation batch is processed with the outputs of the batch.\n\n        .. tip::\n            This method can be used to modify ``val_metrics`` in place, which is useful\n            in scenarios like distributed training where you might need to aggregate metrics\n            in a special way other than a simple average. If that's the case, make sure\n            to set ``auto_aggregate_val_metric`` to ``False`` in :class:`FlaxTrainStep`.\n\n        \"\"\"\n        pass\n\n    def post_val_loop(\n        self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float]\n    ) -> None:\n        \"\"\"\n        Called right after the evaluation loop finishes\n        \"\"\"\n        pass\n"
  },
  {
    "path": "tango/integrations/flax/train_config.py",
    "content": "from dataclasses import asdict, dataclass\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional\n\n\n@dataclass\nclass TrainConfig:\n    \"\"\"\n    Encapsulates the parameters of :class:`FlaxTrainStep`. This is used to pass all the training\n    options to :class:`TrainCallback`.\n    \"\"\"\n\n    step_id: str\n    \"\"\"\n    The unique ID of the current step.\n    \"\"\"\n\n    work_dir: Path\n    \"\"\"\n    The working directory for the training run.\n    \"\"\"\n\n    step_name: Optional[str] = None\n    \"\"\"\n    The name of the current step.\n    \"\"\"\n\n    train_split: str = \"train\"\n    \"\"\"\n    The name of the training split.\n    \"\"\"\n\n    validation_split: Optional[str] = None\n    \"\"\"\n    The name of the validation split.\n    \"\"\"\n\n    seed: int = 42\n    \"\"\"\n    The random seed used to generate\n    \"\"\"\n\n    train_steps: Optional[int] = None\n    \"\"\"\n    The number of steps to train for.\n    \"\"\"\n\n    train_epochs: Optional[int] = None\n    \"\"\"\n    The number of epochs to train for.\n\n    You cannot specify `train_steps` and `train_epochs` at the same time.\n    \"\"\"\n\n    validation_steps: Optional[int] = None\n    \"\"\"\n    The number of validation steps.\n    \"\"\"\n\n    log_every: int = 10\n    \"\"\"\n    Controls the frequency of log updates.\n    \"\"\"\n\n    checkpoint_every: int = 100\n    \"\"\"\n    Controls the frequency of checkpoints.\n    \"\"\"\n\n    validate_every: Optional[int] = None\n    \"\"\"\n    Controls the frequency of the validation loop.\n    \"\"\"\n\n    is_distributed: bool = False\n    \"\"\"\n    Whether or not the training job is distributed.\n    \"\"\"\n\n    val_metric_name: str = \"loss\"\n    \"\"\"\n    The name of the validation metric to track.\n    \"\"\"\n\n    minimize_val_metric: bool = True\n    \"\"\"\n    Should be ``True`` when the validation metric being tracked should be minimized.\n    \"\"\"\n\n    auto_aggregate_val_metric: bool = True\n    \"\"\"\n    Controls automatic aggregation of validation metric.\n    \"\"\"\n\n    remove_stale_checkpoints: bool = True\n    \"\"\"\n    Controls removal of stale checkpoints.\n    \"\"\"\n\n    @property\n    def state_path(self) -> Path:\n        \"\"\"\n        The path to the latest state checkpoint file.\n        \"\"\"\n        return self.work_dir / \"checkpoint_state_latest\"\n\n    @property\n    def best_state_path(self) -> Path:\n        \"\"\"\n        The path to the best state checkpoint file according to the validation metric or training\n        loss (if no validation split is given).\n        \"\"\"\n        return self.work_dir / \"checkpoint_state_best\"\n\n    def should_log_this_step(self, step: int) -> bool:\n        assert self.train_steps is not None\n        return step == 0 or (step + 1) % self.log_every == 0 or step == self.train_steps - 1\n\n    def should_checkpoint_this_step(self, step: int) -> bool:\n        assert self.train_steps is not None\n        return ((step + 1) % self.checkpoint_every == 0) or step == self.train_steps - 1\n\n    def should_log_this_val_step(self, val_step: int) -> bool:\n        assert self.validation_steps is not None\n        return val_step % self.log_every == 0 or val_step == self.validation_steps - 1\n\n    def as_dict(self) -> Dict[str, Any]:\n        return {k: v for k, v in asdict(self).items() if not k.startswith(\"_\")}\n"
  },
  {
    "path": "tango/integrations/flax/util.py",
    "content": "from typing import Any, Union\n\nimport jax\n\n\ndef get_PRNGkey(seed: int = 42) -> Union[Any, jax._src.random.KeyArray]:\n    \"\"\"\n    Utility function to create a pseudo-random number generator key\n    given a seed.\n    \"\"\"\n    return jax.random.PRNGKey(seed)\n\n\ndef get_multiple_keys(key, multiple: int = 1) -> Union[Any, jax._src.random.KeyArray]:\n    \"\"\"\n    Utility function to split a PRNG key into multiple new keys.\n    Used in distributed training.\n    \"\"\"\n    return jax.random.split(key, multiple)\n"
  },
  {
    "path": "tango/integrations/flax/wrapper.py",
    "content": "from abc import abstractmethod\nfrom typing import Dict\n\nfrom tango.common.registrable import Registrable\n\n\nclass FlaxWrapper(Registrable):\n    \"\"\"\n    A wrapper class which contains functions that need to be defined by the user\n    for using the ``flax::train`` and ``flax::eval`` steps.\n    \"\"\"\n\n    def train_metrics(self, state, batch, labels) -> Dict:\n        \"\"\"\n        Returns the train metrics other than loss as Dict.\n        \"\"\"\n        # return empty dict if no other metrics to compute\n        return {}\n\n    @abstractmethod\n    def train_loss(self, params, state, batch, dropout_rng, labels):\n        \"\"\"\n        This function performs the forward pass and computes loss. The function\n        should return the loss for the batch as a jax device array. The gradient\n        of this function is used for training.\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def val_metrics(self, batch, logits, labels) -> Dict:\n        \"\"\"\n        Returns the validation metrics as Dict.\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def eval_metrics(self, batch, logits, labels) -> Dict:\n        \"\"\"\n        Returns the evaluation metrics as  Dict.\n        \"\"\"\n        raise NotImplementedError()\n"
  },
  {
    "path": "tango/integrations/gs/__init__.py",
    "content": "\"\"\"\n.. important::\n    To use this integration you should install ``tango`` with the \"gs\" extra\n    (e.g. ``pip install tango[gs]``) or just install the `gcsfs <https://gcsfs.readthedocs.io/>`_\n    library after the fact (e.g. ``pip install gcsfs``).\n\nComponents for Tango integration with `GS <https://cloud.google.com/storage/>`_.\n\"\"\"\n\nfrom tango.common.exceptions import IntegrationMissingError\n\ntry:\n    from google.cloud import datastore, storage\nexcept (ModuleNotFoundError, ImportError):\n    raise IntegrationMissingError(\"gs\", dependencies={\"google-cloud-storage\"})\n\nfrom .step_cache import GSStepCache\nfrom .workspace import GSWorkspace\n\n__all__ = [\n    \"GSStepCache\",\n    \"GSWorkspace\",\n]\n"
  },
  {
    "path": "tango/integrations/gs/common.py",
    "content": "\"\"\"\nClasses and utility functions for GSWorkspace and GSStepCache.\n\"\"\"\nimport atexit\nimport datetime\nimport json\nimport logging\nimport os\nimport time\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import List, Optional, Tuple, Union\n\nimport google.auth\nfrom google.api_core import exceptions\nfrom google.auth.credentials import Credentials\nfrom google.cloud import storage\nfrom google.oauth2.credentials import Credentials as OAuth2Credentials\nfrom google.oauth2.service_account import Credentials as ServiceAccountCredentials\n\nfrom tango.common.aliases import PathOrStr\nfrom tango.common.exceptions import TangoError\nfrom tango.common.remote_utils import RemoteConstants\nfrom tango.step import Step\nfrom tango.step_info import StepInfo\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_bucket_and_prefix(folder_name: str) -> Tuple[str, str]:\n    \"\"\"\n    Split bucket name and subfolder name, if present.\n    \"\"\"\n    split = folder_name.split(\"/\")\n    return split[0], \"/\".join(split[1:])\n\n\ndef empty_bucket_folder(folder_name: str):\n    \"\"\"\n    Removes all the tango-related blobs from the specified bucket folder.\n    Used for testing.\n    \"\"\"\n    credentials, project = google.auth.default()\n    client = storage.Client(project=project, credentials=credentials)\n    bucket_name, prefix = get_bucket_and_prefix(folder_name)\n\n    prefix = prefix + \"/tango-\" if prefix else \"tango-\"\n\n    bucket = client.bucket(bucket_name)\n    try:\n        bucket.delete_blobs(list(bucket.list_blobs(prefix=prefix)))\n    except exceptions.NotFound:\n        pass\n\n\ndef empty_datastore(folder_name: str):\n    \"\"\"\n    Removes all the tango-related entities from the specified namespace subfolder in datastore.\n    Used for testing.\n    \"\"\"\n    from google.cloud import datastore\n\n    credentials, project = google.auth.default()\n    namespace, prefix = get_bucket_and_prefix(folder_name)\n\n    run_kind = prefix + \"/run\" if prefix else \"run\"\n    stepinfo_kind = prefix + \"/stepinfo\" if prefix else \"stepinfo\"\n\n    client = datastore.Client(project=project, credentials=credentials, namespace=namespace)\n    run_query = client.query(kind=run_kind)\n    run_query.keys_only()\n    keys = [entity.key for entity in run_query.fetch()]\n    stepinfo_query = client.query(kind=stepinfo_kind)\n    stepinfo_query.keys_only()\n    keys += [entity.key for entity in stepinfo_query.fetch()]\n    client.delete_multi(keys)\n\n\n@dataclass\nclass GSArtifact:\n    \"\"\"\n    A GSArtifact object is used for representing storage objects in google cloud storage.\n    \"\"\"\n\n    name: str\n    \"\"\"\n    Name of the artifact.\n    \"\"\"\n    artifact_path: str\n    \"\"\"\n    Remote location url for the artifact.\n    \"\"\"\n    created: datetime.datetime\n    \"\"\"\n    Time of creation.\n    \"\"\"\n    committed: bool\n    \"\"\"\n    If set to True, no further changes to the remote artifact are allowed.\n    If set to False, it means that the artifact is under construction.\n    \"\"\"\n\n\nclass GSArtifactConflict(TangoError):\n    \"\"\"\n    Error denoting that the storage artifact already exists.\n    \"\"\"\n\n    pass\n\n\nclass GSArtifactNotFound(TangoError):\n    \"\"\"\n    Error denoting that the storage artifact does not exist.\n    \"\"\"\n\n    pass\n\n\nclass GSArtifactWriteError(TangoError):\n    \"\"\"\n    Error denoting that there was an issue writing the artifact to the remote storage.\n    \"\"\"\n\n    pass\n\n\ndef join_path(*args) -> str:\n    \"\"\"\n    We use this since we cannot use `os.path.join` for cloud storage paths.\n    \"\"\"\n    return \"/\".join(args).strip(\"/\")\n\n\nclass GSClient:\n    \"\"\"\n    A client for interacting with Google Cloud Storage. The authorization works by\n    providing OAuth2 credentials.\n\n    :param folder_name: The name of the Google Cloud bucket folder to use.\n    :param credentials: OAuth2 credentials can be provided. If not provided, default\n        gcloud credentials are inferred.\n    :param project: Optionally, the project ID can be provided. This is not essential\n        for `google.cloud.storage` API, since buckets are at the account level, rather\n        than the project level.\n    \"\"\"\n\n    placeholder_file = \".placeholder\"\n    \"\"\"\n    The placeholder file is used for creation of a folder in the cloud bucket folder,\n    as empty folders are not allowed. It is also used as a marker for the creation\n    time of the folder, hence we use a separate file to mark the artifact as\n    uncommitted.\n    \"\"\"\n\n    uncommitted_file = \".uncommitted\"\n    \"\"\"\n    The uncommitted file is used to denote an artifact under construction.\n    \"\"\"\n\n    settings_file = \"settings.json\"\n    \"\"\"\n    This file is for storing metadata like version information, etc.\n    \"\"\"\n\n    NUM_CONCURRENT_WORKERS: int = 9\n\n    def __init__(\n        self,\n        folder_name: str,\n        credentials: Optional[Credentials] = None,\n        project: Optional[str] = None,\n    ):\n        if not credentials:\n            credentials, project = google.auth.default()\n\n        self.storage = storage.Client(project=project, credentials=credentials)\n        self.folder_name = folder_name\n\n        self.bucket_name, self.prefix = get_bucket_and_prefix(folder_name)\n        settings_file = self._gs_path(self.settings_file)\n\n        blob = self.storage.bucket(self.bucket_name).blob(settings_file)  # no HTTP request yet\n        try:\n            with blob.open(\"r\") as file_ref:\n                json.load(file_ref)\n        except exceptions.NotFound:\n            settings = {\"version\": 1}\n            with blob.open(\"w\") as file_ref:\n                json.dump(settings, file_ref)\n\n    def url(self, artifact: Optional[str] = None):\n        \"\"\"\n        Returns the remote url of the storage artifact.\n        \"\"\"\n        path = f\"gs://{self.folder_name}\"\n        if artifact is not None:\n            path = f\"{path}/{artifact}\"\n        return path\n\n    def _convert_blobs_to_artifact(self, blobs: List[storage.Blob]) -> GSArtifact:\n        \"\"\"\n        Converts a list of `google.cloud.storage.Blob` to a `GSArtifact`.\n        \"\"\"\n        name: str\n        artifact_path: str\n        created: datetime.datetime\n        committed: bool = True\n\n        for blob in blobs:\n            if blob.name.endswith(self.placeholder_file):\n                created = blob.time_created\n                name = blob.name.replace(\"/\" + self.placeholder_file, \"\")\n                if self.prefix:\n                    name = name.replace(self.prefix + \"/\", \"\")\n                artifact_path = name  # does not contain bucket info here.\n            elif blob.name.endswith(self.uncommitted_file):\n                committed = False\n\n        assert name is not None, \"Folder is not a GSArtifact, should not have happened.\"\n        return GSArtifact(name, artifact_path, created, committed)\n\n    @classmethod\n    def from_env(cls, folder_name: str):\n        \"\"\"\n        Constructs the client object from the environment, using default credentials.\n        \"\"\"\n        return cls(folder_name)\n\n    def get(self, artifact: Union[str, GSArtifact]) -> GSArtifact:\n        \"\"\"\n        Returns a `GSArtifact` object created by fetching the artifact's information\n        from remote location.\n        \"\"\"\n        if isinstance(artifact, str):\n            path = artifact\n        else:\n            # We have an artifact, and we recreate it with refreshed info.\n            path = artifact.artifact_path\n\n        prefix = self._gs_path(path)\n        blobs = list(self.storage.bucket(self.bucket_name).list_blobs(prefix=prefix))\n        if len(blobs) > 0:\n            return self._convert_blobs_to_artifact(blobs)\n        else:\n            raise GSArtifactNotFound()\n\n    def _gs_path(self, *args):\n        \"\"\"\n        Returns path within google cloud storage bucket.\n        \"\"\"\n        return join_path(self.prefix, *args)\n\n    def create(self, artifact: str):\n        \"\"\"\n        Creates a new artifact in the remote location. By default, it is uncommitted.\n        \"\"\"\n        bucket = self.storage.bucket(self.bucket_name)\n        # gives refreshed information\n\n        artifact_path = self._gs_path(artifact, self.placeholder_file)\n        if bucket.blob(artifact_path).exists():\n            raise GSArtifactConflict(f\"{artifact} already exists!\")\n        else:\n            # Additional safety check\n            if bucket.blob(self._gs_path(artifact, self.uncommitted_file)).exists():\n                raise GSArtifactConflict(f\"{artifact} already exists!\")\n            bucket.blob(self._gs_path(artifact, self.placeholder_file)).upload_from_string(\"\")\n            bucket.blob(self._gs_path(artifact, self.uncommitted_file)).upload_from_string(\"\")\n        return self._convert_blobs_to_artifact(\n            list(bucket.list_blobs(prefix=self._gs_path(artifact)))\n        )\n\n    def delete(self, artifact: GSArtifact):\n        \"\"\"\n        Removes the artifact from the remote location.\n        \"\"\"\n        bucket = self.storage.bucket(self.bucket_name)\n        prefix = self._gs_path(artifact.artifact_path)\n        blobs = list(bucket.list_blobs(prefix=prefix))\n        bucket.delete_blobs(blobs)\n\n    def upload(self, artifact: Union[str, GSArtifact], objects_dir: Path):\n        \"\"\"\n        Writes the contents of objects_dir to the remote artifact location.\n        \"\"\"\n        if isinstance(artifact, str):\n            folder_path = artifact\n        else:\n            folder_path = artifact.artifact_path\n\n        source_path = str(objects_dir)\n\n        def _sync_blob(source_file_path: str, target_file_path: str):\n            blob = self.storage.bucket(self.bucket_name).blob(self._gs_path(target_file_path))\n            blob.upload_from_filename(source_file_path)\n\n        import concurrent.futures\n\n        try:\n            # TODO: google-cloud-storage==2.7.0 has added a preview feature called `transfer_manager`\n            # which allows for concurrent uploads and downloads. We should upgrade to this once\n            # it is more robustly supported. Also update in `download()`.\n            if os.path.isdir(source_path):\n                with concurrent.futures.ThreadPoolExecutor(\n                    max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix=\"GSClient.upload()-\"\n                ) as executor:\n                    upload_futures = []\n                    for dirpath, _, filenames in os.walk(source_path):\n                        for filename in filenames:\n                            source_file_path = os.path.join(dirpath, filename)\n                            target_file_path = join_path(\n                                folder_path, source_file_path.replace(source_path + \"/\", \"\")\n                            )\n                        upload_futures.append(\n                            executor.submit(_sync_blob, source_file_path, target_file_path)\n                        )\n                    for future in concurrent.futures.as_completed(upload_futures):\n                        future.result()\n            else:\n                source_file_path = source_path\n                target_file_path = join_path(folder_path, os.path.basename(source_file_path))\n                _sync_blob(source_file_path, target_file_path)\n        except Exception:\n            raise GSArtifactWriteError()\n\n    def commit(self, artifact: Union[str, GSArtifact]):\n        \"\"\"\n        Marks the artifact as committed. No further changes to the artifact are allowed.\n        \"\"\"\n        if isinstance(artifact, str):\n            folder_path = artifact\n        else:\n            folder_path = artifact.artifact_path\n        bucket = self.storage.bucket(self.bucket_name)\n        try:\n            bucket.delete_blob(self._gs_path(folder_path, self.uncommitted_file))\n        except exceptions.NotFound:\n            if not bucket.blob(self._gs_path(folder_path, self.placeholder_file)).exists():\n                raise GSArtifactNotFound()\n            # Otherwise, already committed. No change.\n\n    def download(self, artifact: GSArtifact, target_dir: PathOrStr):\n        \"\"\"\n        Writes the contents of the remote artifact to the `target_dir`.\n        \"\"\"\n        assert (\n            self.storage.bucket(self.bucket_name)\n            .blob(self._gs_path(artifact.artifact_path, self.placeholder_file))\n            .exists()\n        )\n\n        def _fetch_blob(blob: storage.Blob):\n            source_path = blob.name.replace(artifact.artifact_path + \"/\", \"\")\n            target_path = os.path.join(target_dir, source_path)\n            if not os.path.exists(os.path.dirname(target_path)):\n                os.mkdir(os.path.dirname(target_path))\n            blob.download_to_filename(target_path)\n\n        import concurrent.futures\n\n        bucket = self.storage.bucket(self.bucket_name)\n        # We may not need updates that frequently, with list_blobs(prefix).\n        # bucket.update()\n\n        try:\n            with concurrent.futures.ThreadPoolExecutor(\n                max_workers=self.NUM_CONCURRENT_WORKERS, thread_name_prefix=\"GSClient.download()-\"\n            ) as executor:\n                download_futures = []\n                prefix = self._gs_path(artifact.artifact_path)\n                for blob in bucket.list_blobs(prefix=prefix):\n                    download_futures.append(executor.submit(_fetch_blob, blob))\n                for future in concurrent.futures.as_completed(download_futures):\n                    future.result()\n        except exceptions.NotFound:\n            raise GSArtifactWriteError()\n\n    def artifacts(self, prefix: str, uncommitted: bool = True) -> List[GSArtifact]:\n        \"\"\"\n        Lists all the artifacts within the remote storage, based on\n        `match` and `uncommitted` criteria. These can include steps and runs.\n        \"\"\"\n        list_of_artifacts = []\n        prefix = self._gs_path(prefix)\n        for folder_name in self.storage.list_blobs(\n            self.bucket_name, prefix=prefix, delimiter=\"/\"\n        )._get_next_page_response()[\"prefixes\"]:\n            artifact = self._convert_blobs_to_artifact(\n                list(self.storage.list_blobs(self.bucket_name, prefix=folder_name))\n            )\n            if not uncommitted:\n                if not artifact.committed:\n                    continue\n            list_of_artifacts.append(artifact)\n        return list_of_artifacts\n\n\ndef get_credentials(credentials: Optional[Union[str, Credentials]] = None) -> Credentials:\n    \"\"\"\n    :param credentials:\n        * if OAuth2 credentials are provided, they are returned.\n        * if `str`, it can be either a file path or a json string of credentials dict.\n        * if `None`, credentials are inferred from the environment.\n\n    More details on Google Cloud credentials can be found here:\n    https://googleapis.dev/python/google-auth/latest/user-guide.html#service-account-private-key-files,\n    and https://googleapis.dev/python/google-api-core/latest/auth.html\n    \"\"\"\n\n    # BeakerExecutor uses GOOGLE_TOKEN\n    credentials = os.environ.get(\"GOOGLE_TOKEN\", credentials)\n    if credentials is not None:\n        # Path to the credentials file has been provided\n        if isinstance(credentials, str) and credentials.endswith(\".json\"):\n            with open(credentials) as file_ref:\n                credentials = file_ref.read()\n        try:\n            # If credentials dict has been passed as a json string\n            credentials_dict = json.loads(credentials)\n            if credentials_dict.pop(\"type\", None) == \"service_account\":\n                credentials = ServiceAccountCredentials.from_service_account_info(credentials_dict)\n            else:\n                # sometimes the credentials dict may not contain `token` and `token_uri` keys,\n                # but `Credentials()` needs the parameter.\n                token = credentials_dict.pop(\"token\", None)\n                token_uri = credentials_dict.pop(\"token_uri\", \"https://oauth2.googleapis.com/token\")\n                credentials = OAuth2Credentials(\n                    token=token, token_uri=token_uri, **credentials_dict\n                )\n        except (json.decoder.JSONDecodeError, TypeError, ValueError):\n            # It is not a json string.\n            # We use this string because BeakerExecutor cannot write a None secret.\n            if credentials == \"default\":\n                credentials = None\n    if not credentials:\n        # Infer default credentials\n        credentials, _ = google.auth.default()\n    return credentials\n\n\ndef get_client(\n    folder_name: str,\n    credentials: Optional[Union[str, Credentials]] = None,\n    project: Optional[str] = None,\n) -> GSClient:\n    \"\"\"\n    Returns a `GSClient` object for a google cloud bucket folder.\n    \"\"\"\n    credentials = get_credentials(credentials)\n    return GSClient(folder_name, credentials=credentials, project=project)\n\n\nclass Constants(RemoteConstants):\n    pass\n\n\nclass GCSStepLock:\n    \"\"\"\n    Google Cloud offers consistency https://cloud.google.com/storage/docs/consistency,\n    so we can use lock files.\n    \"\"\"\n\n    def __init__(\n        self,\n        client: GSClient,\n        step: Union[str, StepInfo, Step],\n    ):\n        self._client = client\n        self._step_id = step if isinstance(step, str) else step.unique_id\n        self._lock_artifact_name = RemoteConstants.step_lock_artifact_name(step)\n        self._lock_artifact: Optional[GSArtifact] = None\n        self.lock_artifact_url = self._client.url(self._lock_artifact_name)\n\n    def acquire(self, timeout=None, poll_interval: float = 2.0, log_interval: float = 30.0) -> None:\n        if self._lock_artifact is not None:\n            return\n        start = time.monotonic()\n        last_logged = None\n        while timeout is None or (time.monotonic() - start < timeout):\n            try:\n                self._lock_artifact = self._client.create(self._lock_artifact_name)\n                atexit.register(self.release)\n\n            except GSArtifactConflict:\n                if last_logged is None or last_logged - start >= log_interval:\n                    logger.warning(\n                        \"Waiting to acquire lock artifact for step '%s':\\n\\n%s\\n\\n\"\n                        \"This probably means the step is being run elsewhere, but if you're sure it isn't \"\n                        \"you can just delete the lock artifact, using the command: \\n`gsutil rm -r %s`\",\n                        self._step_id,\n                        self.lock_artifact_url,\n                        self.lock_artifact_url,\n                    )\n                    last_logged = time.monotonic()\n                time.sleep(poll_interval)\n                continue\n            else:\n                break\n        else:\n            raise TimeoutError(\n                f\"Timeout error occurred while waiting to acquire artifact lock for step '{self._step_id}':\\n\\n\"\n                f\"{self.lock_artifact_url}\\n\\n\"\n                f\"This probably means the step is being run elsewhere, but if you're sure it isn't you can \"\n                f\"just delete the lock, using the command: \\n`gsutil rm -r {self.lock_artifact_url}`\"\n            )\n\n    def release(self):\n        if self._lock_artifact is not None:\n            try:\n                self._client.delete(self._lock_artifact)\n            except GSArtifactNotFound:\n                # Artifact must have been manually deleted.\n                pass\n            self._lock_artifact = None\n            atexit.unregister(self.release)\n\n    def __del__(self):\n        self.release()\n"
  },
  {
    "path": "tango/integrations/gs/step_cache.py",
    "content": "import logging\nfrom pathlib import Path\nfrom typing import Optional, Union\n\nfrom tango.common import PathOrStr\nfrom tango.common.util import make_safe_filename, tango_cache_dir\nfrom tango.integrations.gs.common import (\n    Constants,\n    GSArtifact,\n    GSArtifactConflict,\n    GSArtifactNotFound,\n    GSArtifactWriteError,\n    GSClient,\n    get_bucket_and_prefix,\n)\nfrom tango.step import Step\nfrom tango.step_cache import StepCache\nfrom tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache\nfrom tango.step_info import StepInfo\n\nlogger = logging.getLogger(__name__)\n\n\n@StepCache.register(\"gs\")\nclass GSStepCache(RemoteStepCache):\n    \"\"\"\n    This is a :class:`~tango.step_cache.StepCache` that's used by :class:`GSWorkspace`.\n    It stores the results of steps on Google cloud buckets as blobs.\n\n    It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a\n    step's resulting subsequent times should be fast.\n\n    .. tip::\n        Registered as a :class:`~tango.step_cache.StepCache` under the name \"gs\".\n\n    :param folder_name: The name of the google cloud bucket folder to use.\n    :param client: The google cloud storage client to use.\n    \"\"\"\n\n    Constants = Constants\n\n    def __init__(self, folder_name: str, client: Optional[GSClient] = None):\n        if client is not None:\n            bucket_name, _ = get_bucket_and_prefix(folder_name)\n            assert (\n                bucket_name == client.bucket_name\n            ), \"Assert that bucket name is same as client bucket until we do better\"\n            self.folder_name = folder_name\n            self._client = client\n        else:\n            self._client = GSClient(folder_name)\n        super().__init__(tango_cache_dir() / \"gs_cache\" / make_safe_filename(folder_name))\n\n    @property\n    def client(self):\n        return self._client\n\n    def _step_result_remote(self, step: Union[Step, StepInfo]) -> Optional[GSArtifact]:\n        \"\"\"\n        Returns a `GSArtifact` object containing the details of the step.\n        This only returns if the step has been finalized (committed).\n        \"\"\"\n        try:\n            artifact = self.client.get(self.Constants.step_artifact_name(step))\n            return artifact if artifact.committed else None\n        except GSArtifactNotFound:\n            return None\n\n    def _upload_step_remote(self, step: Step, objects_dir: Path) -> GSArtifact:\n        \"\"\"\n        Uploads the step's output to remote location.\n        \"\"\"\n        artifact_name = self.Constants.step_artifact_name(step)\n        try:\n            self.client.create(artifact_name)\n        except GSArtifactConflict:\n            pass\n        try:\n            self.client.upload(artifact_name, objects_dir)\n            self.client.commit(artifact_name)\n        except GSArtifactWriteError:\n            pass\n\n        return self.client.get(artifact_name)\n\n    def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None:\n        \"\"\"\n        Downloads the step's output from remote location.\n        \"\"\"\n        try:\n            self.client.download(step_result, target_dir)\n        except GSArtifactNotFound:\n            raise RemoteNotFoundError()\n\n    def __len__(self):\n        \"\"\"\n        Returns the number of committed step outputs present in the remote location.\n        \"\"\"\n        # NOTE: lock files should not count here.\n        return sum(\n            1\n            for ds in self.client.artifacts(\n                prefix=self.Constants.STEP_ARTIFACT_PREFIX, uncommitted=False\n            )\n            if ds.name is not None\n            and ds.name.startswith(self.Constants.STEP_ARTIFACT_PREFIX)\n            and not ds.name.endswith(self.Constants.LOCK_ARTIFACT_SUFFIX)\n        )\n"
  },
  {
    "path": "tango/integrations/gs/workspace.py",
    "content": "import json\nimport random\nfrom pathlib import Path\nfrom typing import (\n    Dict,\n    Generator,\n    Iterable,\n    List,\n    Optional,\n    Tuple,\n    TypeVar,\n    Union,\n    cast,\n)\nfrom urllib.parse import ParseResult\n\nimport petname\nfrom google.auth.credentials import Credentials\nfrom google.cloud import datastore\n\nfrom tango.common.util import utc_now_datetime\nfrom tango.integrations.gs.common import (\n    Constants,\n    GCSStepLock,\n    get_bucket_and_prefix,\n    get_client,\n    get_credentials,\n)\nfrom tango.integrations.gs.step_cache import GSStepCache\nfrom tango.step import Step\nfrom tango.step_info import StepInfo, StepState\nfrom tango.workspace import Run, RunInfo, RunSort, StepInfoSort, Workspace\nfrom tango.workspaces.remote_workspace import RemoteWorkspace\n\nT = TypeVar(\"T\")\n\n\n@Workspace.register(\"gs\")\nclass GSWorkspace(RemoteWorkspace):\n    \"\"\"\n    This is a :class:`~tango.workspace.Workspace` that stores step artifacts on Google Cloud Storage.\n\n    .. tip::\n        Registered as a :class:`~tango.workspace.Workspace` under the name \"gs\".\n\n    :param workspace: The name or ID of the Google Cloud bucket folder to use.\n    :param project: The Google project ID. This is required for the datastore. If not provided,\n        it will be inferred from the Google cloud credentials.\n\n    .. important::\n        Credentials can be provided in the following ways:\n\n        - Using the `credentials` keyword argument:\n            - You can specify the path to the credentials json file.\n            - You can specify the `google.oauth2.credentials.Credentials()` object.\n            - You can specify the json string of credentials dict.\n\n        - Using the default credentials: You can use your default google cloud credentials by running\n          `gcloud auth application-default login`. If you are using `GSWorkspace` with\n          :class:`~tango.integrations.beaker.BeakerExecutor`, you will need to set the environment variable\n          `GOOGLE_TOKEN` to the credentials json file. The default location is usually\n          `~/.config/gcloud/application_default_credentials.json`.\n\n    \"\"\"\n\n    Constants = Constants\n    NUM_CONCURRENT_WORKERS = 32\n\n    def __init__(\n        self,\n        workspace: str,\n        project: Optional[str] = None,\n        credentials: Optional[Union[str, Credentials]] = None,\n    ):\n        credentials = get_credentials(credentials)\n        self.client = get_client(folder_name=workspace, credentials=credentials, project=project)\n\n        self.client.NUM_CONCURRENT_WORKERS = self.NUM_CONCURRENT_WORKERS\n        self._cache = GSStepCache(workspace, client=self.client)\n        self._locks: Dict[Step, GCSStepLock] = {}\n\n        super().__init__()\n\n        project = project or self.client.storage.project or credentials.quota_project_id\n\n        self.bucket_name, self.prefix = get_bucket_and_prefix(workspace)\n        self._ds = datastore.Client(\n            namespace=self.bucket_name, project=project, credentials=credentials\n        )\n\n    @property\n    def cache(self):\n        return self._cache\n\n    @property\n    def locks(self):\n        return self._locks\n\n    @property\n    def steps_dir_name(self):\n        return \"gs_workspace\"\n\n    @classmethod\n    def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:\n        workspace: str\n        if parsed_url.netloc and parsed_url.path:\n            # e.g. \"gs://ai2/my-workspace\"\n            workspace = parsed_url.netloc + parsed_url.path\n        elif parsed_url.netloc:\n            # e.g. \"gs://my-workspace\"\n            workspace = parsed_url.netloc\n        else:\n            raise ValueError(f\"Bad URL for GS workspace '{parsed_url}'\")\n        return cls(workspace)\n\n    @property\n    def url(self) -> str:\n        return self.client.url()\n\n    def _remote_lock(self, step: Step) -> GCSStepLock:\n        return GCSStepLock(self.client, step)\n\n    def _step_location(self, step: Step) -> str:\n        return self.client.url(self.Constants.step_artifact_name(step))\n\n    @property\n    def _run_key(self):\n        return self.client._gs_path(\"run\")\n\n    @property\n    def _stepinfo_key(self):\n        return self.client._gs_path(\"stepinfo\")\n\n    def _save_run(\n        self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None\n    ) -> Run:\n        if name is None:\n            while True:\n                name = petname.generate() + str(random.randint(0, 100))\n                if not self._ds.get(self._ds.key(self._run_key, name)):\n                    break\n        else:\n            if self._ds.get(self._ds.key(self._run_key, name)):\n                raise ValueError(f\"Run name '{name}' is already in use\")\n\n        run_entity = self._ds.entity(\n            key=self._ds.key(self._run_key, name), exclude_from_indexes=(\"steps\",)\n        )\n        # Even though the run's name is part of the key, we add this as a\n        # field so we can index on it and order asc/desc (indices on the key field don't allow ordering).\n        run_entity[\"name\"] = name\n        run_entity[\"start_date\"] = utc_now_datetime()\n        run_entity[\"steps\"] = json.dumps(run_data).encode()\n        self._ds.put(run_entity)\n\n        return Run(name=cast(str, name), steps=steps, start_date=run_entity[\"start_date\"])\n\n    def _get_run_from_entity(self, run_entity: datastore.Entity) -> Optional[Run]:\n        try:\n            steps_info_bytes = run_entity[\"steps\"]\n            steps_info = json.loads(steps_info_bytes)\n        except KeyError:\n            return None\n\n        import concurrent.futures\n\n        steps: Dict[str, StepInfo] = {}\n        with concurrent.futures.ThreadPoolExecutor(\n            max_workers=self.NUM_CONCURRENT_WORKERS,\n            thread_name_prefix=\"GSWorkspace._get_run_from_dataset()-\",\n        ) as executor:\n            step_info_futures = []\n            for unique_id in steps_info.values():\n                step_info_futures.append(executor.submit(self.step_info, unique_id))\n            for future in concurrent.futures.as_completed(step_info_futures):\n                step_info = future.result()\n                assert step_info.step_name is not None\n                steps[step_info.step_name] = step_info\n\n        return Run(name=run_entity.key.name, start_date=run_entity[\"start_date\"], steps=steps)\n\n    def registered_runs(self) -> Dict[str, Run]:\n        import concurrent.futures\n\n        runs: Dict[str, Run] = {}\n\n        with concurrent.futures.ThreadPoolExecutor(\n            max_workers=self.NUM_CONCURRENT_WORKERS,\n            thread_name_prefix=\"GSWorkspace.registered_runs()-\",\n        ) as executor:\n            run_futures = []\n            for run_entity in self._ds.query(kind=self._run_key).fetch():\n                run_futures.append(executor.submit(self._get_run_from_entity, run_entity))\n            for future in concurrent.futures.as_completed(run_futures):\n                run = future.result()\n                if run is not None:\n                    runs[run.name] = run\n\n        return runs\n\n    def search_registered_runs(\n        self,\n        *,\n        sort_by: Optional[RunSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        start: int = 0,\n        stop: Optional[int] = None,\n    ) -> List[RunInfo]:\n        run_entities = self._fetch_run_entities(\n            sort_by=sort_by, sort_descending=sort_descending, match=match, start=start, stop=stop\n        )\n        return [\n            RunInfo(name=e.key.name, start_date=e[\"start_date\"], steps=json.loads(e[\"steps\"]))\n            for e in run_entities\n        ]\n\n    def num_registered_runs(self, *, match: Optional[str] = None) -> int:\n        count = 0\n        for _ in self._fetch_run_entities(match=match):\n            count += 1\n        return count\n\n    def _fetch_run_entities(\n        self,\n        *,\n        sort_by: Optional[RunSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        start: int = 0,\n        stop: Optional[int] = None,\n    ) -> Generator[datastore.Entity, None, None]:\n        from itertools import islice\n\n        # Note: we can't query or order by multiple fields without a suitable\n        # composite index. So in that case we have to apply remaining filters\n        # or slice and order locally. We'll default to using 'match' in the query.\n        # But if 'match' is null we can sort with the query.\n        sort_locally = bool(match)\n\n        sort_field: Optional[str] = None\n        if sort_by == RunSort.START_DATE:\n            sort_field = \"start_date\"\n        elif sort_by == RunSort.NAME:\n            sort_field = \"name\"\n        elif sort_by is not None:\n            raise NotImplementedError(sort_by)\n\n        order: List[str] = []\n        if sort_field is not None and not sort_locally:\n            order = [sort_field if not sort_descending else f\"-{sort_field}\"]\n\n        query = self._ds.query(kind=self._run_key, order=order)\n        if match:\n            # HACK: Datastore has no direct string matching functionality,\n            # but this comparison is equivalent to checking if 'name' starts with 'match'.\n            query.add_filter(\"name\", \">=\", match)\n            query.add_filter(\"name\", \"<=\", match[:-1] + chr(ord(match[-1]) + 1))\n\n        entity_iter: Iterable[datastore.Entity] = query.fetch(\n            offset=0 if sort_locally else start,\n            limit=None if (stop is None or sort_locally) else stop - start,\n        )\n\n        if sort_field is not None and sort_locally:\n            entity_iter = sorted(\n                entity_iter, key=lambda entity: entity[sort_field], reverse=sort_descending\n            )\n\n        if sort_locally:\n            entity_iter = islice(entity_iter, start, stop)\n\n        for entity in entity_iter:\n            yield entity\n\n    def search_step_info(\n        self,\n        *,\n        sort_by: Optional[StepInfoSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        state: Optional[StepState] = None,\n        start: int = 0,\n        stop: Optional[int] = None,\n    ) -> List[StepInfo]:\n        step_info_entities = self._fetch_step_info_entities(\n            sort_by=sort_by,\n            sort_descending=sort_descending,\n            match=match,\n            state=state,\n            start=start,\n            stop=stop,\n        )\n        return [\n            StepInfo.from_json_dict(json.loads(e[\"step_info_dict\"])) for e in step_info_entities\n        ]\n\n    def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int:\n        count = 0\n        for _ in self._fetch_step_info_entities(match=match, state=state):\n            count += 1\n        return count\n\n    def _fetch_step_info_entities(\n        self,\n        *,\n        sort_by: Optional[StepInfoSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        state: Optional[StepState] = None,\n        start: int = 0,\n        stop: Optional[int] = None,\n    ) -> Generator[datastore.Entity, None, None]:\n        from itertools import islice\n\n        # Note: we can't query or order by multiple fields without a suitable\n        # composite index. So in that case we have to apply remaining filters\n        # or slice and order locally. We'll default to using 'match' in the query.\n        # But if 'match' is null, we'll use 'state' to filter in the query.\n        # If 'state' is also null, we can sort with the query.\n        sort_locally = sort_by is not None and (match is not None or state is not None)\n        filter_locally = state is not None and match is not None\n        slice_locally = sort_locally or filter_locally\n\n        sort_field: Optional[str] = None\n        if sort_by == StepInfoSort.START_TIME:\n            sort_field = \"start_time\"\n        elif sort_by == StepInfoSort.UNIQUE_ID:\n            sort_field = \"step_id\"\n        elif sort_by is not None:\n            raise NotImplementedError(sort_by)\n\n        order: List[str] = []\n        if sort_field is not None and not sort_locally:\n            order = [sort_field if not sort_descending else f\"-{sort_field}\"]\n\n        query = self._ds.query(kind=self._stepinfo_key, order=order)\n\n        if match is not None:\n            # HACK: Datastore has no direct string matching functionality,\n            # but this comparison is equivalent to checking if 'step_id' starts with 'match'.\n            query.add_filter(\"step_id\", \">=\", match)\n            query.add_filter(\"step_id\", \"<=\", match[:-1] + chr(ord(match[-1]) + 1))\n        elif state is not None and not filter_locally:\n            query.add_filter(\"state\", \"=\", str(state.value))\n\n        entity_iter: Iterable[datastore.Entity] = query.fetch(\n            offset=0 if slice_locally else start,\n            limit=None if (stop is None or slice_locally) else stop - start,\n        )\n\n        if state is not None and filter_locally:\n            entity_iter = filter(lambda entity: entity[\"state\"] == state, entity_iter)\n\n        if sort_field is not None and sort_locally:\n            entity_iter = sorted(\n                entity_iter, key=lambda entity: entity[sort_field], reverse=sort_descending\n            )\n\n        if slice_locally:\n            entity_iter = islice(entity_iter, start, stop)\n\n        for entity in entity_iter:\n            yield entity\n\n    def registered_run(self, name: str) -> Run:\n        err_msg = f\"Run '{name}' not found in workspace\"\n\n        run_entity = self._ds.get(key=self._ds.key(self._run_key, name))\n        if not run_entity:\n            raise KeyError(err_msg)\n\n        run = self._get_run_from_entity(run_entity)\n        if run is None:\n            raise KeyError(err_msg)\n        else:\n            return run\n\n    def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:\n        unique_id = (\n            step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id\n        )\n        step_info_entity = self._ds.get(key=self._ds.key(self._stepinfo_key, unique_id))\n        if step_info_entity is not None:\n            step_info_bytes = step_info_entity[\"step_info_dict\"]\n            step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))\n            return step_info\n        else:\n            if not isinstance(step_or_unique_id, Step):\n                raise KeyError(step_or_unique_id)\n            step_info = StepInfo.new_from_step(step_or_unique_id)\n            self._update_step_info(step_info)\n            return step_info\n\n    def _step_info_multiple(\n        self, step_or_unique_ids: Union[List[Step], List[str]]\n    ) -> List[StepInfo]:\n        \"\"\"\n        This method is to combine all calls to the datastore api in a single transaction.\n        \"\"\"\n        all_unique_id_keys = []\n        for step_or_unique_id in step_or_unique_ids:\n            unique_id = (\n                step_or_unique_id\n                if isinstance(step_or_unique_id, str)\n                else step_or_unique_id.unique_id\n            )\n            key = self._ds.key(self._stepinfo_key, unique_id)\n            all_unique_id_keys.append(key)\n\n        missing: List = []\n        step_info_entities = self._ds.get_multi(keys=all_unique_id_keys, missing=missing)\n        missing_steps = [entity.key.name for entity in missing]\n\n        step_infos = []\n        for step_info_entity in step_info_entities:\n            step_info_bytes = step_info_entity[\"step_info_dict\"]\n            step_info = StepInfo.from_json_dict(json.loads(step_info_bytes))\n            step_infos.append(step_info)\n\n        for step_or_unique_id in step_or_unique_ids:\n            step_id = (\n                step_or_unique_id\n                if isinstance(step_or_unique_id, str)\n                else step_or_unique_id.unique_id\n            )\n            if step_id in missing_steps:\n                if not isinstance(step_or_unique_id, Step):\n                    raise KeyError(step_or_unique_id)\n                step_info = StepInfo.new_from_step(step_or_unique_id)\n                self._update_step_info(step_info)\n                step_infos.append(step_info)\n        return step_infos\n\n    def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]:\n        all_steps = set(targets)\n        for step in targets:\n            all_steps |= step.recursive_dependencies\n\n        steps: Dict[str, StepInfo] = {}\n        run_data: Dict[str, str] = {}\n\n        all_valid_steps = [step for step in all_steps if step.name is not None]\n        step_infos = self._step_info_multiple(all_valid_steps)\n\n        for step_info in step_infos:\n            assert step_info.step_name is not None\n            steps[step_info.step_name] = step_info\n            run_data[step_info.step_name] = step_info.unique_id\n\n        return steps, run_data\n\n    def _update_step_info(self, step_info: StepInfo):\n        step_info_entity = self._ds.entity(\n            key=self._ds.key(self._stepinfo_key, step_info.unique_id),\n            exclude_from_indexes=(\"step_info_dict\",),\n        )\n\n        # Even though the step's unique ID is part of the key, we add this as a\n        # field so we can index on it and order asc/desc (indices on the key field don't allow ordering).\n        step_info_entity[\"step_id\"] = step_info.unique_id\n        step_info_entity[\"step_name\"] = step_info.step_name\n        step_info_entity[\"start_time\"] = step_info.start_time\n        step_info_entity[\"end_time\"] = step_info.end_time\n        step_info_entity[\"state\"] = str(step_info.state.value)\n        step_info_entity[\"updated\"] = utc_now_datetime()\n        step_info_entity[\"step_info_dict\"] = json.dumps(step_info.to_json_dict()).encode()\n\n        self._ds.put(step_info_entity)\n\n    def _remove_step_info(self, step_info: StepInfo) -> None:\n        # remove dir from bucket\n        step_artifact = self.client.get(self.Constants.step_artifact_name(step_info))\n        if step_artifact is not None:\n            self.client.delete(step_artifact)\n\n        # remove datastore entities\n        self._ds.delete(key=self._ds.key(\"stepinfo\", step_info.unique_id))\n\n    def _save_run_log(self, name: str, log_file: Path):\n        \"\"\"\n        The logs are stored in the bucket. The Run object details are stored in\n        the remote database.\n        \"\"\"\n        run_dataset = self.Constants.run_artifact_name(name)\n        self.client.upload(run_dataset, log_file)\n"
  },
  {
    "path": "tango/integrations/torch/__init__.py",
    "content": "# -*- coding: UTF-8 -*-\n\"\"\"\n.. important::\n    To use this integration you should install ``tango`` with the \"torch\" extra\n    (e.g. ``pip install tango[torch]``) or just install PyTorch after the fact.\n\n    Make sure you install the correct version of torch given your operating system\n    and supported CUDA version. Check\n    `pytorch.org/get-started/locally/ <https://pytorch.org/get-started/locally/>`_\n    for more details.\n\nComponents for Tango integration with `PyTorch <https://pytorch.org/>`_.\n\nThese include a training loop :class:`~tango.step.Step` and registrable versions\nof many ``torch`` classes, such :class:`torch.optim.Optimizer` and :class:`torch.utils.data.DataLoader`.\n\nExample: training a model\n-------------------------\n\nLet's look a simple example of training a model.\n\nWe'll make a basic regression model and generate some fake data to train on.\nFirst, the setup:\n\n.. testcode::\n\n    import torch\n    import torch.nn as nn\n\n    from tango.common.dataset_dict import DatasetDict\n    from tango.step import Step\n    from tango.integrations.torch import Model\n\nNow let's build and register our model:\n\n.. testcode::\n\n    @Model.register(\"basic_regression\")\n    class BasicRegression(Model):\n        def __init__(self):\n            super().__init__()\n            self.linear = nn.Linear(10, 1)\n            self.sigmoid = nn.Sigmoid()\n            self.mse = nn.MSELoss()\n\n        def forward(self, x, y=None):\n            pred = self.sigmoid(self.linear(x))\n            out = {\"pred\": pred}\n            if y is not None:\n                out[\"loss\"] = self.mse(pred, y)\n            return out\n\n        def _to_params(self):\n            return {}\n\nLastly, we'll need a step to generate data:\n\n.. testcode::\n\n    @Step.register(\"generate_data\")\n    class GenerateData(Step):\n        DETERMINISTIC = True\n        CACHEABLE = False\n\n        def run(self) -> DatasetDict:\n            torch.manual_seed(1)\n            return DatasetDict(\n                {\n                    \"train\": [{\"x\": torch.rand(10), \"y\": torch.rand(1)} for _ in range(64)],\n                    \"validation\": [{\"x\": torch.rand(10), \"y\": torch.rand(1)} for _ in range(32)],\n                }\n            )\n\nYou could then run this experiment with a config that looks like this:\n\n.. literalinclude:: ../../../../test_fixtures/integrations/torch/train.jsonnet\n\n.. testcode::\n    :hide:\n\n    from tango.common.testing import run_experiment\n    from tango.common.registrable import Registrable\n\n    # Pickling the model fails because the class is defined ad hoc, not in a module.\n    # So we put in this hack to pickle a 0 instead of the Model.\n    def _return_zero(self):\n        return (int, (0,))\n    BasicRegression.__reduce__ = _return_zero\n\n    with run_experiment(\n        \"test_fixtures/integrations/torch/train.jsonnet\", name=\"boss-alien\"\n    ) as run_dir:\n        assert (run_dir / \"train\").is_dir(), \"Output for the 'train' step was not produced.\"\n    # Restore state of registry.\n    del Registrable._registry[Step][\"generate_data\"]\n    del Registrable._registry[Model][\"basic_regression\"]\n\nFor example,\n\n.. code-block::\n\n    tango run train.jsonnet -i my_package -d /tmp/train\n\nwould produce the following output:\n\n.. testoutput::\n    :options: +ELLIPSIS\n\n    Starting new run boss-alien\n    ● Starting step \"data\" (needed by \"train\")...\n    ✓ Finished step \"data\"\n    ● Starting step \"train\"...\n    ✓ Finished step \"train\"\n    ✓ Finished run boss-alien\n    ...\n\nTips\n----\n\nDebugging\n~~~~~~~~~\n\nWhen debugging a training loop that's causing errors on a GPU, you should set the environment variable\n``CUDA_LAUNCH_BLOCKING=1``. This will ensure that the stack traces shows where the error actually happened.\n\nYou could also use a custom :class:`TrainCallback` to log each batch before they are passed into the model\nso that you can see the exact inputs that are causing the issue.\n\nStopping early\n~~~~~~~~~~~~~~\n\nYou can stop the \"torch::train\" step early using a custom :class:`TrainCallback`. Your callback just\nneeds to raise the :class:`StopEarly` exception.\n\n\"\"\"\n\nfrom tango.common.exceptions import IntegrationMissingError\n\ntry:\n    import torch\nexcept ModuleNotFoundError:\n    raise IntegrationMissingError(\"torch\")\n\n__all__ = [\n    \"TorchFormat\",\n    \"TorchTrainStep\",\n    \"TorchEvalStep\",\n    \"Optimizer\",\n    \"LRScheduler\",\n    \"Model\",\n    \"DataLoader\",\n    \"DataCollator\",\n    \"Sampler\",\n    \"ConcatTensorDictsCollator\",\n    \"TrainCallback\",\n    \"EvalCallback\",\n    \"TrainConfig\",\n    \"StopEarlyCallback\",\n    \"StopEarly\",\n    \"TrainingEngine\",\n    \"TorchTrainingEngine\",\n]\n\nfrom .data import ConcatTensorDictsCollator, DataCollator, DataLoader, Sampler\nfrom .eval import TorchEvalStep\nfrom .eval_callback import EvalCallback\nfrom .exceptions import StopEarly\nfrom .format import TorchFormat\nfrom .model import Model\nfrom .optim import LRScheduler, Optimizer\nfrom .train import TorchTrainStep\nfrom .train_callback import StopEarlyCallback, TrainCallback\nfrom .train_config import TrainConfig\nfrom .training_engine import TorchTrainingEngine, TrainingEngine\n"
  },
  {
    "path": "tango/integrations/torch/data.py",
    "content": "from typing import Any, Dict, Generic, List, Optional, TypeVar, Union\n\nimport torch\n\nfrom tango.common.lazy import Lazy\nfrom tango.common.registrable import Registrable\n\nT = TypeVar(\"T\")\n\n\nclass DataCollator(Generic[T], Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of a ``collate_fn``\n    for a ``DataLoader``.\n\n    Subclasses just need to implement :meth:`__call__()`.\n    \"\"\"\n\n    default_implementation = \"concat_tensor_dicts\"\n    \"\"\"\n    The default implementation is :class:`ConcatTensorDictsCollator`.\n    \"\"\"\n\n    def __call__(self, items: List[T]) -> Dict[str, Any]:\n        \"\"\"\n        Takes a list of items from a dataset and combines them into a batch.\n        \"\"\"\n        raise NotADirectoryError\n\n\n@DataCollator.register(\"concat_tensor_dicts\")\nclass ConcatTensorDictsCollator(DataCollator[Dict[str, Any]]):\n    \"\"\"\n    A simple ``collate_fn`` that expects items to be dictionaries of tensors.\n    The tensors are just concatenated together.\n\n    .. tip::\n\n        Registered as a :class:`DataCollator` under the name \"concat_tensor_dicts\".\n    \"\"\"\n\n    def __call__(self, items: List[Dict[str, Any]]) -> Dict[str, Any]:\n        out = {}\n        keys = items[0].keys()\n        for key in keys:\n            if isinstance(items[0][key], torch.Tensor):\n                out[key] = torch.cat([item[key].unsqueeze(0) for item in items])\n            elif isinstance(items[0][key], (int, float)):\n                out[key] = torch.tensor([item[key] for item in items])\n            else:\n                out[key] = [item[key] for item in items]  # type: ignore[assignment]\n        return out\n\n\nclass Sampler(torch.utils.data.Sampler, Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of a PyTorch\n    :class:`~torch.utils.data.Sampler`.\n\n    All `built-in PyTorch samplers\n    <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`_\n    are registered under their corresponding class name (e.g. \"RandomSampler\").\n    \"\"\"\n\n\n@Sampler.register(\"torch::BatchSampler\")\nclass BatchSampler(torch.utils.data.BatchSampler, Sampler):\n    def __init__(\n        self,\n        dataset: torch.utils.data.Dataset,\n        sampler: Union[Lazy[Sampler], Sampler],\n        batch_size: int,\n        drop_last: bool,\n    ) -> None:\n        super().__init__(\n            sampler.construct(data_source=dataset, dataset=dataset)\n            if isinstance(sampler, Lazy)\n            else sampler,\n            batch_size,\n            drop_last,\n        )\n\n\n# Register all remaining samplers.\nfor name, cls in torch.utils.data.__dict__.items():\n    registered_name = \"torch::\" + name\n    if (\n        isinstance(cls, type)\n        and issubclass(cls, torch.utils.data.Sampler)\n        and not cls == torch.utils.data.Sampler\n        and registered_name not in Sampler.list_available()\n    ):\n        Sampler.register(registered_name)(cls)\n\n\nclass DataLoader(torch.utils.data.DataLoader, Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of a PyTorch\n    :class:`~torch.utils.data.DataLoader`.\n    \"\"\"\n\n    default_implementation = \"default\"\n\n    def __init__(\n        self,\n        dataset: torch.utils.data.Dataset,\n        collate_fn: Optional[DataCollator] = ConcatTensorDictsCollator(),\n        sampler: Optional[Union[Lazy[Sampler], Sampler]] = None,\n        **kwargs,\n    ):\n        super().__init__(\n            dataset,\n            collate_fn=collate_fn,\n            sampler=sampler.construct(data_source=dataset, dataset=dataset)\n            if isinstance(sampler, Lazy)\n            else sampler,\n            **kwargs,\n        )\n\n\nDataLoader.register(\"default\")(DataLoader)\n"
  },
  {
    "path": "tango/integrations/torch/eval.py",
    "content": "from collections import defaultdict\nfrom itertools import islice\nfrom typing import Dict, List, Optional, Sequence\n\nimport torch\n\nfrom tango.common.dataset_dict import DatasetDictBase\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.lazy import Lazy\nfrom tango.common.tqdm import Tqdm\nfrom tango.format import Format, JsonFormat\nfrom tango.step import Step, StepResources\n\nfrom .data import DataLoader\nfrom .eval_callback import EvalCallback\nfrom .model import Model\nfrom .util import check_dataset, move_to_device, resolve_device, set_seed_all\n\n\n@Step.register(\"torch::eval\")\nclass TorchEvalStep(Step):\n    \"\"\"\n    A PyTorch evaluation loop that pairs well with :class:`TorchTrainStep`.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"torch::eval\".\n\n    .. important::\n\n        The evaluation loop will use a GPU automatically if one is available.\n        You can control which GPU it uses with the environment variable ``CUDA_VISIBLE_DEVICES``.\n        For example, set ``CUDA_VISIBLE_DEVICES=1`` to force ``TorchEvalStep`` to only use\n        the GPU with ID 1.\n\n    .. warning::\n\n        By default the metrics specified by the ``metric_names`` parameter\n        are aggregated by simply averaging across batches.\n        This behavior is usually correct for metrics like \"loss\" or \"accuracy\",\n        for example, but may not be correct for other metrics like \"F1\".\n\n        If this is not correct for your metric you will need to handle the aggregation\n        internally in your model or with an :class:`EvalCallback`\n        using the :meth:`EvalCallback.post_batch()` method.\n        Then set the parameter ``auto_aggregate_metrics`` to ``False``.\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = JsonFormat()\n    SKIP_ID_ARGUMENTS = {\"log_every\"}\n\n    @property\n    def resources(self) -> StepResources:\n        return self.step_resources or StepResources(gpu_count=1)\n\n    def run(  # type: ignore[override]\n        self,\n        model: Model,\n        dataset_dict: DatasetDictBase,\n        dataloader: Lazy[DataLoader],\n        test_split: str = \"test\",\n        seed: int = 42,\n        eval_steps: Optional[int] = None,\n        log_every: int = 1,\n        metric_names: Sequence[str] = (\"loss\",),\n        auto_aggregate_metrics: bool = True,\n        callbacks: Optional[List[Lazy[EvalCallback]]] = None,\n    ) -> Dict[str, float]:\n        \"\"\"\n        Evaluate the ``model``.\n\n        :param model:\n            The model to evaluate. It should return a ``dict`` from its ``forward()`` method\n            that includes all of the metrics in ``metric_names`` .\n        :param dataset_dict:\n            Should contain the test data.\n        :param dataloader:\n            The data loader that generates test batches. The batches should be :class:`dict`\n            objects.\n        :param test_split:\n            The name of the data split used for evaluation in the ``dataset_dict``.\n            Default is \"test\".\n        :param seed:\n            Used to set the RNG states at the beginning of the evaluation loop.\n        :param eval_steps:\n            The number of steps to evaluate for. If not specified evaluation will\n            stop after a complete iteration through the ``dataloader``.\n        :param log_every:\n            Log every this many steps. Default is ``1``.\n        :param metric_names:\n            The names of the metrics to track and aggregate. Default is ``(\"loss\",)``.\n        :param auto_aggregate_metrics:\n            If ``True`` (the default), the metrics will be averaged across batches.\n            This may not be the correct behavior for some metrics (such as F1),\n            in which you should set this to ``False`` and handle the aggregation\n            internally in your model or with an :class:`EvalCallback`\n            (using :meth:`EvalCallback.post_batch()`).\n        :param callbacks:\n            A list of :class:`EvalCallback`.\n\n        \"\"\"\n        set_seed_all(seed)\n\n        check_dataset(dataset_dict, test_split)\n\n        # Resolve device.\n        device = resolve_device()\n\n        # Prep model.\n        model = model.eval().to(device)\n\n        # Construct dataloader.\n        dataloader: DataLoader = dataloader.construct(dataset=dataset_dict[test_split])\n\n        steps: int\n        try:\n            dataloader_len = len(dataloader)\n            steps = dataloader_len if eval_steps is None else min(dataloader_len, eval_steps)\n        except TypeError:\n            if eval_steps is None:\n                raise ConfigurationError(\n                    \"You must set 'eval_steps' for streaming/iterable datasets\"\n                )\n            else:\n                steps = eval_steps\n\n        # Initialize callbacks.\n        callbacks: List[EvalCallback] = [\n            callback.construct(\n                workspace=self.workspace,\n                step_id=self.unique_id,\n                work_dir=self.work_dir,\n                model=model,\n                dataset_dict=dataset_dict,\n                dataloader=dataloader,\n            )\n            for callback in (callbacks or [])\n        ]\n        for callback in callbacks:\n            callback.pre_eval_loop()\n\n        eval_batches = enumerate(islice(dataloader, steps))\n\n        running_metrics: Dict[str, float] = defaultdict(float)\n        aggregated_metrics: Dict[str, float] = {}\n\n        with Tqdm.tqdm(eval_batches, desc=\"Evaluating\", total=steps) as batch_iter:\n            for step, batch in batch_iter:\n                should_log_this_step = step % log_every == 0 or step == steps - 1\n\n                for callback in callbacks:\n                    callback.pre_batch(step, batch)\n\n                batch = move_to_device(batch, device)\n                with torch.inference_mode():\n                    outputs = model(**batch)\n\n                for callback in callbacks:\n                    callback.post_batch(step, outputs)\n\n                # Gather metrics we want to track.\n                batch_metrics = {\n                    k: outputs[k].item() if isinstance(outputs[k], torch.Tensor) else outputs[k]\n                    for k in metric_names\n                }\n\n                # Aggregate metrics.\n                if auto_aggregate_metrics:\n                    for k in batch_metrics:\n                        running_metrics[k] += batch_metrics[k]\n                        aggregated_metrics[k] = running_metrics[k] / (step + 1)\n                else:\n                    aggregated_metrics.update(batch_metrics)\n\n                # Update progress bar.\n                if should_log_this_step:\n                    batch_iter.set_postfix(**aggregated_metrics)\n\n                # Clean up to help garbage collector. Hopefully this saves memory.\n                del batch\n                del outputs\n                del batch_metrics\n\n        for callback in callbacks:\n            callback.post_eval_loop(aggregated_metrics)\n\n        return aggregated_metrics\n"
  },
  {
    "path": "tango/integrations/torch/eval_callback.py",
    "content": "from pathlib import Path\nfrom typing import Any, Dict\n\nfrom tango.common.dataset_dict import DatasetDictBase\nfrom tango.common.registrable import Registrable\nfrom tango.workspace import Workspace\n\nfrom .data import DataLoader\nfrom .model import Model\n\n\nclass EvalCallback(Registrable):\n    \"\"\"\n    An ``EvalCallback`` is a :class:`~tango.common.Registrable` class that can be used\n    within :class:`TorchEvalStep` to customize the behavior of the evaluation loop,\n    similar to how :class:`TrainCallback` is used to customize the behavior of the training\n    loop.\n\n    .. tip::\n        All of the parameters to this base class will be automatically set within\n        the training loop, so you shouldn't include them in your config for your callbacks.\n\n    :ivar Workspace workspace: The tango workspace being used.\n    :ivar str step_id: The unique ID of the step.\n    :ivar pathlib.Path work_dir: The working directory of the step\n    :ivar Model model: The model being evaluated.\n    :ivar DatasetDictBase dataset_dict: The dataset dict containing the evaluation split.\n    :ivar DataLoader dataloader: The data loader used to load the evaluation split data.\n    \"\"\"\n\n    def __init__(\n        self,\n        workspace: Workspace,\n        step_id: str,\n        work_dir: Path,\n        model: Model,\n        dataset_dict: DatasetDictBase,\n        dataloader: DataLoader,\n    ) -> None:\n        self.workspace = workspace\n        self.step_id = step_id\n        self.work_dir = work_dir\n        self.model = model\n        self.dataset_dict = dataset_dict\n        self.dataloader = dataloader\n\n    def pre_eval_loop(self) -> None:\n        \"\"\"\n        Called right before the first batch is processed.\n        \"\"\"\n        pass\n\n    def post_eval_loop(self, aggregated_metrics: Dict[str, float]) -> None:\n        \"\"\"\n        Called after the evaluation loop completes with the final aggregated metrics.\n\n        This is the last method that is called, so any cleanup can be done in this method.\n        \"\"\"\n        pass\n\n    def pre_batch(self, step: int, batch: Dict[str, Any]) -> None:\n        \"\"\"\n        Called directly before processing a batch.\n        \"\"\"\n        pass\n\n    def post_batch(self, step: int, batch_outputs: Dict[str, Any]) -> None:\n        \"\"\"\n        Called directly after processing a batch with the outputs of the batch.\n\n        .. tip::\n            This method can be used to modify ``batch_outputs`` in place, which is useful\n            in scenarios where you might need to aggregate metrics\n            in a special way other than a simple average. If that's the case, make sure\n            to set ``auto_aggregate_metrics`` to ``False`` in :class:`TorchEvalStep`.\n\n        \"\"\"\n        pass\n"
  },
  {
    "path": "tango/integrations/torch/exceptions.py",
    "content": "from tango.common.exceptions import TangoError\n\n\nclass StopEarly(TangoError):\n    \"\"\"\n    Callbacks can raise this exception to stop training early without crashing.\n\n    .. important::\n        During distributed training all workers must raise this exception at the same point\n        in the training loop, otherwise there will be a deadlock.\n    \"\"\"\n"
  },
  {
    "path": "tango/integrations/torch/format.py",
    "content": "from pathlib import Path\nfrom typing import Generic, TypeVar\n\nimport dill\nimport torch\n\nfrom tango.common.aliases import PathOrStr\nfrom tango.format import Format\n\nT = TypeVar(\"T\")\n\n\n@Format.register(\"torch\")\nclass TorchFormat(Format[T], Generic[T]):\n    \"\"\"\n    This format writes the artifact using ``torch.save()``.\n\n    Unlike :class:`tango.format.DillFormat`, this has no special support for iterators.\n\n    .. tip::\n\n        Registered as a :class:`~tango.format.Format` under the name \"torch\".\n\n    \"\"\"\n\n    VERSION = \"002\"\n\n    def write(self, artifact: T, dir: PathOrStr):\n        filename = Path(dir) / \"data.pt\"\n        with open(filename, \"wb\") as f:\n            torch.save((self.VERSION, artifact), f, pickle_module=dill)\n\n    def read(self, dir: PathOrStr) -> T:\n        filename = Path(dir) / \"data.pt\"\n        with open(filename, \"rb\") as f:\n            version, artifact = torch.load(f, pickle_module=dill, map_location=torch.device(\"cpu\"))\n            if version > self.VERSION:\n                raise ValueError(\n                    f\"File {filename} is too recent for this version of {self.__class__}.\"\n                )\n            return artifact\n"
  },
  {
    "path": "tango/integrations/torch/model.py",
    "content": "import torch\n\nfrom tango.common.registrable import Registrable\n\n\nclass Model(torch.nn.Module, Registrable):\n    \"\"\"\n    This is a :class:`~tango.common.Registrable` mixin class that inherits from\n    :class:`torch.nn.Module`.\n    Its :meth:`~torch.nn.Module.forward()` method should return a :class:`dict` that\n    includes the ``loss`` during training and any tracked metrics during validation.\n    \"\"\"\n"
  },
  {
    "path": "tango/integrations/torch/optim.py",
    "content": "from typing import Type\n\nimport torch\n\nfrom tango.common.registrable import Registrable\n\n\nclass Optimizer(torch.optim.Optimizer, Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of a PyTorch\n    :class:`torch.optim.Optimizer`.\n\n    All `built-in PyTorch optimizers\n    <https://pytorch.org/docs/stable/optim.html#algorithms>`_\n    are registered according to their class name (e.g. \"torch::Adam\").\n\n    .. tip::\n\n        You can see a list of all available optimizers by running\n\n        .. testcode::\n\n            from tango.integrations.torch import Optimizer\n            for name in sorted(Optimizer.list_available()):\n                print(name)\n\n        .. testoutput::\n            :options: +ELLIPSIS\n\n            torch::ASGD\n            torch::Adadelta\n            torch::Adagrad\n            torch::Adam\n            torch::AdamW\n            ...\n\n    \"\"\"\n\n\nclass LRScheduler(torch.optim.lr_scheduler._LRScheduler, Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of a PyTorch learning\n    rate scheduler.\n\n    All `built-in PyTorch learning rate schedulers\n    <https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate>`_\n    are registered according to their class name (e.g. \"torch::StepLR\").\n\n    .. tip::\n\n        You can see a list of all available schedulers by running\n\n        .. testcode::\n\n            from tango.integrations.torch import LRScheduler\n            for name in sorted(LRScheduler.list_available()):\n                print(name)\n\n        .. testoutput::\n            :options: +ELLIPSIS\n\n            torch::ChainedScheduler\n            torch::ConstantLR\n            torch::CosineAnnealingLR\n            ...\n    \"\"\"\n\n\n# Register all optimizers.\nfor name, cls in torch.optim.__dict__.items():\n    if (\n        isinstance(cls, type)\n        and issubclass(cls, torch.optim.Optimizer)\n        and not cls == torch.optim.Optimizer\n    ):\n        Optimizer.register(\"torch::\" + name)(cls)\n\n# Note: This is a hack. Remove after we upgrade the torch version.\nbase_class: Type\ntry:\n    base_class = torch.optim.lr_scheduler.LRScheduler\nexcept AttributeError:\n    base_class = torch.optim.lr_scheduler._LRScheduler\n\n# Register all learning rate schedulers.\nfor name, cls in torch.optim.lr_scheduler.__dict__.items():\n    if isinstance(cls, type) and issubclass(cls, base_class) and not cls == base_class:\n        LRScheduler.register(\"torch::\" + name)(cls)\n"
  },
  {
    "path": "tango/integrations/torch/train.py",
    "content": "import logging\nimport math\nimport os\nimport shutil\nfrom itertools import islice\nfrom typing import Any, Dict, List, Optional, Set, Union, cast\n\nimport more_itertools\nimport torch\nimport torch.distributed as dist\nfrom more_itertools import chunked\nfrom torch.utils.data import DistributedSampler\n\nfrom tango.common.dataset_dict import DatasetDictBase\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.lazy import Lazy\nfrom tango.common.tqdm import Tqdm\nfrom tango.common.util import get_extra_imported_modules, import_extra_module\nfrom tango.format import Format\nfrom tango.step import Step, StepResources\nfrom tango.workspace import Workspace\n\nfrom .data import DataLoader\nfrom .exceptions import StopEarly\nfrom .format import TorchFormat\nfrom .model import Model\nfrom .train_callback import TrainCallback\nfrom .train_config import TrainConfig\nfrom .training_engine import TrainingEngine\nfrom .util import check_dataloader, check_dataset, set_seed_all\n\n\n@Step.register(\"torch::train\")\nclass TorchTrainStep(Step):\n    \"\"\"\n    A PyTorch training loop step that supports gradient accumulation, distributed training,\n    and AMP, with configurable dataloaders, callbacks, optimizer, and LR scheduler.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"torch::train\".\n\n    .. important::\n\n        The training loop will use GPU(s) automatically when available, as long as at least\n        ``device_count`` CUDA devices are available.\n\n        Distributed data parallel training is activated when the ``device_count`` is greater than 1.\n\n        You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``.\n        For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1``\n        (and ``device_count`` to 2).\n\n    .. warning::\n\n        During validation, the validation metric (specified by the ``val_metric_name`` parameter)\n        is aggregated by simply averaging across validation batches and distributed processes.\n        This behavior is usually correct when your validation metric is \"loss\" or \"accuracy\",\n        for example, but may not be correct for other metrics like \"F1\".\n\n        If this is not correct for your metric you will need to handle the aggregation\n        internally in your model or with a :class:`TrainCallback`\n        using the :meth:`TrainCallback.post_val_batch()` method.\n        Then set the parameter ``auto_aggregate_val_metric`` to ``False``.\n\n        Note that correctly aggregating your metric during distributed training will\n        involve distributed communication.\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = TorchFormat()\n    SKIP_ID_ARGUMENTS = {\"distributed_port\", \"log_every\"}\n    METADATA = {\"artifact_kind\": \"model\"}\n\n    @property\n    def resources(self) -> StepResources:\n        return self.step_resources or StepResources(gpu_count=self.kwargs[\"device_count\"])\n\n    def run(  # type: ignore[override]\n        self,\n        model: Union[Lazy[Model], Model],  # Lazy has to come first\n        training_engine: Lazy[TrainingEngine],\n        dataset_dict: DatasetDictBase,\n        train_dataloader: Lazy[DataLoader],\n        *,\n        train_split: str = \"train\",\n        validation_split: Optional[str] = None,\n        validation_dataloader: Optional[Lazy[DataLoader]] = None,\n        seed: int = 42,\n        train_steps: Optional[int] = None,\n        train_epochs: Optional[int] = None,\n        validation_steps: Optional[int] = None,\n        grad_accum: int = 1,\n        log_every: int = 10,\n        checkpoint_every: int = 100,\n        validate_every: Optional[int] = None,\n        device_count: int = 1,\n        distributed_port: int = 54761,\n        val_metric_name: str = \"loss\",\n        minimize_val_metric: bool = True,\n        auto_aggregate_val_metric: bool = True,\n        callbacks: Optional[List[Lazy[TrainCallback]]] = None,\n        remove_stale_checkpoints: bool = True,\n    ) -> Model:\n        \"\"\"\n        Run a basic training loop to train the ``model``.\n\n        :param model:\n            The model to train. It should return a ``dict`` that includes the ``loss``\n            during training and the ``val_metric_name`` during validation.\n        :param training_engine:\n            The :class:`TrainingEngine` to use to train the model.\n        :param dataset_dict:\n            The train and optional validation data.\n        :param train_dataloader:\n            The data loader that generates training batches. The batches should be :class:`dict`\n            objects that will be used as ``kwargs`` for the model's ``forward()`` method.\n        :param train_split:\n            The name of the data split used for training in the ``dataset_dict``.\n            Default is \"train\".\n        :param validation_split:\n            Optional name of the validation split in the ``dataset_dict``. Default is ``None``,\n            which means no validation.\n        :param validation_dataloader:\n            An optional data loader for generating validation batches. The batches should be\n            :class:`dict` objects. If not specified, but ``validation_split`` is given,\n            the validation ``DataLoader`` will be constructed from the same parameters\n            as the train ``DataLoader``.\n        :param seed:\n            Used to set the RNG states at the beginning of training.\n        :param train_steps:\n            The number of steps to train for. If not specified training will\n            stop after a complete iteration through the ``train_dataloader``.\n        :param train_epochs:\n            The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs``\n            at the same time.\n        :param validation_steps:\n            The number of steps to validate for. If not specified validation\n            will stop after a complete iteration through the ``validation_dataloader``.\n        :param grad_accum:\n            The number of gradient accumulation steps. Defaults to 1.\n\n            .. note::\n                This parameter - in conjuction with the settings of your data loader\n                and the number distributed workers -\n                determines the *effective batch size* of your training run.\n\n        :param log_every:\n            Log every this many steps.\n        :param checkpoint_every:\n            Save a checkpoint every this many steps.\n        :param validate_every:\n            Run the validation loop every this many steps.\n        :param device_count:\n            The number of devices to train on, i.e. the number of distributed data parallel workers.\n        :param distributed_port:\n            The port of the distributed process group. Default = \"54761\".\n        :param val_metric_name:\n            The name of the validation metric, i.e. the key of the metric in the dictionary\n            returned by the forward pass of the model. Default is \"loss\".\n        :param minimize_val_metric:\n            Whether the validation metric is meant to be minimized (such as the loss).\n            Default is ``True``. When using a metric such as accuracy, you should set\n            this to ``False``.\n        :param auto_aggregate_val_metric:\n            If ``True`` (the default), the validation metric will be averaged across\n            validation batches and distributed processes. This may not be the correct\n            behavior for some metrics (such as F1), in which you should set this to\n            ``False`` and handle the aggregation internally in your model\n            or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`).\n        :param callbacks:\n            A list of :class:`TrainCallback`.\n        :param remove_stale_checkpoints:\n            If ``True`` (the default), stale checkpoints will be removed throughout training so that\n            only the latest and best checkpoints are kept.\n\n        :returns:\n            The trained model on CPU with the weights from the best checkpoint loaded.\n\n        \"\"\"\n\n        devices = self._get_devices(device_count)\n\n        return self._train(\n            model=model,\n            training_engine=training_engine,\n            dataset_dict=dataset_dict,\n            train_dataloader=train_dataloader,\n            train_split=train_split,\n            validation_split=validation_split,\n            validation_dataloader=validation_dataloader,\n            seed=seed,\n            train_steps=train_steps,\n            train_epochs=train_epochs,\n            validation_steps=validation_steps,\n            grad_accum=grad_accum,\n            log_every=log_every,\n            checkpoint_every=checkpoint_every,\n            validate_every=validate_every,\n            devices=devices,\n            distributed_port=distributed_port,\n            val_metric_name=val_metric_name,\n            minimize_val_metric=minimize_val_metric,\n            auto_aggregate_val_metric=auto_aggregate_val_metric,\n            callbacks=callbacks,\n            remove_stale_checkpoints=remove_stale_checkpoints,\n        )\n\n    def _get_devices(self, device_count: int) -> List[int]:\n        \"\"\"\n        Validates the device count, and returns the list of devices.\n        \"\"\"\n        # Validate device(s).\n        if device_count <= 0:\n            raise ConfigurationError(\"Invalid value for 'device_count'. Must be at least 1.\")\n        devices: List[int]\n        if torch.cuda.is_available() and torch.cuda.device_count() >= device_count:\n            devices = list(range(device_count))\n            self.logger.info(\"Training on %d GPU%s\", device_count, \"s\" if device_count > 1 else \"\")\n        else:\n            devices = [-1] * device_count\n            self.logger.info(\n                \"Training on CPU with %d worker%s\", device_count, \"s\" if device_count > 1 else \"\"\n            )\n        return devices\n\n    def _train(\n        self,\n        model: Union[Model, Lazy[Model]],\n        training_engine: Lazy[TrainingEngine],\n        dataset_dict: DatasetDictBase,\n        train_dataloader: Lazy[DataLoader],\n        *,\n        train_split: str = \"train\",\n        validation_split: Optional[str] = None,\n        validation_dataloader: Optional[Lazy[DataLoader]] = None,\n        seed: int = 42,\n        train_steps: Optional[int] = None,\n        train_epochs: Optional[int] = None,\n        validation_steps: Optional[int] = None,\n        grad_accum: int = 1,\n        log_every: int = 10,\n        checkpoint_every: int = 100,\n        validate_every: Optional[int] = None,\n        devices: Optional[List[int]] = None,\n        distributed_port: int = 54761,\n        val_metric_name: str = \"loss\",\n        minimize_val_metric: bool = True,\n        auto_aggregate_val_metric: bool = True,\n        callbacks: Optional[List[Lazy[TrainCallback]]] = None,\n        remove_stale_checkpoints: bool = True,\n    ) -> Model:\n        is_distributed = False\n        num_workers = 1\n        if devices and len(devices) > 1:\n            is_distributed = True\n            num_workers = len(devices)\n\n        if validate_every is not None and validation_split is None:\n            raise ConfigurationError(\n                \"You have set a validation interval, but no validation split. \"\n                \"That's probably unintentional.\"\n            )\n\n        if (train_steps is not None) == (train_epochs is not None):\n            raise ConfigurationError(\n                \"One of 'train_steps' or 'train_epochs' needs to be specified, but not both.\"\n            )\n\n        if validate_every is not None and checkpoint_every is not None:\n            if checkpoint_every % validate_every != 0 and validate_every % checkpoint_every != 0:\n                raise ConfigurationError(\n                    \"'checkpoint_every' needs to be multiple of 'validate_every' or vice versa\"\n                )\n\n        config = TrainConfig(\n            self.unique_id,\n            self.work_dir,\n            step_name=self.name,\n            train_split=train_split,\n            validation_split=validation_split,\n            seed=seed,\n            train_steps=train_steps,\n            train_epochs=train_epochs,\n            grad_accum=grad_accum,\n            log_every=log_every,\n            checkpoint_every=checkpoint_every,\n            validate_every=validate_every,\n            validation_steps=validation_steps,\n            is_distributed=is_distributed,\n            devices=devices,\n            distributed_port=distributed_port,\n            val_metric_name=val_metric_name,\n            minimize_val_metric=minimize_val_metric,\n            auto_aggregate_val_metric=auto_aggregate_val_metric,\n            remove_stale_checkpoints=remove_stale_checkpoints,\n            world_size=num_workers,\n        )\n\n        final_model: Model\n        if is_distributed:\n            import torch.multiprocessing as mp\n\n            mp.spawn(\n                _train,\n                args=(\n                    self.workspace,\n                    config,\n                    model,\n                    training_engine,\n                    dataset_dict,\n                    train_dataloader,\n                    validation_dataloader,\n                    callbacks,\n                    get_extra_imported_modules(),\n                ),\n                nprocs=num_workers,\n            )\n            self.logger.info(\"Constructing final model\")\n            if isinstance(model, Lazy):\n                final_model = model.construct()\n            else:\n                final_model = model\n        else:\n            final_model = _train(  # type: ignore[assignment]\n                0,\n                self.workspace,\n                config,\n                model,\n                training_engine,\n                dataset_dict,\n                train_dataloader,\n                validation_dataloader=validation_dataloader,\n                callbacks=callbacks,\n            )\n            assert final_model is not None\n            final_model = final_model.cpu()\n\n        # Load best checkpoint before returning model.\n        if config.final_weights_path.is_file():\n            self.logger.info(\n                f\"Loading best weights from {str(config.final_weights_path.resolve())}\"\n            )\n            state = torch.load(config.final_weights_path, map_location=\"cpu\")\n            # We use `strict=False` because there might be missing keys due to weight tying.\n            final_model.load_state_dict(state, strict=False)\n\n        return final_model\n\n\ndef _train(\n    worker_id: int,\n    workspace: Workspace,\n    config: TrainConfig,\n    model: Union[Model, Lazy[Model]],\n    training_engine: Lazy[TrainingEngine],\n    dataset_dict: DatasetDictBase,\n    train_dataloader: Lazy[DataLoader],\n    validation_dataloader: Optional[Lazy[DataLoader]] = None,\n    callbacks: Optional[List[Lazy[TrainCallback]]] = None,\n    include_package: Optional[Set[str]] = None,\n) -> Optional[Model]:\n    # Set random seeds.\n    set_seed_all(config.seed)\n\n    config.worker_id = worker_id\n\n    if config.is_distributed and include_package:\n        # During distributed training we need to import `include_package` modules again\n        # in order to initialize the lazy objects.\n        for package_name in include_package:\n            import_extra_module(package_name)\n\n    if config.is_distributed:\n        import tango.common.logging as common_logging\n\n        common_logging.initialize_worker_logging(config.worker_id)\n    logger = logging.getLogger(TorchTrainStep.__name__)\n\n    training_engine: TrainingEngine = training_engine.construct(\n        train_config=config,\n        model=model,\n    )\n\n    # Check working directory to see if we should recover from a previous run.\n    initial_state: Optional[Dict[str, Any]] = None\n    if config.state_path.exists():\n        if config.is_local_main_process:\n            logger.info(f\"Recovering from previous run at {str(config.state_path.resolve())}\")\n        initial_state = training_engine.load_checkpoint(config.state_path)\n    device = config.worker_local_default_device\n\n    # Construct data loaders.\n    validation_dataloader_: Optional[DataLoader] = None\n    if config.validation_split is not None:\n        validation_dataset = dataset_dict[config.validation_split]\n        check_dataset(validation_dataset, config.validation_split)\n        if validation_dataloader is not None:\n            validation_dataloader_ = validation_dataloader.construct(dataset=validation_dataset)\n        else:\n            validation_dataloader_ = train_dataloader.construct(dataset=validation_dataset)\n    validation_dataloader: Optional[DataLoader] = validation_dataloader_\n    train_dataset = dataset_dict[config.train_split]\n    check_dataset(train_dataset, config.train_split)\n    train_dataloader: DataLoader = train_dataloader.construct(dataset=train_dataset)\n\n    if config.train_steps is None:\n        assert config.train_epochs is not None\n        try:\n            steps_per_epoch = len(train_dataloader)\n        except TypeError:\n            raise ConfigurationError(\"You must set 'train_steps' for streaming/iterable datasets\")\n        config.train_steps = math.ceil(\n            steps_per_epoch * (config.train_epochs or 1) / config.grad_accum\n        )\n\n    assert config.train_steps is not None  # for mypy\n\n    if validation_dataloader is not None:\n        if config.validation_steps is None:\n            try:\n                config.validation_steps = len(validation_dataloader)\n            except TypeError:\n                raise ConfigurationError(\n                    \"You must set 'validation_steps' for streaming/iterable datasets\"\n                )\n\n    # Make sure we're using a DistributedSampler during distributed training.\n    if config.is_distributed:\n        check_dataloader(train_dataloader)\n        if validation_dataloader is not None:\n            check_dataloader(validation_dataloader)\n\n    # The (training) loss for each batch, updated every training batch.\n    batch_loss: float = 0.0\n    # The value of the validation metric (could be loss), updated after every validation loop.\n    val_metric: Optional[float] = None\n    # The best validation metric over all validation set passes.\n    best_val_metric: Optional[float] = None\n    # The best validation metric over all validation set passes that correspond to a checkpoint.\n    # Could be different from `best_val_metric` if `checkpoint_every` > `validate_every`.\n    best_val_metric_checkpointed: Optional[float] = None\n    # The step to start training from.\n    start_step: int = 0\n    # The current training step.\n    step: int = start_step\n    # If we should do a validation pass after the current training batch.\n    should_validate_this_step: bool = False\n\n    # Load state from checkpoint.\n    if initial_state is not None:\n        val_metric = initial_state[f\"val_{config.val_metric_name}\"]\n        best_val_metric = initial_state[f\"best_{config.val_metric_name}\"]\n        best_val_metric_checkpointed = initial_state[f\"best_{config.val_metric_name}_checkpointed\"]\n        start_step = initial_state[\"training_steps\"]\n\n    # Initialize callbacks.\n    callbacks: List[TrainCallback] = [\n        callback.construct(\n            workspace=workspace,\n            train_config=config,\n            training_engine=training_engine,\n            dataset_dict=dataset_dict,\n            train_dataloader=train_dataloader,\n            validation_dataloader=validation_dataloader,\n        )\n        for callback in (callbacks or [])\n    ]\n    if initial_state:\n        for callback, state in zip(callbacks, initial_state[\"callbacks\"]):\n            callback.load_state_dict(state)\n\n    del initial_state\n\n    training_engine.model.train()\n    training_batches = enumerate(\n        islice(\n            _cycle_through_epochs(train_dataloader, config.is_distributed, config.grad_accum),\n            config.train_steps,\n        )\n    )\n\n    def is_best_checkpoint() -> bool:\n        \"\"\"\n        A closure that we'll call when saving checkpoints to check if we should link\n        the best checkpoint path to the current checkpoint file.\n        \"\"\"\n        if val_metric is not None:\n            if best_val_metric_checkpointed is not None:\n                return (\n                    config.minimize_val_metric and val_metric <= best_val_metric_checkpointed\n                ) or (not config.minimize_val_metric and val_metric >= best_val_metric_checkpointed)\n            else:\n                return False\n        else:\n            # Without a validation loop we always treat the most recent checkpoint as the best.\n            return True\n\n    def save_state(step: int):\n        \"\"\"\n        A closure that we'll call every `checkpoint_every` steps in the train loop to\n        save model and training state.\n        \"\"\"\n        # Update best loss/metric trackers.\n        nonlocal best_val_metric_checkpointed\n        if should_validate_this_step and val_metric is not None:\n            if (\n                best_val_metric_checkpointed is None\n                or (config.minimize_val_metric and val_metric <= best_val_metric_checkpointed)\n                or (not config.minimize_val_metric and val_metric >= best_val_metric_checkpointed)\n            ):\n                best_val_metric_checkpointed = val_metric\n\n        train_state = {\n            \"training_steps\": step + 1,\n            f\"val_{config.val_metric_name}\": val_metric,\n            f\"best_{config.val_metric_name}\": best_val_metric,\n            f\"best_{config.val_metric_name}_checkpointed\": best_val_metric_checkpointed,\n            \"callbacks\": [\n                callback.state_dict() for callback in callbacks  # type: ignore[union-attr]\n            ],\n        }\n\n        # For reason mypy can't figure out that `training_engine` is a `TrainingEngine` in this closure,\n        # and not a `Lazy[TrainingEngine]`.\n        cast(TrainingEngine, training_engine).save_checkpoint(\n            config.state_path_for_step(step), train_state\n        )\n\n        # Link to most recent state path.\n        # NOTE: While hard linking would be preferable to creating symlinks, some train engines\n        # require a whole directory to save their state instead of a single file, which\n        # means state_path_for_step will be a directory, so a hard link won't work.\n        if config.is_local_main_process:\n            if config.state_path.is_symlink():\n                config.state_path.unlink()\n            config.state_path.symlink_to(\n                config.state_path_for_step(step).relative_to(config.work_dir)\n            )\n\n            # Link to best state path.\n            if is_best_checkpoint():\n                if config.best_state_path.is_symlink():\n                    config.best_state_path.unlink()\n                config.best_state_path.symlink_to(\n                    config.state_path_for_step(step).relative_to(config.work_dir)\n                )\n\n            # Clean up stale checkpoints.\n            if config.remove_stale_checkpoints:\n                checkpoints_to_keep = {\n                    config.best_state_path.resolve(),\n                    config.state_path.resolve(),\n                }\n                for path in config.work_dir.glob(\"checkpoint_state_step*\"):\n                    path = path.resolve()\n                    if path not in checkpoints_to_keep:\n                        if path.is_file():\n                            path.unlink()\n                        else:\n                            shutil.rmtree(path)\n\n        if config.is_distributed:\n            dist.barrier()\n\n    # Catch data loader up to where we left off before.\n    current_epoch: int = -1\n    if start_step > 0:\n        with Tqdm.tqdm(\n            training_batches,\n            desc=f\"Catching dataloader up to step {start_step}\",\n            total=start_step - 1,\n            disable=not config.is_local_main_process,\n        ) as batch_iter:\n            for step, (current_epoch, batch) in batch_iter:\n                del batch\n                if step >= start_step - 1:\n                    break\n\n    if config.is_distributed:\n        dist.barrier()\n\n    for callback in callbacks:\n        callback.pre_train_loop()\n\n    train_batch_iterator_tqdm = Tqdm.tqdm(\n        training_batches,\n        desc=\"Training\",\n        initial=start_step,\n        total=config.train_steps,\n        disable=not config.is_local_main_process,\n    )\n    train_batch_iterator = more_itertools.peekable(train_batch_iterator_tqdm)\n    try:\n        for step, (epoch, batch) in train_batch_iterator:\n            if epoch != current_epoch:\n                # Start of new epoch.\n                if epoch > 0:\n                    # Call post-epoch callbacks for the last epoch.\n                    for callback in callbacks:\n                        callback.post_epoch(step, current_epoch)\n                for callback in callbacks:\n                    callback.pre_epoch(step, epoch)\n                current_epoch = epoch\n\n            # Pre-batch callback.\n            for callback in callbacks:\n                callback.pre_batch(step, current_epoch, batch)\n            batch_loss = 0.0\n            batch_outputs = []\n            for micro_batch_idx, micro_batch in enumerate(batch):\n                # Get loss.\n                micro_batch_loss, micro_batch_outputs = training_engine.forward_train(\n                    micro_batch, micro_batch_idx, len(batch)\n                )\n                if torch.isnan(micro_batch_loss):\n                    raise ValueError(\"nan loss encountered\")\n                batch_loss += micro_batch_loss.detach().item()\n                batch_outputs.append(\n                    {\n                        key: output.detach() if isinstance(output, torch.Tensor) else output\n                        for key, output in micro_batch_outputs.items()\n                    }\n                )\n\n                # Calculate gradients.\n                training_engine.backward(micro_batch_loss)\n\n                # Clean up in case it saves memory.\n                del micro_batch\n                del micro_batch_loss\n                del micro_batch_outputs\n\n            # Post-batch callback.\n            for callback in callbacks:\n                callback.post_batch(step, current_epoch, batch_loss, batch_outputs)\n\n            del batch\n\n            training_engine.step()\n\n            # Find out whether we should validate\n            if config.validation_split is None:\n                # If we can't validate, we don't.\n                should_validate_this_step = False\n            elif step == config.train_steps - 1:\n                # If we're at the end of the training run, we always validate.\n                should_validate_this_step = True\n            elif config.validate_every is not None and (step + 1) % config.validate_every == 0:\n                # If validate_every is given, we use that to decide.\n                should_validate_this_step = True\n            elif config.validate_every is None and epoch != train_batch_iterator.peek()[1][0]:\n                # If validate_every is not given, we validate at the end of the epoch.\n                should_validate_this_step = True\n            else:\n                # Otherwise, we don't validate.\n                should_validate_this_step = False\n\n            # Gather average loss across all workers.\n            if (\n                config.should_log_this_step(step) or should_validate_this_step\n            ) and config.is_distributed:\n                batch_loss_tensor = torch.tensor(batch_loss, device=device)\n                dist.all_reduce(batch_loss_tensor)\n                batch_loss = batch_loss_tensor.item() / config.world_size\n\n            if config.should_log_this_step(step):\n                # Callbacks.\n                for callback in callbacks:\n                    callback.log_batch(step, current_epoch, batch_loss, batch_outputs)\n\n                # Update progress bar.\n                metrics_to_log: Dict[str, float] = {\"batch_loss\": batch_loss}\n                if val_metric is not None:\n                    metrics_to_log[f\"val_{config.val_metric_name}\"] = val_metric\n                if best_val_metric is not None:\n                    metrics_to_log[f\"best_val_{config.val_metric_name}\"] = best_val_metric\n                if config.is_local_main_process:\n                    train_batch_iterator_tqdm.set_postfix(**metrics_to_log)\n\n            # Validate.\n            if should_validate_this_step:\n                assert validation_dataloader is not None\n                assert config.validation_steps is not None\n\n                # Prepare model for validation.\n                training_engine.model.eval()\n\n                running_metric = 0.0\n                with Tqdm.tqdm(\n                    islice(validation_dataloader, config.validation_steps),\n                    desc=\"Validating\",\n                    total=config.validation_steps,\n                    leave=False,\n                    disable=not config.is_local_main_process,\n                ) as val_batch_iterator:\n                    for val_step, val_batch in enumerate(val_batch_iterator):\n                        for callback in callbacks:\n                            callback.pre_val_batch(step, val_step, current_epoch, val_batch)\n\n                        # Get metric.\n                        outputs = training_engine.forward_eval(val_batch)\n\n                        for callback in callbacks:\n                            callback.post_val_batch(step, val_step, current_epoch, outputs)\n                        metric = outputs[config.val_metric_name]\n\n                        if config.auto_aggregate_val_metric:\n                            running_metric += metric if isinstance(metric, float) else metric.item()\n                            val_metric = running_metric / (val_step + 1)\n                        else:\n                            val_metric = metric if isinstance(metric, float) else metric.item()\n\n                        # Average metric across all workers.\n                        if (\n                            config.is_distributed\n                            and config.should_log_this_val_step(val_step)\n                            and config.auto_aggregate_val_metric\n                        ):\n                            val_metric_tensor = torch.tensor(val_metric, device=device)\n                            dist.all_reduce(val_metric_tensor)\n                            val_metric = val_metric_tensor.item() / config.world_size\n\n                        # Update progress bar.\n                        if config.is_local_main_process and config.should_log_this_val_step(\n                            val_step\n                        ):\n                            val_batch_iterator.set_postfix(**{config.val_metric_name: val_metric})\n\n                        # Clean up.\n                        del val_batch\n                        del outputs\n                        del metric\n\n                assert val_metric is not None\n\n                # Reset model to train mode.\n                training_engine.model.train()\n\n                if (\n                    best_val_metric is None\n                    or (config.minimize_val_metric and val_metric <= best_val_metric)\n                    or (not config.minimize_val_metric and val_metric >= best_val_metric)\n                ):\n                    best_val_metric = val_metric\n\n                # Checkpoint.\n                if config.should_checkpoint_this_step(step):\n                    save_state(step)\n\n                # Post validation callback.\n                for callback in callbacks:\n                    callback.post_val_loop(step, current_epoch, val_metric, best_val_metric)\n\n                # Reset model to train mode again in case the callbacks messed with it.\n                if callbacks:\n                    training_engine.model.train()\n\n                # Update progress bar again.\n                metrics_to_log = {\n                    \"batch_loss\": batch_loss,\n                    f\"val_{config.val_metric_name}\": val_metric,\n                    f\"best_{config.val_metric_name}\": best_val_metric,\n                }\n                if config.is_local_main_process:\n                    train_batch_iterator_tqdm.set_postfix(**metrics_to_log)\n            else:\n                # Checkpoint.\n                if config.should_checkpoint_this_step(step):\n                    save_state(step)\n\n        # End train loop\n\n        # Final post-epoch callback.\n        for callback in callbacks:\n            callback.post_epoch(step, current_epoch)\n    except StopEarly:\n        if config.is_local_main_process:\n            logger.info(\"Stopping early!\")\n    finally:\n        train_batch_iterator_tqdm.close()\n\n    if config.is_distributed:\n        dist.barrier()\n\n    # If we haven't saved a checkpoint yet, do it now.\n    if not config.best_state_path.exists():\n        save_state(step)\n\n    for callback in callbacks:\n        callback.post_train_loop(step, current_epoch)\n\n    if config.is_local_main_process:\n        # It's possible this file already exists if the step previously failed after\n        # already saving the final weights.\n        if config.final_weights_path.is_file():\n            os.remove(config.final_weights_path)\n        training_engine.save_complete_weights_from_checkpoint(\n            config.best_state_path, config.final_weights_path\n        )\n\n    if not config.is_distributed:\n        return training_engine.model\n    else:\n        return None\n\n\ndef _cycle_through_epochs(dataloader: DataLoader, is_distributed: bool, grad_accum: int):\n    epoch = 0\n    while True:\n        if is_distributed and isinstance(dataloader.sampler, DistributedSampler):\n            dataloader.sampler.set_epoch(epoch)\n        for batch in chunked(dataloader, grad_accum):\n            yield epoch, batch\n        epoch += 1\n"
  },
  {
    "path": "tango/integrations/torch/train_callback.py",
    "content": "import logging\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional\n\nfrom tango.common.dataset_dict import DatasetDictBase\nfrom tango.common.registrable import Registrable\nfrom tango.workspace import Workspace\n\nfrom .data import DataLoader\nfrom .exceptions import StopEarly\nfrom .model import Model\nfrom .train_config import TrainConfig\nfrom .training_engine import TrainingEngine\n\n\nclass TrainCallback(Registrable):\n    \"\"\"\n    A :class:`TrainCallback` is a :class:`~tango.common.Registrable` class\n    that can be used within :class:`TorchTrainStep` to customize behavior in the training\n    loop. You can set the training callbacks with the ``callbacks`` parameter to :class:`TorchTrainStep`.\n\n    .. tip::\n        All of the parameters to this base class will be automatically set within\n        the training loop, so you shouldn't include them in your config for your callbacks.\n\n    .. tip::\n        You can access the model being trained through :attr:`self.model <model>`.\n\n    .. important::\n        The ``step`` argument to callback methods is the total/overall number of training steps\n        so far, independent of the current epoch.\n\n    .. seealso::\n        See :class:`~tango.integrations.wandb.WandbTrainCallback` for an example\n        implementation.\n\n    :ivar Workspace workspace: The tango workspace being used.\n    :ivar TrainConfig train_config: The training config.\n    :ivar TrainingEngine training_engine: The engine used to train the model.\n    :ivar tango.common.DatasetDictBase dataset_dict: The dataset dict containing train and\n        optional validation splits.\n    :ivar DataLoader train_dataloader: The dataloader used for the training split.\n    :ivar DataLoader validation_dataloader: Optional dataloader used for the validation split.\n    \"\"\"\n\n    def __init__(\n        self,\n        workspace: Workspace,\n        train_config: TrainConfig,\n        training_engine: TrainingEngine,\n        dataset_dict: DatasetDictBase,\n        train_dataloader: DataLoader,\n        validation_dataloader: Optional[DataLoader] = None,\n    ) -> None:\n        self.workspace = workspace\n        self.train_config = train_config\n        self.training_engine = training_engine\n        self.dataset_dict = dataset_dict\n        self.train_dataloader = train_dataloader\n        self.validation_dataloader = validation_dataloader\n        self.logger = logging.getLogger(self.__class__.__name__)\n\n    @property\n    def step_id(self) -> str:\n        \"\"\"\n        The unique ID of the current :class:`~tango.Step`.\n        \"\"\"\n        return self.train_config.step_id\n\n    @property\n    def step_name(self) -> Optional[str]:\n        \"\"\"\n        The name of the current :class:`~tango.Step`.\n        \"\"\"\n        return self.train_config.step_name\n\n    @property\n    def work_dir(self) -> Path:\n        \"\"\"\n        The working directory of the current train step.\n        \"\"\"\n        return self.train_config.work_dir\n\n    @property\n    def is_local_main_process(self) -> bool:\n        \"\"\"\n        This is ``True`` if the current worker is the main distributed worker of the current node, or if\n        we are not using distributed training.\n        \"\"\"\n        return self.train_config.is_local_main_process\n\n    @property\n    def model(self) -> Model:\n        \"\"\"\n        The :class:`Model` being trained.\n        \"\"\"\n        return self.training_engine.model\n\n    def state_dict(self) -> Dict[str, Any]:\n        \"\"\"\n        Return any state that needs to be kept after a restart.\n\n        Some callbacks need to maintain state across restarts. This is the callback's opportunity to\n        save it's state. It will be restored using :meth:`load_state_dict`.\n        \"\"\"\n        return {}\n\n    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:\n        \"\"\"\n        Load the state on a restart.\n\n        Some callbacks need to maintain state across restarts. This is the callback's opportunity to\n        restore it's state. It gets saved using :meth:`state_dict`.\n        \"\"\"\n        pass\n\n    def pre_train_loop(self) -> None:\n        \"\"\"\n        Called right before the first batch is processed, or after a restart.\n        \"\"\"\n        pass\n\n    def post_train_loop(self, step: int, epoch: int) -> None:\n        \"\"\"\n        Called after the training loop completes.\n\n        This is the last method that is called, so any cleanup can be done in this method.\n        \"\"\"\n        pass\n\n    def pre_epoch(self, step: int, epoch: int) -> None:\n        \"\"\"\n        Called right before the start of an epoch. Epochs start at 0.\n        \"\"\"\n        pass\n\n    def post_epoch(self, step: int, epoch: int) -> None:\n        \"\"\"\n        Called after an epoch is completed. Epochs start at 0.\n        \"\"\"\n        pass\n\n    def pre_batch(self, step: int, epoch: int, batch: List[Dict[str, Any]]) -> None:\n        \"\"\"\n        Called directly before processing a batch.\n\n        .. note::\n            A type of ``batch`` is a list because with gradient accumulation there will\n            more than one \"micro batch\" in the batch.\n\n        \"\"\"\n        pass\n\n    def post_batch(\n        self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]]\n    ) -> None:\n        \"\"\"\n        Called directly after processing a batch, but before unscaling gradients,\n        clipping gradients, and taking an optimizer step.\n\n        .. note::\n            The ``batch_loss`` here is the loss local to the current worker, not the\n            overall (average) batch loss across distributed workers.\n\n            If you need the average loss, use :meth:`log_batch()`.\n\n        .. note::\n            A type of ``batch_outputs`` is a list because with gradient accumulation there will\n            more than one \"micro batch\" in the batch.\n\n        \"\"\"\n        pass\n\n    def log_batch(\n        self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]]\n    ) -> None:\n        \"\"\"\n        Called after the optimizer step. Here ``batch_loss`` is the average loss across\n        all distributed workers.\n\n        .. note::\n            This callback method is not necessarily called on every step.\n            The frequency depends on the value of the ``log_every`` parameter of\n            :class:`TorchTrainStep`.\n\n        .. note::\n            A type of ``batch_outputs`` is a list because with gradient accumulation there will\n            more than one \"micro batch\" in the batch.\n\n        \"\"\"\n        pass\n\n    def pre_val_batch(\n        self, step: int, val_step: int, epoch: int, val_batch: Dict[str, Any]\n    ) -> None:\n        \"\"\"\n        Called right before a validation batch is processed.\n        \"\"\"\n        pass\n\n    def post_val_batch(\n        self, step: int, val_step: int, epoch: int, val_batch_outputs: Dict[str, Any]\n    ) -> None:\n        \"\"\"\n        Called right after a validation batch is processed with the outputs of the batch.\n\n        .. tip::\n            This method can be used to modify ``val_batch_outputs`` in place, which is useful\n            in scenarios like distributed training where you might need to aggregate metrics\n            in a special way other than a simple average. If that's the case, make sure\n            to set ``auto_aggregate_val_metric`` to ``False`` in :class:`TorchTrainStep`.\n\n        \"\"\"\n        pass\n\n    def post_val_loop(\n        self, step: int, epoch: int, val_metric: float, best_val_metric: float\n    ) -> None:\n        \"\"\"\n        Called right after the validation loop finishes.\n        \"\"\"\n        pass\n\n\n@TrainCallback.register(\"torch::stop_early\")\nclass StopEarlyCallback(TrainCallback):\n    \"\"\"\n    A :class:`TrainCallback` for early stopping. Training is stopped early after\n    ``patience`` steps without an improvement to the validation metric.\n\n    .. tip::\n\n        Registered as a :class:`TrainCallback` under the name \"torch::stop_early\".\n    \"\"\"\n\n    def __init__(self, *args, patience: int = 10000, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n        self.patience = patience\n        self.best_step = 0\n        self.best_val_metric: Optional[float] = None\n\n    def post_val_loop(\n        self, step: int, epoch: int, val_metric: float, best_val_metric: float\n    ) -> None:\n        # We can't rely on the best_val_metric parameter, because then we can't detect when the metric stays\n        # the same for many steps.\n        if self.best_val_metric is None or val_metric > self.best_val_metric:\n            self.best_step = step\n            self.best_val_metric = val_metric\n        elif step > self.best_step + self.patience:\n            raise StopEarly\n\n    def state_dict(self) -> Dict[str, Any]:\n        \"\"\"\n        Return any state that needs to be kept after a restart.\n        \"\"\"\n        return {\n            \"patience\": self.patience,\n            \"best_step\": self.best_step,\n            \"best_val_metric\": self.best_val_metric,\n        }\n\n    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:\n        \"\"\"\n        Load the state on a restart.\n        \"\"\"\n        self.patience = state_dict[\"patience\"]\n        self.best_step = state_dict[\"best_step\"]\n        self.best_val_metric = state_dict[\"best_val_metric\"]\n"
  },
  {
    "path": "tango/integrations/torch/train_config.py",
    "content": "from dataclasses import asdict, dataclass\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional\n\nimport torch\n\n\n@dataclass\nclass TrainConfig:\n    \"\"\"\n    Encapsulates the parameters of :class:`TorchTrainStep`. This is used to pass all the training\n    options to :class:`TrainCallback`.\n    \"\"\"\n\n    step_id: str\n    \"\"\"\n    The unique ID of the current step.\n    \"\"\"\n\n    work_dir: Path\n    \"\"\"\n    The working directory for the training run.\n    \"\"\"\n\n    step_name: Optional[str] = None\n    \"\"\"\n    The name of the current step.\n\n    .. note::\n        The same step can be run under different names.\n    \"\"\"\n\n    worker_id: int = 0\n    \"\"\"\n    The ID of the distributed worker.\n    \"\"\"\n\n    train_split: str = \"train\"\n    \"\"\"\n    The name of the training split.\n    \"\"\"\n\n    validation_split: Optional[str] = None\n    \"\"\"\n    The name of the validation split.\n    \"\"\"\n\n    seed: int = 42\n    \"\"\"\n    The random seed.\n    \"\"\"\n\n    train_steps: Optional[int] = None\n    \"\"\"\n    The number of steps to train for.\n    \"\"\"\n\n    train_epochs: Optional[int] = None\n    \"\"\"\n    The number of epochs to train for.\n\n    You cannot specify `train_steps` and `train_epochs` at the same time.\n    \"\"\"\n\n    validation_steps: Optional[int] = None\n    \"\"\"\n    The number of validation steps.\n\n    The default is to validate on the entire validation set.\n    \"\"\"\n\n    grad_accum: int = 1\n    \"\"\"\n    The number of micro-batches per gradient accumulation mini-batch.\n    \"\"\"\n\n    log_every: int = 10\n    \"\"\"\n    Controls the frequency of log updates, in number of optimizer steps\n    \"\"\"\n\n    checkpoint_every: int = 100\n    \"\"\"\n    Controls the frequency of checkpoints, in number of optimizer steps\n    \"\"\"\n\n    validate_every: Optional[int] = None\n    \"\"\"\n    Controls the frequency of the validation loop, in number of optimizer steps\n    \"\"\"\n\n    is_distributed: bool = False\n    \"\"\"\n    Whether or not the training job is distributed.\n    \"\"\"\n\n    devices: Optional[List[int]] = None\n    \"\"\"\n    The devices used (for distributed jobs).\n    \"\"\"\n\n    distributed_address: str = \"127.0.0.1\"\n    \"\"\"\n    The IP address of the main distributed process.\n    \"\"\"\n\n    distributed_port: int = 54761\n    \"\"\"\n    The port of the main distributed process.\n    \"\"\"\n\n    val_metric_name: str = \"loss\"\n    \"\"\"\n    The name of the validation metric to track.\n    \"\"\"\n\n    minimize_val_metric: bool = True\n    \"\"\"\n    Should be ``True`` when the validation metric being tracked should be minimized.\n    \"\"\"\n\n    auto_aggregate_val_metric: bool = True\n    \"\"\"\n    Controls automatic aggregation of validation metric.\n    \"\"\"\n\n    remove_stale_checkpoints: bool = True\n    \"\"\"\n    Controls removal of stale checkpoints.\n    \"\"\"\n\n    world_size: int = 1\n    \"\"\"\n    The number of distributed workers.\n    \"\"\"\n\n    _worker_local_default_device: Optional[torch.device] = None\n\n    _device_type: Optional[str] = None  # either \"cuda\" or \"cpu\"\n\n    @property\n    def worker_local_default_device(self) -> torch.device:\n        \"\"\"\n        The default ``torch`` device for the current worker.\n        \"\"\"\n        if self._worker_local_default_device is not None:\n            return self._worker_local_default_device\n        else:\n            if self.devices:\n                device_id = self.devices[self.worker_id]\n                if device_id >= 0:\n                    device = torch.device(f\"cuda:{device_id}\")\n                else:\n                    device = torch.device(\"cpu\")\n            elif torch.cuda.is_available():\n                device = torch.device(\"cuda\")\n            else:\n                device = torch.device(\"cpu\")\n            self._worker_local_default_device = device\n            return device\n\n    @property\n    def device_type(self) -> str:\n        if self._device_type is None:\n            device_type = (\n                \"cpu\" if self.worker_local_default_device == torch.device(\"cpu\") else \"cuda\"\n            )\n            self._device_type = device_type\n            return device_type\n        else:\n            return self._device_type\n\n    @property\n    def is_local_main_process(self) -> bool:\n        \"\"\"\n        Whether the local process is the main distributed worker.\n        \"\"\"\n        return self.worker_id == 0\n\n    @property\n    def state_path(self) -> Path:\n        \"\"\"\n        The path to the latest state checkpoint file.\n        \"\"\"\n        return self.work_dir / \"checkpoint_state_latest\"\n\n    @property\n    def best_state_path(self) -> Path:\n        \"\"\"\n        The path to the best state checkpoint file according to the validation metric or training\n        loss (if no validation split is given).\n        \"\"\"\n        return self.work_dir / \"checkpoint_state_best\"\n\n    def state_path_for_step(self, step: int) -> Path:\n        return self.work_dir / f\"checkpoint_state_step{step + 1}\"\n\n    @property\n    def final_weights_path(self) -> Path:\n        return self.work_dir / \"weights.pt\"\n\n    def should_log_this_step(self, step: int) -> bool:\n        assert self.train_steps is not None\n        return step == 0 or (step + 1) % self.log_every == 0 or step == self.train_steps - 1\n\n    def should_checkpoint_this_step(self, step: int) -> bool:\n        assert self.train_steps is not None\n        return ((step + 1) % self.checkpoint_every == 0) or step == self.train_steps - 1\n\n    def should_log_this_val_step(self, val_step: int) -> bool:\n        assert self.validation_steps is not None\n        return val_step % self.log_every == 0 or val_step == self.validation_steps - 1\n\n    def as_dict(self) -> Dict[str, Any]:\n        return {k: v for k, v in asdict(self).items() if not k.startswith(\"_\")}\n"
  },
  {
    "path": "tango/integrations/torch/training_engine.py",
    "content": "import os\nimport tempfile\nfrom abc import abstractmethod\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Tuple, Union, cast\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\n\nfrom tango.common import Lazy, Registrable, Tqdm\n\nfrom .model import Model\nfrom .optim import LRScheduler, Optimizer\nfrom .train_config import TrainConfig\nfrom .util import move_to_device\n\n\nclass TrainingEngine(Registrable):\n    \"\"\"\n    A :class:`TrainingEngine` defines and drives the strategy for training a model\n    in :class:`TorchTrainStep`.\n\n    :ivar TrainConfig train_config: The training config.\n    :ivar Model model: The model being trained.\n    :ivar Optimizer optimizer: The optimizer being used to train the model.\n    :ivar LRScheduler lr_scheduler: The optional learning rate scheduler.\n    \"\"\"\n\n    default_implementation = \"torch\"\n    \"\"\"\n    The default implementation is :class:`TorchTrainingEngine`.\n    \"\"\"\n\n    def __init__(\n        self,\n        train_config: TrainConfig,\n        model: Union[Model, Lazy[Model]],\n        optimizer: Lazy[Optimizer],\n        *,\n        lr_scheduler: Optional[Lazy[LRScheduler]] = None,\n    ) -> None:\n        self.train_config = train_config\n        self.model = self._construct_model(model)\n        self.optimizer = self._construct_optimizer(optimizer)\n        self.lr_scheduler: Optional[LRScheduler] = None\n        if lr_scheduler is not None:\n            self.lr_scheduler = self._construct_lr_scheduler(lr_scheduler)\n\n    def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model:\n        if isinstance(model, Lazy):\n            model = model.construct()\n        return model.to(self.train_config.worker_local_default_device)\n\n    def _construct_optimizer(self, optimizer: Lazy[Optimizer]) -> Optimizer:\n        optimizer: Optimizer = optimizer.construct(params=self.model.parameters())\n        return optimizer\n\n    def _construct_lr_scheduler(self, lr_scheduler: Lazy[LRScheduler]) -> LRScheduler:\n        lr_scheduler: LRScheduler = lr_scheduler.construct(optimizer=self.optimizer)\n        return lr_scheduler\n\n    @abstractmethod\n    def forward_train(\n        self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int\n    ) -> Tuple[torch.Tensor, Dict[str, Any]]:\n        \"\"\"\n        Run a forward training pass on the model.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n        \"\"\"\n        Run a forward evaluation pass on the model.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def backward(self, loss: torch.Tensor) -> None:\n        \"\"\"\n        Run a backwards pass on the model. This will always be called after :meth:`forward_train()`.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def step(self) -> None:\n        \"\"\"\n        Take an optimization step. This will always be called after :meth:`backward()`.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None:\n        \"\"\"\n        Save a training checkpoint with model state, optimizer state, etc., as well\n        as the arbitrary ``client_state`` to the given ``checkpoint_dir``.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]:\n        \"\"\"\n        Load a checkpoint to resume training. Should return the same ``client_state`` saved\n        in :meth:`save_checkpoint()`.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def save_complete_weights_from_checkpoint(\n        self, checkpoint_dir: Path, weights_path: Path\n    ) -> None:\n        \"\"\"\n        Gather the final weights from the best checkpoint and save to the file at ``weights_path``.\n        \"\"\"\n        raise NotImplementedError\n\n\n@TrainingEngine.register(\"torch\")\nclass TorchTrainingEngine(TrainingEngine):\n    \"\"\"\n    This train engine only uses native PyTorch functionality to provide\n    vanilla distributed data parallel training and AMP.\n\n    .. tip::\n        Registered as a :class:`TrainingEngine` under the name \"torch\".\n\n    .. important::\n        Only the parameters listed below should be defined in a configuration\n        file. The other parameters will be automatically passed to the constructor\n        within :class:`TorchTrainStep`.\n\n    :param amp:\n        Use automatic mixed precision. Default is ``False``.\n    :param max_grad_norm:\n        If set, gradients will be clipped to have this max norm. Default is ``None``.\n    :param amp_use_bfloat16:\n        Set to ``True`` to force using the ``bfloat16`` datatype in mixed precision training.\n        Only applicable when ``amp=True``. If not specified, the default behavior will be\n        to use ``bfloat16`` when training with AMP on CPU, otherwise not.\n    \"\"\"\n\n    def __init__(\n        self,\n        train_config: TrainConfig,\n        model: Union[Model, Lazy[Model]],\n        optimizer: Lazy[Optimizer],\n        *,\n        lr_scheduler: Optional[Lazy[LRScheduler]] = None,\n        amp: bool = False,\n        max_grad_norm: Optional[float] = None,\n        amp_use_bfloat16: Optional[bool] = None,\n    ) -> None:\n        self.device = train_config.worker_local_default_device\n        if amp_use_bfloat16 is None:\n            amp_use_bfloat16 = True if train_config.device_type == \"cpu\" else False\n\n        self.amp = amp\n        self.amp_dtype = torch.bfloat16 if amp_use_bfloat16 else torch.float16\n        self.max_grad_norm = max_grad_norm\n        self.grad_scaler: Optional[torch.cuda.amp.GradScaler] = (\n            None if not amp else torch.cuda.amp.GradScaler()\n        )\n\n        if train_config.is_distributed:\n            # Initialize distributed process group.\n            backend: str\n            if train_config.device_type != \"cpu\":\n                torch.cuda.set_device(self.device)\n                backend = \"nccl\"\n            else:\n                backend = \"gloo\"\n            dist.init_process_group(\n                backend=backend,\n                init_method=f\"tcp://{train_config.distributed_address}:{train_config.distributed_port}\",\n                world_size=train_config.world_size,\n                rank=train_config.worker_id,\n            )\n\n        super().__init__(train_config, model, optimizer, lr_scheduler=lr_scheduler)\n\n    def _construct_model(self, model: Union[Model, Lazy[Model]]) -> Model:\n        if isinstance(model, Lazy):\n            model = model.construct()\n        model.to(self.train_config.worker_local_default_device)\n        # Wrap model with DDP wrapper.\n        if self.train_config.is_distributed:\n            model = cast(Model, nn.parallel.DistributedDataParallel(model))\n        return model\n\n    def forward_train(\n        self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int\n    ) -> Tuple[torch.Tensor, Dict[str, Any]]:\n        if micro_batch_idx == 0:\n            self.optimizer.zero_grad(set_to_none=True)\n\n        # Move tensors to right device.\n        micro_batch = move_to_device(micro_batch, self.device)\n\n        with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype):\n            outputs = self.model(**micro_batch)\n            micro_batch_loss = outputs[\"loss\"] / num_micro_batches\n\n        return micro_batch_loss, outputs\n\n    def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]:\n        # Move tensors to right device.\n        batch = move_to_device(batch, self.device)\n\n        with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype):\n            with torch.inference_mode():\n                outputs = self.model(**batch)\n\n        return outputs\n\n    def backward(self, loss: torch.Tensor) -> None:\n        if self.grad_scaler is not None:\n            self.grad_scaler.scale(loss).backward()\n        else:\n            loss.backward()\n\n    def clip_grad_norm(self) -> None:\n        if self.max_grad_norm is not None:\n            nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)\n\n    def step(self) -> None:\n        # Unscale gradients.\n        if self.grad_scaler is not None:\n            self.grad_scaler.unscale_(self.optimizer)\n\n        # Clip gradients.\n        self.clip_grad_norm()\n\n        # Take optimizer step.\n        if self.grad_scaler is not None:\n            self.grad_scaler.step(self.optimizer)\n            self.grad_scaler.update()\n        else:\n            self.optimizer.step()\n\n        # Adjust LR schedule.\n        if self.lr_scheduler is not None:\n            self.lr_scheduler.step()\n\n    def get_model_state(self) -> Dict[str, torch.Tensor]:\n        if self.train_config.is_distributed:\n            return self.model.module.state_dict()  # type: ignore[union-attr]\n        else:\n            return self.model.state_dict()\n\n    def load_model_state(self, state_dict: Dict[str, torch.Tensor]) -> None:\n        if self.train_config.is_distributed:\n            self.model.module.load_state_dict(state_dict)  # type: ignore\n        else:\n            self.model.load_state_dict(state_dict)  # type: ignore\n\n    def save_checkpoint(self, checkpoint_dir: Path, client_state: Dict[str, Any]) -> None:\n        checkpoint_dir.mkdir(exist_ok=True)\n\n        def save_state(state: Dict[str, Any], name: str):\n            temp_state_file = tempfile.NamedTemporaryFile(\n                \"w+b\", dir=checkpoint_dir, delete=False, suffix=\".pt\"\n            )\n            try:\n                with Tqdm.wrapattr(\n                    temp_state_file,\n                    \"write\",\n                    desc=f\"Saving {name} state\",\n                    leave=False,\n                    disable=not self.train_config.is_local_main_process,\n                ) as f:\n                    torch.save(state, f)\n                temp_state_file.close()\n                os.replace(\n                    temp_state_file.name,\n                    checkpoint_dir / f\"worker{self.train_config.worker_id}_{name}.pt\",\n                )\n            finally:\n                if os.path.exists(temp_state_file.name):\n                    os.remove(temp_state_file.name)\n\n        save_state(self.get_model_state(), \"model\")\n        save_state(self.optimizer.state_dict(), \"optimizer\"),\n        if self.lr_scheduler is not None:\n            save_state(self.lr_scheduler.state_dict(), \"lr_scheduler\")\n        if self.grad_scaler is not None:\n            save_state(self.grad_scaler.state_dict(), \"grad_scaler\")\n        save_state(client_state, \"trainer\")\n\n    def load_checkpoint(self, checkpoint_dir: Path) -> Dict[str, Any]:\n        self.load_model_state(\n            torch.load(checkpoint_dir / f\"worker{self.train_config.worker_id}_model.pt\")\n        )\n        self.optimizer.load_state_dict(\n            torch.load(checkpoint_dir / f\"worker{self.train_config.worker_id}_optimizer.pt\")\n        )\n        if self.lr_scheduler is not None:\n            self.lr_scheduler.load_state_dict(\n                torch.load(checkpoint_dir / f\"worker{self.train_config.worker_id}_lr_scheduler.pt\")\n            )\n        if self.grad_scaler is not None:\n            self.grad_scaler.load_state_dict(\n                torch.load(checkpoint_dir / f\"worker{self.train_config.worker_id}_grad_scaler.pt\")\n            )\n        return torch.load(checkpoint_dir / f\"worker{self.train_config.worker_id}_trainer.pt\")\n\n    def save_complete_weights_from_checkpoint(\n        self, checkpoint_dir: Path, weights_path: Path\n    ) -> None:\n        os.link(checkpoint_dir.resolve() / \"worker0_model.pt\", weights_path)\n"
  },
  {
    "path": "tango/integrations/torch/util.py",
    "content": "import random\nimport warnings\nfrom collections import UserDict\nfrom typing import Dict, Optional, TypeVar, Union\n\nimport torch\nimport torch.distributed as dist\nfrom torch.utils.data import DistributedSampler, IterableDataset\n\nfrom .data import DataLoader\n\nT = TypeVar(\"T\")\n\n\ndef move_to_device(o: T, device: torch.device) -> T:\n    if isinstance(o, torch.Tensor):\n        return o.to(device)  # type: ignore[return-value]\n    elif isinstance(o, dict) or isinstance(o, UserDict):\n        return {k: move_to_device(v, device) for k, v in o.items()}  # type: ignore[return-value]\n    elif isinstance(o, list):\n        return [move_to_device(x, device) for x in o]  # type: ignore[return-value]\n    elif isinstance(o, tuple):\n        return tuple((move_to_device(x, device) for x in o))  # type: ignore[return-value]\n    else:\n        return o\n\n\ndef check_dataset(dataset, split: str):\n    try:\n        len(dataset)\n    except TypeError:\n        if not isinstance(dataset, IterableDataset):\n            warnings.warn(\n                f\"Dataset for {split} split appears to be a streaming/iterable dataset, \"\n                \"but is not an instance of 'torch.utils.data.IterableDataset'. This could cause issues \"\n                \"within the DataLoader.\",\n                UserWarning,\n            )\n\n\ndef check_dataloader(dataloader: DataLoader):\n    # If using a regular dataset and not streaming/iterable dataset, we\n    # should probably be using a `DistributedSampler`.\n    if not isinstance(dataloader.dataset, IterableDataset) and not isinstance(\n        dataloader.sampler, DistributedSampler\n    ):\n        warnings.warn(\n            \"DistributedSampler is required for dataloader during distributed training, \"\n            f\"found {type(dataloader.sampler)} instead.\",\n            UserWarning,\n        )\n\n\ndef set_seed_all(seed: int):\n    random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(seed)\n    try:\n        import numpy as np\n    except ModuleNotFoundError:\n        pass\n    else:\n        np.random.seed(seed)\n\n\ndef resolve_device(device: Optional[Union[int, str, torch.device]] = None) -> torch.device:\n    if device is None:\n        if torch.cuda.is_available():\n            # TODO (epwalsh, dirkgr): automatically pick which GPU to use when there are multiple\n            return torch.device(\"cuda\")\n        else:\n            return torch.device(\"cpu\")\n    elif isinstance(device, int):\n        if device >= 0:\n            return torch.device(f\"cuda:{device}\")\n        else:\n            return torch.device(\"cpu\")\n    elif isinstance(device, str):\n        return torch.device(device)\n    elif isinstance(device, torch.device):\n        return device\n    else:\n        raise TypeError(f\"unexpected type for 'device': '{device}'\")\n\n\ndef peak_gpu_memory(reset: bool = False) -> Dict[int, int]:\n    \"\"\"\n    Get the peak GPU memory usage in MiB by distributed worker rank.\n\n    :returns:\n        Keys are rank ids as integers (from 0 up to world size - 1).\n        Values are memory usage as integers in MiB.\n        Returns an empty `dict` if GPUs are not available.\n    \"\"\"\n    if not torch.cuda.is_available():\n        return {}\n\n    device = torch.device(\"cuda\")\n\n    results_dict: Dict[int, int] = {}\n    if dist.is_available() and dist.is_initialized():\n        # If the backend is not 'nccl', we're training on CPU.\n        if dist.get_backend() != \"nccl\":\n            return {}\n\n        global_rank = dist.get_rank()\n        world_size = dist.get_world_size()\n        peak_mb = torch.cuda.max_memory_allocated(device) // 1048576\n        peak_mb_tensor = torch.tensor([global_rank, peak_mb], device=device)\n        # All of these tensors will be gathered into this list.\n        gather_results = [torch.tensor([0, 0], device=device) for _ in range(world_size)]\n\n        dist.all_gather(gather_results, peak_mb_tensor)\n\n        for peak_mb_tensor in gather_results:\n            results_dict[int(peak_mb_tensor[0])] = int(peak_mb_tensor[1])\n    else:\n        results_dict = {0: torch.cuda.max_memory_allocated()}\n\n    if reset:\n        # Reset peak stats.\n        torch.cuda.reset_max_memory_allocated(device)\n\n    return results_dict\n"
  },
  {
    "path": "tango/integrations/transformers/__init__.py",
    "content": "\"\"\"\n.. important::\n    To use this integration you should install ``tango`` with the \"transformers\" extra\n    (e.g. ``pip install tango[transformers]``) or just install the ``transformers`` library after the fact\n    (e.g. ``pip install transformers``).\n\nComponents for Tango integration with `🤗 Transformers <https://huggingface.co/docs/transformers/>`_.\n\nThis integration provides some useful steps and also registers PyTorch components from the transformers\nlibrary under the corresponding class from the `torch <torch.html>`_ integration, such as:\n\n- :class:`~tango.integrations.torch.Model`: All transformers \"auto\" model classes are registered\n  according to their class names (e.g. \"transformers::AutoModelForCausalLM::from_pretrained\"\n  or \"transformers::AutoModelForCausalLM::from_config\").\n\n  For example, to instantiate a pretrained transformer model from params:\n\n  .. testcode::\n\n      from tango.integrations.torch import Model\n\n      model = Model.from_params({\n          \"type\": \"transformers::AutoModel::from_pretrained\",\n          \"pretrained_model_name_or_path\": \"epwalsh/bert-xsmall-dummy\",\n      })\n\n  Or to instantiate a transformer model from params without loading pretrained weights:\n\n  .. testcode::\n\n      from tango.integrations.torch import Model\n\n      model = Model.from_params({\n          \"type\": \"transformers::AutoModel::from_config\",\n          \"config\": {\"pretrained_model_name_or_path\": \"epwalsh/bert-xsmall-dummy\"},\n      })\n\n  .. tip::\n\n        You can see a list of all of the available auto model constructors from transformers by running:\n\n        .. testcode::\n\n            from tango.integrations.torch import Model\n            from tango.integrations.transformers import *\n\n            available_models = []\n\n            for name in sorted(Model.list_available()):\n                if name.startswith(\"transformers::AutoModel\"):\n                    available_models.append(name)\n\n- :class:`~tango.integrations.torch.Optimizer`: All optimizers from transformers are registered according\n  to their class names (e.g. \"transformers::AdaFactor\").\n\n  .. tip::\n\n        You can see a list of all of the available optimizers from transformers by running:\n\n        .. testcode::\n\n            from tango.integrations.torch import Optimizer\n            from tango.integrations.transformers import *\n\n            for name in sorted(Optimizer.list_available()):\n                if name.startswith(\"transformers::\"):\n                    print(name)\n\n        .. testoutput::\n\n            transformers::Adafactor\n            transformers::AdamW\n            transformers::LayerWiseDummyOptimizer\n\n- :class:`~tango.integrations.torch.LRScheduler`: All learning rate scheduler function from transformers\n  are registered according to their type name (e.g. \"transformers::linear\").\n\n  .. tip::\n\n        You can see a list of all of the available scheduler functions from transformers by running:\n\n        .. testcode::\n\n            from tango.integrations.torch import LRScheduler\n            from tango.integrations.transformers import *\n\n            for name in sorted(LRScheduler.list_available()):\n                if name.startswith(\"transformers::\"):\n                    print(name)\n\n        .. testoutput::\n\n            transformers::constant\n            transformers::constant_with_warmup\n            transformers::cosine\n            transformers::cosine_with_min_lr\n            transformers::cosine_with_restarts\n            transformers::inverse_sqrt\n            transformers::linear\n            transformers::polynomial\n            transformers::reduce_lr_on_plateau\n\n- :class:`~tango.integrations.torch.DataCollator`: All data collators from transformers\n  are registered according to their class name (e.g. \"transformers::DefaultDataCollator\").\n\n  You can instantiate any of these from a config / params like so:\n\n  .. testcode::\n\n      from tango.integrations.torch import DataCollator\n\n      collator = DataCollator.from_params({\n          \"type\": \"transformers::DataCollatorWithPadding\",\n          \"tokenizer\": {\n              \"pretrained_model_name_or_path\": \"epwalsh/bert-xsmall-dummy\",\n          },\n      })\n\n  .. tip::\n\n        You can see a list of all of the available data collators from transformers by running:\n\n        .. testcode::\n\n            from tango.integrations.torch import DataCollator\n            from tango.integrations.transformers import *\n\n            for name in sorted(DataCollator.list_available()):\n                if name.startswith(\"transformers::\"):\n                    print(name)\n\n        .. testoutput::\n\n            transformers::DataCollatorForLanguageModeling\n            transformers::DataCollatorForPermutationLanguageModeling\n            transformers::DataCollatorForSOP\n            transformers::DataCollatorForSeq2Seq\n            transformers::DataCollatorForTokenClassification\n            transformers::DataCollatorForWholeWordMask\n            transformers::DataCollatorWithPadding\n            transformers::DefaultDataCollator\n\n\"\"\"\n\nfrom tango.common.exceptions import IntegrationMissingError\n\ntry:\n    import transformers\nexcept ModuleNotFoundError:\n    raise IntegrationMissingError(\"transformers\")\n\n__all__ = [\n    \"RunGeneration\",\n    \"RunGenerationDataset\",\n    \"Tokenizer\",\n    \"Config\",\n    \"add_soft_prompt\",\n    \"FinetuneWrapper\",\n    \"FinetuneStep\",\n    \"TokenizeText2TextData\",\n]\n\nfrom .config import Config\nfrom .data import *  # noqa: F403\nfrom .finetune import FinetuneStep, FinetuneWrapper, TokenizeText2TextData\nfrom .model import *  # noqa: F403\nfrom .optim import *  # noqa: F403\nfrom .run_generation import RunGeneration, RunGenerationDataset\nfrom .soft_prompt import add_soft_prompt\nfrom .tokenizer import Tokenizer\n"
  },
  {
    "path": "tango/integrations/transformers/config.py",
    "content": "from transformers import AutoConfig, PretrainedConfig\n\nfrom tango.common import Registrable\n\n\nclass Config(PretrainedConfig, Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of transformers'\n    :class:`~transformers.PretrainedConfig`.\n    \"\"\"\n\n    default_implementation = \"auto\"\n    \"\"\"\n    The default registered implementation just calls\n    :meth:`transformers.AutoConfig.from_pretrained()`.\n    \"\"\"\n\n\nConfig.register(\"auto\", constructor=\"from_pretrained\")(AutoConfig)\n"
  },
  {
    "path": "tango/integrations/transformers/data.py",
    "content": "from dataclasses import fields, is_dataclass\nfrom typing import Callable\n\nfrom transformers.data import data_collator as transformers_data_collator\n\nfrom tango.integrations.torch.data import DataCollator\n\nfrom .tokenizer import Tokenizer\n\n\n# Some data collators take a tokenizer, so in order to instantiate those collators from params,\n# we need to use a factory function that takes our registrable version of a tokenizer as\n# an argument.\ndef data_collator_with_tokenizer_factory(cls) -> Callable[..., DataCollator]:\n    def factory(tokenizer: Tokenizer, **kwargs) -> DataCollator:\n        return cls(tokenizer=tokenizer, **kwargs)\n\n    return factory\n\n\nfor name, cls in transformers_data_collator.__dict__.items():\n    if (\n        isinstance(cls, type)\n        and is_dataclass(cls)\n        and \"DataCollator\" in name\n        and hasattr(cls, \"__call__\")\n    ):\n        for field in fields(cls):\n            if field.name == \"tokenizer\":\n                factory_func = data_collator_with_tokenizer_factory(cls)\n                DataCollator.register(\"transformers::\" + name)(factory_func)  # type: ignore\n                break\n        else:\n            DataCollator.register(\"transformers::\" + name)(cls)\n"
  },
  {
    "path": "tango/integrations/transformers/finetune.py",
    "content": "import logging\nfrom os import PathLike\nfrom typing import List, Optional, Union, cast\n\nimport datasets as ds\nfrom transformers import (\n    AutoConfig,\n    AutoModelForCausalLM,\n    AutoModelForSeq2SeqLM,\n    DataCollatorForSeq2Seq,\n    PreTrainedModel,\n)\n\nfrom tango.common import Lazy, Params\nfrom tango.format import Format\nfrom tango.integrations.datasets import DatasetsFormat, convert_to_tango_dataset_dict\nfrom tango.integrations.torch import (\n    DataCollator,\n    DataLoader,\n    Model,\n    TorchFormat,\n    TrainCallback,\n    TrainingEngine,\n)\nfrom tango.integrations.torch.train import TorchTrainStep\nfrom tango.integrations.transformers import Tokenizer\nfrom tango.step import Step\n\nlogger = logging.getLogger(__name__)\n\nSEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys()  # type: ignore\nCAUSAL = AutoModelForCausalLM._model_mapping.keys()  # type: ignore\n\n\nclass FinetuneWrapper(PreTrainedModel):\n    \"\"\"\n    Wrapper `PreTrainedModel` class that returns either a `Seq2SeqLM` or `CausalLM` model.\n    \"\"\"\n\n    @classmethod\n    def from_pretrained(  # type: ignore\n        cls,\n        pretrained_model_name_or_path: Union[str, PathLike],\n        num_tokens: Optional[int] = None,\n        **kwargs,\n    ) -> PreTrainedModel:\n        \"\"\"\n        :param pretrained_model_name_or_path:\n            The name of the model to return. Any name that works in the transformers library works here.\n        :param num_tokens:\n            The number of token embeddings to have.\n        \"\"\"\n        try:\n            model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path, **kwargs)\n        except ValueError:\n            model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs)\n\n        if num_tokens is not None:\n            model.resize_token_embeddings(num_tokens)\n        return model\n\n\nModel.register(\"transformers::finetune::from_pretrained\", constructor=\"from_pretrained\")(\n    FinetuneWrapper\n)\n\n\ndef _add_special_tokens(tokenizer: Tokenizer) -> None:\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n    if tokenizer.pad_token is None:\n        tokenizer.add_special_tokens({\"pad_token\": \"[PAD]\"})\n    if tokenizer.sep_token is None:\n        tokenizer.add_special_tokens({\"sep_token\": \"[SEP]\"})\n    if tokenizer.eos_token is None:\n        tokenizer.add_special_tokens({\"eos_token\": \"[EOS]\"})\n\n\ndef tokenize_data(\n    data: ds.DatasetDict,\n    tokenizer: Tokenizer,\n    num_workers: int = 1,\n    source_field: str = \"source\",\n    target_field: str = \"target\",\n    max_source_length: Optional[int] = 1024,\n    max_target_length: Optional[int] = 1024,\n    pad_to_max_length: bool = False,\n    ignore_pad_token_for_loss: bool = True,\n    concat_source_target: bool = False,\n) -> ds.DatasetDict:\n    \"\"\"\n    Returns a `DatasetDict` with tokenized source and target fields.\n\n    :param data:\n        The original dataset dict containing the source and target fields.\n    :param tokenizer:\n        The tokenizer to use.\n    :param num_workers:\n        The number of workers to use for processing the data.\n    :param source_field:\n        The string name of the field containing the source sequence.\n    :param target_field:\n        The string name of the field containing the target sequence.\n    :param max_source_length:\n        The maximum number of tokens in the source sequence.\n    :param max_target_length:\n        The maximum number of tokens in the target sequence.\n    :param pad_to_max_length:\n        Whether to pad to the maximum length when tokenizing.\n    :param ignore_pad_token_for_loss:\n        Whether to ignore the padded tokens for calculating loss.\n        If set to True, all the pad tokens in the labels are replaced\n        by -100, which is ignored by the loss function.\n    :param concat_source_target:\n        If the downstream model is decoder-only, like \"gpt2\", the source\n        and target sequences need to be concatenated and fed to the model\n        together.\n    \"\"\"\n    padding = \"max_length\" if pad_to_max_length else False\n\n    _add_special_tokens(tokenizer)\n\n    def preprocess_function(examples):\n        # remove pairs where at least one record is None\n        inputs, targets = [], []\n        input_lengths = []\n        for i in range(len(examples[source_field])):\n            if examples[source_field][i] is not None and examples[target_field][i] is not None:\n                if not concat_source_target:\n                    inputs.append(examples[source_field][i])\n                    targets.append(examples[target_field][i])\n                else:\n                    text = (\n                        examples[source_field][i]\n                        + tokenizer.sep_token\n                        + examples[target_field][i]\n                        + tokenizer.eos_token\n                    )\n                    inputs.append(text)\n                    targets.append(text)\n                    input_lengths.append(len(examples[source_field][i]))\n\n        model_inputs = tokenizer(\n            inputs, max_length=max_source_length, padding=padding, truncation=True\n        )\n\n        if not concat_source_target:\n            # Setup the tokenizer for targets\n            with tokenizer.as_target_tokenizer():\n                labels = tokenizer(\n                    targets, max_length=max_target_length, padding=padding, truncation=True\n                )\n        else:\n            labels = {\"input_ids\": []}\n            for input_ids in model_inputs[\"input_ids\"]:\n                label_start_idx = input_ids.index(tokenizer.sep_token_id)\n                label_ids = [-100] * len(input_ids)\n                label_ids[label_start_idx + 1 :] = input_ids[label_start_idx + 1 :]\n                labels[\"input_ids\"].append(label_ids)\n\n        # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100\n        # when we want to ignore padding in the loss.\n        if padding == \"max_length\" and ignore_pad_token_for_loss:\n            labels[\"input_ids\"] = [\n                [(lb if lb != tokenizer.pad_token_id else -100) for lb in label]\n                for label in labels[\"input_ids\"]\n            ]\n\n        model_inputs[\"labels\"] = labels[\"input_ids\"]\n        return model_inputs\n\n    data = data.map(\n        preprocess_function,\n        batched=True,\n        num_proc=num_workers,\n        remove_columns=list(data.column_names.values())[0],  # remove all old columns\n        desc=\"Tokenizing dataset\",\n    )\n\n    return data\n\n\n@Step.register(\"transformers::tokenize_text2text\")\nclass TokenizeText2TextData(Step):\n    \"\"\"\n    A step that tokenizes data containing source and target sequences.\n\n    .. tip::\n        Registered as a :class:`~tango.step.Step` under the name \"transformers::tokenize_text2text\".\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT = DatasetsFormat()\n\n    def run(  # type: ignore[override]\n        self,\n        data: ds.DatasetDict,\n        tokenizer: Tokenizer,\n        num_workers: int = 1,\n        source_field: str = \"source\",\n        target_field: str = \"target\",\n        max_source_length: Optional[int] = 1024,\n        max_target_length: Optional[int] = 1024,\n        pad_to_max_length: bool = False,\n        ignore_pad_token_for_loss: bool = True,\n        concat_source_target: bool = False,\n    ) -> ds.DatasetDict:\n        \"\"\"\n        Returns a `DatasetDict` with tokenized source and target fields.\n\n        :param data:\n            The original dataset dict containing the source and target fields.\n        :param tokenizer:\n            The tokenizer to use.\n        :param num_workers:\n            The number of workers to use for processing the data.\n        :param source_field:\n            The string name of the field containing the source sequence.\n        :param target_field:\n            The string name of the field containing the target sequence.\n        :param max_source_length:\n            The maximum number of tokens in the source sequence.\n        :param max_target_length:\n            The maximum number of tokens in the target sequence.\n        :param pad_to_max_length:\n            Whether to pad to the maximum length when tokenizing.\n        :param ignore_pad_token_for_loss:\n            Whether to ignore the padded tokens for calculating loss.\n            If set to True, all the pad tokens in the labels are replaced\n            by -100, which is ignored by the loss function.\n        :param concat_source_target:\n            If the downstream model is decoder-only, like \"gpt2\", the source\n            and target sequences need to be concatenated and fed to the model\n            together.\n\n        .. tip::\n            If concat_source_target is set to True, we pad all sequences to max\n            length here. Otherwise, we leave it to the appropriate\n            :class:`~tango.integrations.torch.DataCollator` object.\n        \"\"\"\n        return tokenize_data(\n            data,\n            tokenizer=tokenizer,\n            num_workers=num_workers,\n            source_field=source_field,\n            target_field=target_field,\n            max_source_length=max_source_length,\n            max_target_length=max_target_length,\n            pad_to_max_length=pad_to_max_length,\n            ignore_pad_token_for_loss=ignore_pad_token_for_loss,\n            concat_source_target=concat_source_target,\n        )\n\n\n@Step.register(\"transformers::finetune\")\nclass FinetuneStep(TorchTrainStep):\n    \"\"\"\n    Mostly similar to :class:`~tango.integrations.torch.train.TorchTrainStep` with additional\n    preprocessing for data.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"transformers::finetune\".\n\n    .. important::\n\n        The training loop will use GPU(s) automatically when available, as long as at least\n        ``device_count`` CUDA devices are available.\n\n        Distributed data parallel training is activated when the ``device_count`` is greater than 1.\n\n        You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``.\n        For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1``\n        (and ``device_count`` to 2).\n\n    .. warning::\n\n        During validation, the validation metric (specified by the ``val_metric_name`` parameter)\n        is aggregated by simply averaging across validation batches and distributed processes.\n        This behavior is usually correct when your validation metric is \"loss\" or \"accuracy\",\n        for example, but may not be correct for other metrics like \"F1\".\n\n        If this is not correct for your metric you will need to handle the aggregation\n        internally in your model or with a :class:`TrainCallback`\n        using the :meth:`TrainCallback.post_val_batch()` method.\n        Then set the parameter ``auto_aggregate_val_metric`` to ``False``.\n\n        Note that correctly aggregating your metric during distributed training will\n        involve distributed communication.\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = TorchFormat()\n    SKIP_ID_ARGUMENTS = {\"distributed_port\", \"log_every\"}\n\n    def run(  # type: ignore[override]\n        self,\n        model: Lazy[Model],\n        tokenizer: Tokenizer,\n        training_engine: Lazy[TrainingEngine],\n        dataset_dict: ds.DatasetDict,\n        train_dataloader: Lazy[DataLoader],\n        *,\n        train_split: str = \"train\",\n        validation_split: Optional[str] = None,\n        validation_dataloader: Optional[Lazy[DataLoader]] = None,\n        source_field: str = \"source\",\n        target_field: str = \"target\",\n        max_source_length: Optional[int] = 1024,\n        max_target_length: Optional[int] = 1024,\n        seed: int = 42,\n        train_steps: Optional[int] = None,\n        train_epochs: Optional[int] = None,\n        validation_steps: Optional[int] = None,\n        grad_accum: int = 1,\n        log_every: int = 10,\n        checkpoint_every: int = 100,\n        validate_every: Optional[int] = None,\n        device_count: int = 1,\n        distributed_port: int = 54761,\n        val_metric_name: str = \"loss\",\n        minimize_val_metric: bool = True,\n        auto_aggregate_val_metric: bool = True,\n        callbacks: Optional[List[Lazy[TrainCallback]]] = None,\n        remove_stale_checkpoints: bool = True,\n    ) -> Model:\n        \"\"\"\n        Run a basic training loop to train the ``model``.\n\n        :param model:\n            The model to train. It should return a ``dict`` that includes the ``loss``\n            during training and the ``val_metric_name`` during validation.\n        :param tokenizer:\n            The tokenizer to use for tokenizing source and target sequences.\n        :param training_engine:\n            The :class:`TrainingEngine` to use to train the model.\n        :param dataset_dict:\n            The train and optional validation data.\n        :param train_dataloader:\n            The data loader that generates training batches. The batches should be :class:`dict`\n            objects that will be used as ``kwargs`` for the model's ``forward()`` method.\n        :param train_split:\n            The name of the data split used for training in the ``dataset_dict``.\n            Default is \"train\".\n        :param validation_split:\n            Optional name of the validation split in the ``dataset_dict``. Default is ``None``,\n            which means no validation.\n        :param validation_dataloader:\n            An optional data loader for generating validation batches. The batches should be\n            :class:`dict` objects. If not specified, but ``validation_split`` is given,\n            the validation ``DataLoader`` will be constructed from the same parameters\n            as the train ``DataLoader``.\n        :param source_field:\n            The string name of the field containing the source sequence.\n        :param target_field:\n            The string name of the field containing the target sequence.\n        :param max_source_length:\n            The maximum number of tokens in the source sequence.\n        :param max_target_length:\n            The maximum number of tokens in the target sequence.\n        :param seed:\n            Used to set the RNG states at the beginning of training.\n        :param train_steps:\n            The number of steps to train for. If not specified training will\n            stop after a complete iteration through the ``train_dataloader``.\n        :param train_epochs:\n            The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs``\n            at the same time.\n        :param validation_steps:\n            The number of steps to validate for. If not specified validation\n            will stop after a complete iteration through the ``validation_dataloader``.\n        :param grad_accum:\n            The number of gradient accumulation steps. Defaults to 1.\n\n            .. note::\n                This parameter - in conjuction with the settings of your data loader\n                and the number distributed workers -\n                determines the *effective batch size* of your training run.\n\n        :param log_every:\n            Log every this many steps.\n        :param checkpoint_every:\n            Save a checkpoint every this many steps.\n        :param validate_every:\n            Run the validation loop every this many steps.\n        :param device_count:\n            The number of devices to train on, i.e. the number of distributed data parallel workers.\n        :param distributed_port:\n            The port of the distributed process group. Default = \"54761\".\n        :param val_metric_name:\n            The name of the validation metric, i.e. the key of the metric in the dictionary\n            returned by the forward pass of the model. Default is \"loss\".\n        :param minimize_val_metric:\n            Whether the validation metric is meant to be minimized (such as the loss).\n            Default is ``True``. When using a metric such as accuracy, you should set\n            this to ``False``.\n        :param auto_aggregate_val_metric:\n            If ``True`` (the default), the validation metric will be averaged across\n            validation batches and distributed processes. This may not be the correct\n            behavior for some metrics (such as F1), in which you should set this to\n            ``False`` and handle the aggregation internally in your model\n            or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`).\n        :param callbacks:\n            A list of :class:`TrainCallback`.\n        :param remove_stale_checkpoints:\n            If ``True`` (the default), stale checkpoints will be removed throughout training so that\n            only the latest and best checkpoints are kept.\n\n        :returns:\n            The trained model on CPU with the weights from the best checkpoint loaded.\n\n        \"\"\"\n        devices = self._get_devices(device_count)\n\n        is_distributed = False\n        if devices and len(devices) > 1:\n            is_distributed = True\n\n        # Setup the tokenizer\n        _add_special_tokens(tokenizer)\n\n        # Hacky way to deal with resizing the model embeddings.\n        model_params_dict = model._params.as_dict()\n        if \"fairscale\" in model_params_dict[\"type\"]:\n            model_params_dict[\"model\"][\"num_tokens\"] = len(tokenizer)  # type: ignore\n        else:\n            model_params_dict[\"num_tokens\"] = len(tokenizer)  # type: ignore\n\n        model = Lazy(\n            model._constructor,\n            Params(model_params_dict),\n            constructor_extras=model._constructor_extras,\n        )\n\n        # Get the config to check in order to check if the model is seq2seq or causal.\n        config = AutoConfig.from_pretrained(tokenizer.name_or_path)\n        seq2seq: bool = type(config) in SEQ2SEQ\n\n        dataset_dict = tokenize_data(\n            dataset_dict,\n            tokenizer=tokenizer,\n            source_field=source_field,\n            target_field=target_field,\n            max_source_length=max_source_length,\n            max_target_length=max_target_length,\n            concat_source_target=not seq2seq,\n        )\n\n        if is_distributed:\n            from torch.utils.data.distributed import DistributedSampler\n\n            sampler = Lazy(DistributedSampler, drop_last=True, shuffle=True)\n            train_dataloader = Lazy(\n                train_dataloader._constructor,\n                train_dataloader._params,\n                constructor_extras=train_dataloader._constructor_extras,\n                sampler=sampler,\n            )\n\n        collate_fn: DataCollator\n        collate_fn = cast(DataCollator, DataCollatorForSeq2Seq(tokenizer=tokenizer))\n\n        train_dataloader = Lazy(\n            train_dataloader._constructor,\n            train_dataloader._params,\n            constructor_extras=train_dataloader._constructor_extras,\n            collate_fn=collate_fn,\n        )\n\n        return self._train(\n            model=model,\n            training_engine=training_engine,\n            dataset_dict=convert_to_tango_dataset_dict(dataset_dict),\n            train_dataloader=train_dataloader,\n            train_split=train_split,\n            validation_split=validation_split,\n            validation_dataloader=validation_dataloader,\n            seed=seed,\n            train_steps=train_steps,\n            train_epochs=train_epochs,\n            validation_steps=validation_steps,\n            grad_accum=grad_accum,\n            log_every=log_every,\n            checkpoint_every=checkpoint_every,\n            validate_every=validate_every,\n            devices=devices,\n            distributed_port=distributed_port,\n            val_metric_name=val_metric_name,\n            minimize_val_metric=minimize_val_metric,\n            auto_aggregate_val_metric=auto_aggregate_val_metric,\n            callbacks=callbacks,\n            remove_stale_checkpoints=remove_stale_checkpoints,\n        )\n"
  },
  {
    "path": "tango/integrations/transformers/ia3.py",
    "content": "import re\nfrom dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom transformers import PreTrainedModel\nfrom transformers.modeling_utils import Conv1D\n\n\n@dataclass\nclass WithIA3Config:\n    \"\"\"\n    A class for configuring which layers to modify with IA3 adaptors.\n\n\n    :param ia3_param_names:\n        A string used as the name for all ia3 parameters\n    :param attention_modules:\n        A regex that matches all attention modules which are parents to the keys and value layers to modify.\n    :param mlp_modules:\n        A regex that matches all modules that are parents to the feed forward layer to modify.\n    :param mlp_layers:\n        A regex that matches the feed forward layer in the modules specified by `mlp_modles`.\n    :param fused_qkv_layers:\n        A regex that matches the combined query, key, and value layer in the modules specified\n        by `attention_modules`.\n    :param k_layers:\n        A regex that matches the key layer in the modules specified by `attention_modules`.\n    :param v_layers:\n        A regex that matches the value layer in the modules specified by `attention_modules`.\n    \"\"\"\n\n    ia3_param_names: str\n    attention_modules: str\n    mlp_modules: str\n    mlp_layers: str\n    fused_qkv_layers: Optional[str] = None\n    k_layers: Optional[str] = None\n    v_layers: Optional[str] = None\n\n\nGPT_J_IA3_CONFIG = WithIA3Config(\n    attention_modules=\".*attn\",\n    k_layers=\"k_proj\",\n    v_layers=\"v_proj\",\n    mlp_modules=\".*mlp\",\n    mlp_layers=\"fc_in\",\n    ia3_param_names=\"ia3\",\n)\n\nGPT_2_IA3_CONFIG = WithIA3Config(\n    attention_modules=\".*attn\",\n    fused_qkv_layers=\"c_attn\",\n    mlp_modules=\".*mlp\",\n    mlp_layers=\"c_fc\",\n    ia3_param_names=\"ia3\",\n)\n\nOPT_IA3_CONFIG = WithIA3Config(\n    attention_modules=\".*self_attn\",\n    k_layers=\"k_proj\",\n    v_layers=\"v_proj\",\n    mlp_modules=r\".*layers\\.\\d*\",\n    mlp_layers=\"fc1\",\n    ia3_param_names=\"ia3\",\n)\n\nBLOOM_IA3_CONFIG = WithIA3Config(\n    attention_modules=\".*self_attention\",\n    fused_qkv_layers=\"query_key_value\",\n    mlp_modules=\".*mlp\",\n    mlp_layers=\"dense_h_to_4h\",\n    ia3_param_names=\"ia3\",\n)\n\nMODEL_NAME_TO_CONFIG = {\n    \"sshleifer/tiny-gpt2\": GPT_2_IA3_CONFIG,\n    \"gpt2\": GPT_2_IA3_CONFIG,\n    \"gpt2-medium\": GPT_2_IA3_CONFIG,\n    \"gpt2-large\": GPT_2_IA3_CONFIG,\n    \"gpt2-xl\": GPT_2_IA3_CONFIG,\n    \"bigscience/bloom-560m\": BLOOM_IA3_CONFIG,\n    \"bigscience/bloom-1b1\": BLOOM_IA3_CONFIG,\n    \"bigscience/bloom-1b7\": BLOOM_IA3_CONFIG,\n    \"bigscience/bloom-3b\": BLOOM_IA3_CONFIG,\n    \"bigscience/bloom-7b1\": BLOOM_IA3_CONFIG,\n    \"bigscience/bloom\": BLOOM_IA3_CONFIG,\n    \"facebook/opt-125m\": OPT_IA3_CONFIG,\n    \"facebook/opt-350m\": OPT_IA3_CONFIG,\n    \"facebook/opt-1.3b\": OPT_IA3_CONFIG,\n    \"facebook/opt-2.7b\": OPT_IA3_CONFIG,\n    \"facebook/opt-6.7b\": OPT_IA3_CONFIG,\n    \"facebook/opt-13b\": OPT_IA3_CONFIG,\n    \"facebook/opt-30b\": OPT_IA3_CONFIG,\n    \"facebook/opt-66b\": OPT_IA3_CONFIG,\n    \"EleutherAI/gpt-j-6B\": GPT_J_IA3_CONFIG,\n}\n\n\nclass WithIA3(nn.Module):\n    def __init__(self, ia3_param_names: str, unfuse_size: Optional[int] = None):\n        super().__init__()\n        self.ia3_param_names = ia3_param_names\n\n        # if (q,k,v) are stacked into one layer\n        if unfuse_size is not None:\n            # IA3 only operates on k and v (not q), thus the \"* 2\"\n            setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1)))\n        else:\n            setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1)))  # type: ignore\n\n    def scale_by_ia3(self, x):\n        ia3_params = getattr(self, self.ia3_param_names)\n\n        if ia3_params.requires_grad:\n            if self.unfuse_size is not None:\n                # non_q means k and v\n                q, non_q = x[:, :, : self.unfuse_size], x[:, :, self.unfuse_size :]  # type: ignore\n                ia3_params = getattr(self, self.ia3_param_names)\n                non_q = non_q * ia3_params.flatten()\n                x = torch.cat([q, non_q], dim=2)\n            else:\n                x = x * ia3_params.flatten()\n\n        return x\n\n\nclass LinearWithIA3(WithIA3):\n    def __init__(\n        self, linear_layer: nn.Linear, ia3_param_names: str, unfuse_size: Optional[int] = None\n    ):\n        \"\"\"\n        A replacement for :class:`~torch.nn.Linear` modified with an IA3 adaptor\n\n\n        :param linear_layer:\n            A :class:`~torch.nn.Linear` layer to adapt.\n        :param ia3_param_names:\n            A `str` to use as the name of ia3 parameters.\n        :param unfuse_size:\n            An `int` indicating hidden dimension of the query, key, and value vectors.\n            To be used only when the layer to modify is a fused projection of query,\n            key, and value vectors in an attention mechanism.\n        \"\"\"\n        assert unfuse_size is None or (linear_layer.out_features == unfuse_size * 3)\n        self.in_features = linear_layer.in_features\n        self.out_features = linear_layer.out_features\n        self.unfuse_size = unfuse_size\n\n        super().__init__(ia3_param_names, unfuse_size)\n\n        self.weight = linear_layer.weight\n        self.bias = linear_layer.bias\n\n    def forward(self, x):\n        x = F.linear(x, self.weight, self.bias)\n        return self.scale_by_ia3(x)\n\n\nclass Conv1DWithIA3(WithIA3):\n    def __init__(\n        self, conv1d_layer: Conv1D, ia3_param_names: str, unfuse_size: Optional[int] = None\n    ):\n        \"\"\"\n        A replacement for :class:`~transformers.modeling_utils.Conv1D` modified with an IA3 adaptor\n\n\n        :param conv1d_layer:\n            A :class:`~transformers.modeling_utils.Conv1D` layer to adapt.\n        :param ia3_param_names:\n            A `str` to use as the name of ia3 parameters.\n        :param unfuse_size:\n            An `int` indicating hidden dimension of the query, key, and value vectors.\n            To be used only when the layer to modify is a fused projection of query,\n            key, and value vectors in an attention mechanism.\n        \"\"\"\n        assert unfuse_size is None or (conv1d_layer.nf == unfuse_size * 3)\n\n        # nf: number of output features; nx: number of input features\n        self.out_features = conv1d_layer.nf\n        self.unfuse_size = unfuse_size\n\n        super().__init__(ia3_param_names, unfuse_size)\n\n        self.weight = conv1d_layer.weight\n        self.bias = conv1d_layer.bias\n\n    def forward(self, x):\n        # copied and pasted from the original Conv1D implemnetation\n        size_out = x.size()[:-1] + (self.out_features,)\n        x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)\n        x = x.view(size_out)  # ... * self.nf\n\n        return self.scale_by_ia3(x)\n\n\ndef modify_with_ia3(\n    transformer: PreTrainedModel,\n    *,\n    config: Optional[WithIA3Config] = None,\n    only_ia3_requires_grad: bool = True,\n) -> PreTrainedModel:\n    \"\"\"\n    A function to add ia3 adaptors to the given transformer. Code modified from\n    `t-few <https://github.com/r-three/t-few/blob/217cfa3b73aa66a07594826e4ebbbc516b331461/src/models/lora.py>`_\n    and Qinyuan Ye\n\n\n    :param model:\n        A :class:`~transformers.PreTrainedModel` to modify.\n    :param config:\n        A :class:`~tango.integrations.transformers.ia3.WithIA3Config` that specifies the layers to modify.\n    :param only_ia3_requires_grad:\n        A `bool`, `True` if `requires_grad` should only be set on ia3 paramenters in the output model.\n\n    Examples\n    --------\n\n    You can use the provided configurations:\n\n    .. testcode::\n\n        from transformers import AutoModelForCausalLM, AutoTokenizer\n        from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG\n\n        model = AutoModelForCausalLM.from_pretrained(\"sshleifer/tiny-gpt2\")\n        model = modify_with_ia3(model, config=GPT_2_IA3_CONFIG)\n\n    Or you can write your own configuration with regex matching the layers to modify and their parents:\n\n    .. testcode::\n\n        from transformers import AutoModelForCausalLM, AutoTokenizer\n        from tango.integrations.transformers.ia3 import modify_with_ia3\n\n        my_config = WithIA3Config(\n            attention_modules=\".*attn\",\n            fused_qkv_layers=\"c_attn\",\n            mlp_modules=\".*mlp\",\n            mlp_layers=\"c_fc\",\n            ia3_param_names=\"ia3\",\n        )\n\n        model = AutoModelForCausalLM.from_pretrained(\"sshleifer/tiny-gpt2\")\n        model = modify_with_ia3(model, config=my_config)\n    \"\"\"\n    if config is None:\n        model_name = transformer.config._name_or_path  # type: ignore\n        assert (\n            model_name in MODEL_NAME_TO_CONFIG\n        ), f\"{model_name} does not have a pre made configuration; please make your own.\"\n        config = MODEL_NAME_TO_CONFIG[model_name]\n\n    for m_name, module in dict(transformer.named_modules()).items():  # type: ignore\n        if re.fullmatch(config.attention_modules, m_name) or re.fullmatch(\n            config.mlp_modules, m_name\n        ):\n            attn_layers = [\n                regex\n                for regex in (config.fused_qkv_layers, config.k_layers, config.v_layers)\n                if regex is not None\n            ]\n            layers_to_change = (\n                \"|\".join(attn_layers)\n                if re.fullmatch(config.attention_modules, m_name)\n                else config.mlp_layers\n            )\n            for c_name, layer in dict(module.named_children()).items():\n                if re.fullmatch(layers_to_change, c_name):\n                    assert isinstance(layer, Conv1D) or isinstance(\n                        layer, nn.Linear\n                    ), \"This code only supports Conv1D and nn.Linear\"\n                    adaptor_class = Conv1DWithIA3 if isinstance(layer, Conv1D) else LinearWithIA3\n                    new_module = adaptor_class(\n                        layer,\n                        config.ia3_param_names,\n                        unfuse_size=transformer.config.hidden_size  # type: ignore\n                        if config.fused_qkv_layers and re.fullmatch(config.fused_qkv_layers, c_name)\n                        else None,\n                    )\n                    setattr(module, c_name, new_module)\n\n    if only_ia3_requires_grad:\n        transformer.requires_grad_(False)  # type: ignore\n        for p_name, v in dict(transformer.named_parameters()).items():  # type: ignore\n            if re.fullmatch(\".*\" + config.ia3_param_names + \".*\", p_name):\n                v.requires_grad_(True)\n\n    return transformer\n"
  },
  {
    "path": "tango/integrations/transformers/model.py",
    "content": "from typing import Optional, Type\n\nfrom transformers.models.auto import modeling_auto\n\nfrom tango.common.exceptions import IntegrationMissingError\nfrom tango.integrations.torch.model import Model\n\nfrom .config import Config\n\n\ndef auto_model_wrapper_factory(cls: type) -> Type[Model]:\n    class AutoModelWrapper(cls, Model):  # type: ignore\n        @classmethod\n        def from_pretrained(\n            cls, pretrained_model_name_or_path: str, config: Optional[Config] = None, **kwargs\n        ) -> Model:\n            return super().from_pretrained(pretrained_model_name_or_path, config=config, **kwargs)\n\n        @classmethod\n        def from_config(cls, config: Config, **kwargs) -> Model:\n            return super().from_config(config, **kwargs)\n\n    return AutoModelWrapper\n\n\nfor name, cls in modeling_auto.__dict__.items():\n    if isinstance(cls, type) and name.startswith(\"AutoModel\"):\n        wrapped_cls = auto_model_wrapper_factory(cls)\n        Model.register(\n            \"transformers::\" + name + \"::from_pretrained\", constructor=\"from_pretrained\"\n        )(wrapped_cls)\n        Model.register(\"transformers::\" + name + \"::from_config\", constructor=\"from_config\")(\n            wrapped_cls\n        )\n\ntry:\n    from transformers.models.auto import modeling_flax_auto\n\n    from tango.integrations.flax.model import Model as FlaxModel\n\n    def flax_auto_model_wrapper_factory(cls: type) -> Type[FlaxModel]:\n        class AutoModelWrapper(cls, FlaxModel):  # type: ignore\n            @classmethod\n            def from_pretrained(\n                cls, pretrained_model_name_or_path: str, config: Optional[Config] = None, **kwargs\n            ) -> FlaxModel:\n                return super().from_pretrained(\n                    pretrained_model_name_or_path, config=config, **kwargs\n                )\n\n            @classmethod\n            def from_config(cls, config: Config, **kwargs) -> FlaxModel:\n                return super().from_config(config, **kwargs)\n\n        return AutoModelWrapper\n\n    for name, cls in modeling_flax_auto.__dict__.items():\n        if isinstance(cls, type) and name.startswith(\"FlaxAutoModel\"):\n            wrapped_cls_ = flax_auto_model_wrapper_factory(cls)\n            FlaxModel.register(\n                \"transformers::\" + name + \"::from_pretrained\", constructor=\"from_pretrained\"\n            )(wrapped_cls_)\n            FlaxModel.register(\n                \"transformers::\" + name + \"::from_config\", constructor=\"from_config\"\n            )(wrapped_cls_)\n\nexcept ModuleNotFoundError:\n    pass\nexcept IntegrationMissingError:\n    pass\n"
  },
  {
    "path": "tango/integrations/transformers/optim.py",
    "content": "import torch\nfrom transformers import optimization as transformers_optim\n\nfrom tango.integrations.torch.optim import LRScheduler, Optimizer\n\n# Register all transformers optimizers.\nfor name, cls in transformers_optim.__dict__.items():\n    if (\n        isinstance(cls, type)\n        and issubclass(cls, torch.optim.Optimizer)\n        and not cls == torch.optim.Optimizer\n    ):\n        Optimizer.register(\"transformers::\" + name)(cls)\n\n\n# Register all transformers scheduler factory functions.\nfor scheduler_type, scheduler_func in transformers_optim.TYPE_TO_SCHEDULER_FUNCTION.items():\n    name = scheduler_type.value\n    LRScheduler.register(\"transformers::\" + name)(scheduler_func)  # type: ignore\n"
  },
  {
    "path": "tango/integrations/transformers/run_generation.py",
    "content": "import logging\nimport typing\nfrom typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union, cast\n\nimport more_itertools\nimport torch\nfrom datasets import Dataset\nfrom datasets import DatasetDict as HfDatasetDict\nfrom transformers import (\n    AutoModelForCausalLM,\n    AutoModelForSeq2SeqLM,\n    AutoTokenizer,\n    CTRLLMHeadModel,\n    CTRLTokenizer,\n    GPT2LMHeadModel,\n    GPT2Tokenizer,\n    OpenAIGPTLMHeadModel,\n    OpenAIGPTTokenizer,\n    PreTrainedModel,\n    PreTrainedTokenizer,\n    PreTrainedTokenizerFast,\n    TransfoXLLMHeadModel,\n    TransfoXLTokenizer,\n    XLMTokenizer,\n    XLMWithLMHeadModel,\n    XLNetLMHeadModel,\n    XLNetTokenizer,\n)\n\nfrom tango import Format, JsonFormat, SqliteDictFormat, Step\nfrom tango.common import DatasetDict\nfrom tango.common.sequences import MappedSequence, SqliteSparseSequence\nfrom tango.common.tqdm import Tqdm\nfrom tango.integrations.torch import Model\nfrom tango.integrations.torch.util import resolve_device, set_seed_all\n\nlogger = logging.getLogger(__name__)\n\n#\n# A lot of the code in this step is stolen from the run_generation.py script in transformers. Unfortunately their\n# examples don't ship when you `pip install transformers`, so we have to duplicate it here.\n#\n\nMAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop\n\nMODEL_CLASSES = {\n    \"gpt2\": (GPT2LMHeadModel, GPT2Tokenizer),\n    \"ctrl\": (CTRLLMHeadModel, CTRLTokenizer),\n    \"openai-gpt\": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),\n    \"xlnet\": (XLNetLMHeadModel, XLNetTokenizer),\n    \"transfo-xl\": (TransfoXLLMHeadModel, TransfoXLTokenizer),\n    \"xlm\": (XLMWithLMHeadModel, XLMTokenizer),\n}\n\n# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia\n# in https://github.com/rusiaaman/XLNet-gen#methodology\n# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e\nPREFIX = \"\"\"In 1991, the remains of Russian Tsar Nicholas II and his family\n(except for Alexei and Maria) are discovered.\nThe voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the\nremainder of the story. 1883 Western Siberia,\na young Grigori Rasputin is asked by his father and a group of men to perform magic.\nRasputin has a vision and denounces one of the men as a horse thief. Although his\nfather initially slaps him for making such an accusation, Rasputin watches as the\nman is chased outside and beaten. Twenty years later, Rasputin sees a vision of\nthe Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,\nwith people, even a bishop, begging for his blessing. <eod> </s> <eos>\"\"\"\n\nSEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys()  # type: ignore\nCAUSAL = AutoModelForCausalLM._model_mapping.keys()  # type: ignore\n\n\ndef adjust_length_to_model(length, model):\n    max_sequence_length = (\n        model.config.max_position_embeddings\n        if hasattr(model.config, \"max_position_embeddings\")\n        else MAX_LENGTH\n    )\n    if length < 0 and max_sequence_length > 0:\n        length = max_sequence_length\n    elif 0 < max_sequence_length < length:\n        length = max_sequence_length  # No generation bigger than model size\n    elif length < 0:\n        length = MAX_LENGTH  # avoid infinite loop\n    return length\n\n\n@typing.no_type_check  # mypy has somehow lost the ability to tell what PreTrainedTokenizer and Model are.\ndef _generate(\n    model: Model,\n    # TODO: Change type to `Tokenizer` once HF includes `convert_tokens_to_ids` in `PretrainedTokenizerBase` class.\n    tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],\n    prompts: Iterable[str],\n    *,\n    batch_size: int = 4,\n    max_length: int = 20,\n    temperature: float = 1.0,\n    repetition_penalty: float = 1.0,\n    k: int = 0,\n    p: float = 0.9,\n    prefix: str = \"\",\n    xlm_language: str = \"\",\n    seed: int = 42,\n    num_return_sequences: int = 1,\n    fp16: bool = False,\n) -> Iterable[List[str]]:\n    if not isinstance(model.config, tuple(SEQ2SEQ + CAUSAL)):\n        raise NotImplementedError(\n            \"This function is only defined for huggingface models seq2seq/causal models.\"\n        )\n\n    device = resolve_device()\n    set_seed_all(seed)\n\n    tokenizer_kwargs: Dict[str, Any] = {}\n    tokenizer.padding_side = \"left\"\n\n    if tokenizer.pad_token is None:\n        tokenizer.pad_token = tokenizer.eos_token\n    if tokenizer.pad_token is None:\n        tokenizer.add_special_tokens({\"pad_token\": \"[PAD]\"})\n\n    if tokenizer.eos_token is None:\n        tokenizer.add_special_tokens({\"eos_token\": \"[EOS]\"})\n\n    eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)\n    pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)\n\n    # Seq2Seq models don't return their own prefix.\n    seq2seq_model = model.config_class in SEQ2SEQ\n\n    # HF does not do this? WTF?\n    model.eval()\n\n    model.to(device)\n    if fp16:\n        model.half()\n\n    def prepare_batch_without_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]:\n        result = tokenizer.batch_encode_plus(\n            prompts,\n            add_special_tokens=False,\n            return_tensors=\"pt\",\n            padding=True,\n            **tokenizer_kwargs,\n        )\n        result = {key: tensor.to(device) for key, tensor in result.items()}\n        return result\n\n    def prepare_batch_with_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]:\n        if len(prefix) > 0:\n            prompts = [f\"{prefix} {t}\" for t in prompts]\n        return prepare_batch_without_prefix(prompts)\n\n    prepare_batch_fn = prepare_batch_with_prefix\n    num_prefix_tokens: Optional[int] = None\n\n    # transformer model-specific exceptions\n    if isinstance(model, PreTrainedModel) and model.config_class:\n        if model.config_class.model_type == \"xlm\":\n            use_lang_emb = hasattr(model.config, \"use_lang_emb\") and model.config.use_lang_emb\n            if hasattr(model.config, \"lang2id\") and use_lang_emb:\n                model.config.lang_id = xlm_language\n            # Original HF code ignores the prefix, but it looks like a bug?\n            prepare_batch_fn = prepare_batch_without_prefix\n            num_prefix_tokens = 0\n        elif model.config_class.model_type in {\"xlnet\", \"transfo-xl\"}:\n            prefix = prefix if prefix else PREFIX\n        if model.__class__.__name__ in [\"TransfoXLLMHeadModel\"]:\n            # This actually doesn't work in the current version of transformers, which is probably a bug in the\n            # transformers library.\n            tokenizer_kwargs = {\"add_space_before_punct_symbol\": True}\n\n    if num_prefix_tokens is None:\n        num_prefix_tokens = len(tokenizer.tokenize(prefix))\n\n    batches = more_itertools.chunked(Tqdm.tqdm(prompts, desc=\"Pre-processing prompts\"), batch_size)\n    encoded_batches = map(prepare_batch_fn, batches)\n\n    for encoded_batch in Tqdm.tqdm(encoded_batches, desc=\"Processing batches\"):\n        if seq2seq_model:\n            length = max_length\n        else:\n            length = adjust_length_to_model(max_length + encoded_batch[\"input_ids\"].size(1), model)\n        with torch.inference_mode():\n            generated_sequences: torch.Tensor = model.generate(  # type: ignore\n                **encoded_batch,\n                max_length=length,\n                temperature=temperature,\n                top_k=k,\n                top_p=p,\n                repetition_penalty=repetition_penalty,\n                do_sample=True,\n                num_return_sequences=num_return_sequences,\n                synced_gpus=False,  # Needs to be True if we have more than one GPU running.\n            )\n\n        generated_sequences = generated_sequences.view(\n            -1, num_return_sequences, *generated_sequences.shape[1:]\n        ).to(\"cpu\")\n\n        def strip_special_tokens(t: torch.Tensor) -> torch.Tensor:\n            # amazing that torch has no capability for this\n            start = 0\n            while start < len(t) and int(t[start]) in {0, eos_token_id, pad_token_id}:\n                start += 1\n            end = len(t)\n            while int(t[end - 1]) in {0, eos_token_id, pad_token_id} and end > start:\n                end -= 1\n            return t[start:end]\n\n        # strip padding\n        generated_sequences_list = [\n            [strip_special_tokens(sequence) for sequence in per_prompt_sequences]\n            for per_prompt_sequences in generated_sequences\n        ]\n\n        # strip prefix\n        if not seq2seq_model:\n            generated_sequences_list = [\n                [sequence[num_prefix_tokens:] for sequence in per_prompt_sequences]\n                for per_prompt_sequences in generated_sequences_list\n            ]\n\n        texts = [\n            tokenizer.batch_decode(per_prompt_sequences, clean_up_tokenization_spaces=True)\n            for per_prompt_sequences in generated_sequences_list\n        ]\n\n        yield from texts\n\n\ndef _generate_with_model_name(model_name: str, *args, **kwargs) -> Iterable[List[str]]:\n    try:\n        model = AutoModelForSeq2SeqLM.from_pretrained(model_name)\n    except ValueError:\n        model = AutoModelForCausalLM.from_pretrained(model_name)\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n    return _generate(model, tokenizer, *args, **kwargs)\n\n\n@Step.register(\"transformers::run_generation\")\nclass RunGeneration(Step[Iterable[List[str]]]):\n    \"\"\"\n    A step that runs seq2seq Huggingface models in inference mode.\n\n    .. tip::\n        Registered as a :class:`~tango.step.Step` under the name \"transformers::run_generation\".\n    \"\"\"\n\n    FORMAT: Format = JsonFormat(\"gz\")\n    VERSION = \"001\"\n    SKIP_ID_ARGUMENTS = {\"batch_size\"}\n\n    # TODO: multiple GPUs\n\n    def run(  # type: ignore\n        self,\n        model: Union[str, Model],\n        prompts: Iterable[str],\n        *,\n        tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,\n        batch_size: int = 4,\n        max_length: int = 20,\n        temperature: float = 1.0,\n        repetition_penalty: float = 1.0,\n        k: int = 0,\n        p: float = 0.9,\n        prefix: str = \"\",\n        xlm_language: str = \"\",\n        seed: int = 42,\n        num_return_sequences: int = 1,\n        fp16: bool = False,\n    ) -> Iterable[List[str]]:\n        \"\"\"\n        Run a Huggingface seq2seq model in inference mode.\n\n        :param model:\n            The name of the model to run. Any name that works in the transformers library works here.\n            Or, you can directly provide the model to run.\n        :param prompts:\n            The prompts to run through the model. You can specify prompts directly in the config, but\n            more commonly the prompts are produced by another step that reads a dataset, for example.\n        :param tokenizer:\n            The tokenizer to run.\n        :param batch_size:\n            The number of sequences to process at one time. This has no bearing on the output, so\n            you can change this number without invalidating cached results.\n        :param max_length:\n            The maximum number of tokens/word pieces that the model will generate. For models that\n            extend the prompt, the prefix does not count towards this limit.\n        :param temperature:\n            Passed directly to transformer's ``generate()`` method.\n            The value used to model the next token probabilities.\n        :param repetition_penalty:\n            Passed directly to transformer's ``generate()`` method.\n            The parameter for repetition penalty. 1.0 means no penalty.\n        :param k:\n            Passed directly to transformer's ``generate()`` method.\n            The number of highest probability vocabulary tokens to keep for top-k-filtering.\n        :param p:\n            Passed directly to transformer's ``generate()`` method.\n            If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher\n            are kept for generation.\n        :param prefix:\n            A prefix that gets pre-pended to all prompts.\n        :param xlm_language:\n            For the XLM model, this is a way to specify the language you want to use.\n        :param seed:\n            Random seed\n        :param num_return_sequences:\n            The number of generations to return for each prompt.\n        :param fp16:\n            Whether to use 16-bit floats.\n\n        :returns:\n            Returns an iterator of lists of string. Each list contains the predictions for one prompt.\n        \"\"\"\n        if isinstance(model, str):\n            try:\n                model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model))\n            except ValueError:\n                model = cast(Model, AutoModelForCausalLM.from_pretrained(model))\n\n        tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path)\n\n        return _generate(\n            model,\n            tokenizer,\n            prompts,\n            batch_size=batch_size,\n            max_length=max_length,\n            temperature=temperature,\n            repetition_penalty=repetition_penalty,\n            k=k,\n            p=p,\n            prefix=prefix,\n            xlm_language=xlm_language,\n            seed=seed,\n            num_return_sequences=num_return_sequences,\n            fp16=fp16,\n        )\n\n\n@Step.register(\"transformers::run_generation_dataset\")\nclass RunGenerationDataset(Step[DatasetDict]):\n    \"\"\"\n    A step that runs seq2seq Huggingface models in inference mode.\n\n    This is similar to :class:`RunGeneration`, but it takes a dataset as input and produces\n    a new dataset as output, which contains the predictions in a new field.\n\n    .. tip::\n        Registered as a :class:`~tango.step.Step` under the name \"transformers::run_generation_dataset\".\n    \"\"\"\n\n    FORMAT: Format = SqliteDictFormat()\n    VERSION = \"002\"\n    SKIP_ID_ARGUMENTS = {\"batch_size\"}\n\n    def run(  # type: ignore\n        self,\n        model: Union[str, Model],\n        input: Union[DatasetDict, HfDatasetDict],\n        prompt_field: str,\n        *,\n        tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None,\n        output_field: Optional[str] = None,\n        splits: Optional[Union[str, Set[str]]] = None,\n        batch_size: int = 4,\n        max_length: int = 20,\n        temperature: float = 1.0,\n        repetition_penalty: float = 1.0,\n        k: int = 0,\n        p: float = 0.9,\n        prefix: str = \"\",\n        xlm_language: str = \"\",\n        seed: int = 42,\n        num_return_sequences: int = 1,\n        fp16: bool = False,\n    ) -> DatasetDict:\n        \"\"\"\n        Augment an input dataset with generations from a Huggingface seq2seq model.\n\n        :param model:\n            The name of the model to run. Any name that works in the transformers library works here.\n            Or, you can directly provide the model to run.\n        :param input:\n            The input dataset.\n        :param prompt_field:\n            The field in the dataset that contains the text of the prompts.\n        :param tokenizer:\n            The tokenizer to run.\n        :param output_field:\n            The field in the dataset that we will write the predictions into. In the result, this field\n            will contain ``List[str]``.\n        :param splits:\n            A split, or set of splits, to process. If this is not specified, we will process all splits.\n        :param batch_size:\n            The number of sequences to process at one time. This has no bearing on the output, so\n            you can change this number without invalidating cached results.\n        :param max_length:\n            The maximum number of tokens/word pieces that the model will generate. For models that\n            extend the prompt, the prefix does not count towards this limit.\n        :param temperature:\n            Passed directly to transformer's `generate()` method.\n            The value used to model the next token probabilities.\n        :param repetition_penalty:\n            Passed directly to transformer's `generate()` method.\n            The parameter for repetition penalty. 1.0 means no penalty.\n        :param k:\n            Passed directly to transformer's `generate()` method.\n            The number of highest probability vocabulary tokens to keep for top-k-filtering.\n        :param p:\n            Passed directly to transformer's `generate()` method.\n            If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher\n            are kept for generation.\n        :param prefix:\n            A prefix that gets pre-pended to all prompts.\n        :param xlm_language:\n            For the XLM model, this is a way to specify the language you want to use.\n        :param seed:\n            Random seed\n        :param num_return_sequences:\n            The number of generations to return for each prompt.\n        :param fp16:\n            Whether to use 16-bit floats.\n\n        :returns:\n            Returns a dataset with an extra field containing the predictions.\n        \"\"\"\n\n        if isinstance(model, str):\n            try:\n                model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model))\n            except ValueError:\n                model = cast(Model, AutoModelForCausalLM.from_pretrained(model))\n\n        tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path)\n\n        if isinstance(input, HfDatasetDict):\n            input = DatasetDict(input, {})\n        if splits is None:\n            splits = input.keys()\n        elif isinstance(splits, str):\n            splits = {splits}\n\n        result: Dict[str, Sequence] = {}\n        for split_name, input_split in input.items():\n            if split_name in splits:\n                output_split = SqliteSparseSequence(self.work_dir / f\"{split_name}.sqlite\")\n                if len(output_split) > 0:\n                    logger.info(\n                        \"Found %d items already generated. Will generate %d more.\",\n                        len(output_split),\n                        len(input_split) - len(output_split),\n                    )\n                if len(output_split) > 0:\n                    if isinstance(input_split, Dataset):\n                        input_split = input_split.select(range(len(output_split), len(input_split)))\n                    else:\n                        input_split = input_split[len(output_split) :]\n                prompts = MappedSequence(lambda i: i[prompt_field], input_split)\n                generations = _generate(\n                    model,\n                    tokenizer,\n                    prompts,\n                    batch_size=batch_size,\n                    max_length=max_length,\n                    temperature=temperature,\n                    repetition_penalty=repetition_penalty,\n                    k=k,\n                    p=p,\n                    prefix=prefix,\n                    xlm_language=xlm_language,\n                    seed=seed,\n                    num_return_sequences=num_return_sequences,\n                    fp16=fp16,\n                )\n                for instance, generation in zip(input_split, generations):\n                    output_split.append(\n                        {**instance, **{output_field or prompt_field + \"_generated\": generation}}\n                    )\n                result[split_name] = output_split\n            else:\n                result[split_name] = input_split\n\n        return DatasetDict(result, input.metadata)\n"
  },
  {
    "path": "tango/integrations/transformers/soft_prompt.py",
    "content": "import inspect\nimport logging\nimport random\nfrom typing import Any, Dict, Optional\n\nimport torch\nfrom torch import nn\nfrom transformers import PreTrainedModel\nfrom transformers.modeling_outputs import (\n    CausalLMOutputWithCrossAttentions,\n    Seq2SeqModelOutput,\n)\n\nfrom tango.integrations.torch import Model\n\nlogger = logging.getLogger(__name__)\n\n\ndef _get_bound_args_with_decorators(fn, *args, **kwargs):\n    while True:\n        try:\n            fn = fn.__wrapped__\n        except AttributeError:\n            break\n    signature = inspect.Signature.from_callable(fn)\n    return signature.bind(*args, **kwargs)\n\n\ndef add_soft_prompt(\n    model: Model,\n    prompt_length: int,\n    *,\n    only_prompt_is_trainable: bool = True,\n    initialize_from_top_embeddings: Optional[int] = 5000,\n    random_seed: int = 1940,\n) -> None:\n    \"\"\"\n    Takes a regular huggingface transformer, and equips it with a soft prompt.\n\n    Example:\n\n    .. testcode::\n        import transformers\n\n        model = transformers.AutoModelForCausalLM.from_pretrained(\"gpt2\")\n        tokenizer = transformers.AutoTokenizer.from_pretrained(\"gpt2\")\n        generated = model.generate(tokenizer.encode(\"It was the best of times.\", return_tensors=\"pt\"))\n        original_output = tokenizer.decode(generated[0])\n\n        add_soft_prompt(model, prompt_length=3)\n        generated = model.generate(tokenizer.encode(\"It was the best of times.\", return_tensors=\"pt\"))\n        prompted_output = tokenizer.decode(generated[0])\n\n    :param model: the original huggingface transformer. This model is augmented in-place!\n    :param prompt_length: the length of the soft prompt, in tokens\n    :param only_prompt_is_trainable: freezes the original model's weights, leaving only the prompt trainable\n    :param initialize_from_top_embeddings: Prompt embeddings are initialized from a random selection of the top n\n                                           word piece embeddings from the original model. This is how you set n.\n    :param random_seed: random seed used to initialize the prompt embeddings\n\n    \"\"\"\n    assert isinstance(model, PreTrainedModel)\n\n    original_embedding: nn.Embedding = model.get_input_embeddings()  # type: ignore\n    prompt_embedding = nn.Parameter(\n        torch.empty(\n            1,\n            prompt_length,\n            original_embedding.embedding_dim,\n            dtype=original_embedding.weight.dtype,\n            device=original_embedding.weight.device,\n        )\n    )\n    r = random.Random(random_seed)\n    if initialize_from_top_embeddings is None:\n        initialize_from_top_embeddings = original_embedding.num_embeddings\n    indices = torch.tensor(r.sample(range(initialize_from_top_embeddings), prompt_length))\n    with torch.no_grad():\n        prompt_embedding.copy_(original_embedding(indices).unsqueeze(0))\n\n    if only_prompt_is_trainable:\n        for param in model.parameters():\n            param.requires_grad = False\n\n    # find unique parameter name\n    parameter_name = \"prompt_embedding\"\n    parameter_name_index = 0\n    while True:\n        try:\n            model.get_parameter(parameter_name)\n        except AttributeError:\n            break\n        parameter_name_index += 1\n        parameter_name = f\"prompt_embedding_{parameter_name_index}\"\n    model.register_parameter(parameter_name, prompt_embedding)\n\n    def patch_tensor(kwargs: Dict[str, torch.Tensor], key: str, value: Any = 0) -> None:\n        t = kwargs.get(key)\n        if t is None:\n            return\n        prefix = t.new_full((t.size(0), prompt_length) + t.shape[2:], value)\n        kwargs[key] = torch.cat([prefix, t], dim=1)\n\n    def patch_tensor_with_indices(\n        kwargs: Dict[str, torch.Tensor], key: str, offset: int = 0\n    ) -> None:\n        t = kwargs.get(key)\n        if t is None:\n            return\n        kwargs[key] = torch.cat(\n            [\n                torch.arange(0, prompt_length, dtype=t.dtype)\n                .unsqueeze(0)\n                .expand(t.size(0), prompt_length),\n                t + offset,\n            ],\n            dim=1,\n        )\n\n    old_forward = model.forward\n\n    def new_forward(*args, **kwargs):\n        # Massage the input to include the prompt\n        if kwargs.get(\"past_key_values\") is not None:\n            # If we have already been running this model, we don't need to do anything with the prefix now.\n            return old_forward(*args, **kwargs)\n        if kwargs.get(\"encoder_outputs\") is not None:\n            # For encoder/decoder models, this runs only on the encoder. If we already have encoder outputs,\n            # we don't have to do anything.\n            return old_forward(*args, **kwargs)\n\n        inputs_embeds: Optional[torch.Tensor] = None\n        input_ids = kwargs.pop(\"input_ids\", None)\n        if input_ids is not None:\n            inputs_embeds = original_embedding(input_ids)\n\n        inputs_embeds = kwargs.get(\"inputs_embeds\", inputs_embeds)\n        if inputs_embeds is not None:\n            kwargs[\"inputs_embeds\"] = torch.cat(\n                [prompt_embedding.expand(inputs_embeds.size(0), -1, -1), inputs_embeds], dim=1\n            )\n\n        patch_tensor(kwargs, \"labels\")\n        patch_tensor(kwargs, \"attention_mask\", 1)\n        patch_tensor(kwargs, \"token_type_ids\")\n        patch_tensor_with_indices(kwargs, \"position_ids\", prompt_length)\n\n        # Run the model\n        result = old_forward(*args, **kwargs)\n\n        # Massage the output to look like the prompt was never there\n        unpatch_tensor = lambda t: t[:, prompt_length:]  # noqa: E731\n        unpatch_attention_tensor = lambda t: t[:, :, prompt_length:]  # noqa: E731\n        unpatch_kv_tensor = unpatch_attention_tensor\n        if isinstance(result, CausalLMOutputWithCrossAttentions):\n            if result.logits is not None:\n                result.logits = unpatch_tensor(result.logits)\n            if result.hidden_states is not None:\n                result.hidden_states = tuple(map(unpatch_tensor, result.hidden_states))\n            if result.attentions is not None:\n                result.attentions = tuple(map(unpatch_attention_tensor, result.attentions))\n            if result.cross_attentions is not None:\n                result.cross_attentions = tuple(\n                    map(unpatch_attention_tensor, result.cross_attentions)\n                )\n            return result\n        elif isinstance(result, Seq2SeqModelOutput):\n            if result.last_hidden_state is not None:\n                result.last_hidden_state = unpatch_tensor(result.last_hidden_state)\n            if result.past_key_values is not None:\n                result.past_key_values = tuple(map(unpatch_kv_tensor, result.past_key_values))\n            if result.encoder_hidden_states is not None:\n                result.hidden_states = tuple(map(unpatch_tensor, result.hidden_states))\n            if result.encoder_attentions is not None:\n                result.attentions = tuple(map(unpatch_attention_tensor, result.attentions))\n            if result.cross_attentions is not None:\n                result.cross_attentions = tuple(\n                    map(unpatch_attention_tensor, result.cross_attentions)\n                )\n            return result\n        else:\n            logger.warning(\n                \"Unexpected result type from the transformer in soft_prompt_transformer: `%s`\",\n                result.__class__,\n            )\n            return result\n\n    model.forward = new_forward  # type: ignore\n\n    # For encoder/decoder models, HF doesn't call `forward()` like it should when you use `generate()`. Instead, it\n    # calls the encoder separately, and then passes the results into `forward()`. So in that case, we have to patch\n    # this too.\n    if model.config.is_encoder_decoder:\n        old_generate = model.generate\n\n        def new_generate(*args, **kwargs):\n            args = (model,) + args\n            ba = _get_bound_args_with_decorators(old_generate, *args, **kwargs)\n            del ba.arguments[\"self\"]\n\n            if \"encoder_outputs\" in ba.arguments:\n                # For encoder/decoder models, this runs only on the encoder. If we already have encoder outputs,\n                # we don't have to do anything.\n                return old_generate(*ba.args, **ba.kwargs)\n\n            inputs_embeds: Optional[torch.Tensor] = None\n            inputs = ba.arguments.pop(\"inputs\", None)\n            if inputs is not None:\n                inputs_embeds = original_embedding(inputs)\n\n            inputs_embeds = ba.arguments.pop(\"inputs_embeds\", inputs_embeds)\n            if inputs_embeds is not None:\n                inputs_embeds = torch.cat(\n                    [prompt_embedding.expand(inputs_embeds.size(0), -1, -1), inputs_embeds], dim=1\n                )\n\n            assert callable(model.get_encoder)\n            encoder = model.get_encoder()\n            kwargs = ba.kwargs\n            kwargs[\"encoder_outputs\"] = encoder(inputs_embeds=inputs_embeds, return_dict=True)\n\n            return old_generate(*ba.args, **kwargs)\n\n        model.generate = new_generate  # type: ignore\n\n\ndef _with_soft_prompt(\n    model: Model,\n    prompt_length: int,\n    *,\n    only_prompt_is_trainable: bool = True,\n    initialize_from_top_embeddings: Optional[int] = 5000,\n    random_seed: int = 1940,\n) -> Model:\n    \"\"\"To initialize a soft-prompt model as a Registrable (i.e., to use it from a config file), we need a variant\n    of this function that returns the resulting model. This is that variant.\"\"\"\n    add_soft_prompt(\n        model,\n        prompt_length,\n        only_prompt_is_trainable=only_prompt_is_trainable,\n        initialize_from_top_embeddings=initialize_from_top_embeddings,\n        random_seed=random_seed,\n    )\n    return model\n\n\nModel.register(\"transformers::with_soft_prompt\")(_with_soft_prompt)  # type: ignore\n"
  },
  {
    "path": "tango/integrations/transformers/tokenizer.py",
    "content": "from transformers import AutoTokenizer\nfrom transformers.tokenization_utils_base import PreTrainedTokenizerBase\n\nfrom tango.common import Registrable\n\n\nclass Tokenizer(PreTrainedTokenizerBase, Registrable):\n    \"\"\"\n    A :class:`~tango.common.Registrable` version of transformers'\n    :class:`~transformers.PreTrainedTokenizerBase`.\n    \"\"\"\n\n    default_implementation = \"auto\"\n    \"\"\"\n    The default registered implementation just calls\n    :meth:`transformers.AutoTokenizer.from_pretrained()`.\n    \"\"\"\n\n\nTokenizer.register(\"auto\", constructor=\"from_pretrained\")(AutoTokenizer)\n"
  },
  {
    "path": "tango/integrations/wandb/__init__.py",
    "content": "\"\"\"\n.. important::\n    To use this integration you should install ``tango`` with the \"wandb\" extra\n    (e.g. ``pip install tango[wandb]``) or just install the ``wandb`` library after the fact\n    (e.g. ``pip install wandb``).\n\nComponents for Tango integration with `Weights & Biases <https://wandb.ai/>`_.\n\nOverview\n--------\n\nThe main components provided by this integration are the :class:`WandbWorkspace` and\nthe :class:`WandbTrainCallback`.\n\nThe :class:`WandbWorkspace` is a :class:`~tango.workspace.Workspace` implementation that is\ngreat for collaboration. It tracks Tango runs and steps in the W&B project of your choosing\nand uses W&B Artifacts to cache step results in the cloud so that they're accessible anywhere.\n\nAnd if you're training PyTorch models via the :class:`~tango.integrations.torch.TorchTrainStep`,\nyou can use the :class:`WandbTrainCallback` to track metrics throughout the run.\n\n\"\"\"\n\nfrom tango.common.exceptions import IntegrationMissingError\n\ntry:\n    import wandb\nexcept ModuleNotFoundError:\n    raise IntegrationMissingError(\"wandb\")\n\n__all__ = [\"WandbWorkspace\", \"WandbStepCache\"]\n\nfrom .step_cache import WandbStepCache\nfrom .workspace import WandbWorkspace\n\ntry:\n    import torch\nexcept ModuleNotFoundError:\n    pass\nelse:\n    from .torch_train_callback import WandbTrainCallback\n\n    __all__.append(\"WandbTrainCallback\")\n\ntry:\n    import flax\n    import jax\n    import tensorflow  # flax has a tensorflow dependency\nexcept ModuleNotFoundError:\n    pass\nelse:\n    from .flax_train_callback import WandbFlaxTrainCallback\n\n    __all__.append(\"WandbFlaxTrainCallback\")\n"
  },
  {
    "path": "tango/integrations/wandb/flax_train_callback.py",
    "content": "from typing import Any, Dict, List, Optional\n\nimport jax\nimport wandb\nfrom flax import jax_utils\n\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.integrations.flax.train_callback import TrainCallback\n\nfrom .workspace import WandbWorkspace\n\n\n@TrainCallback.register(\"wandb::log_flax\")\nclass WandbFlaxTrainCallback(TrainCallback):\n    \"\"\"\n    A flax :class:`~tango.integrations.flax.TrainCallback` for use with\n    the :class:`~tango.integrations.flax.FlaxTrainStep` that logs training and\n    validation metrics to W&B.\n\n    This can be used with any :class:`~tango.workspace.Workspace` implementation,\n    including :class:`WandbWorkspace`.\n\n    .. tip::\n\n        Registered as a :class:`~tango.integrations.flax.TrainCallback`\n        under the name \"wandb::log_flax\".\n\n    .. important::\n\n        When this callback is used with the :class:`WandbWorkspace` it will log metrics\n        to the same W&B project that the workspace uses. The ``group`` and ``name``\n        parameters will also automatically be set, so a :class:`~tango.common.exceptions.ConfigurationError`\n        will be raised if any of ``project``, ``entity``, ``group``, or ``name`` are set in this callback.\n\n    :param project:\n            W&B project to associated this run with.\n\n    :param entity:\n        W&B entity (user or organization) to associated this run with.\n\n    :param group:\n        W&B group to associated this run with.\n\n    :param name:\n        Set the name of the run in W&B. If not set, the default will be the name of the step.\n\n    :param notes:\n        Arbitrary notes to add in W&B to this run.\n\n    :param tags:\n        Arbitrary tags to add in W&B to this run.\n\n    :param watch_model:\n        If ``True``, ``wandb.watch()`` is called to collect gradients and other information\n        about the model throughout training.\n        See `docs.wandb.ai/ref/python/watch <https://docs.wandb.ai/ref/python/watch>`_.\n\n    :param wandb_config:\n        Arbitrary configuration fields to set in W&B for this run.\n        See `docs.wandb.ai/guides/track/config <https://docs.wandb.ai/guides/track/config>`_.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        project: Optional[str] = None,\n        entity: Optional[str] = None,\n        group: Optional[str] = None,\n        name: Optional[str] = None,\n        notes: Optional[str] = None,\n        tags: Optional[List[str]] = None,\n        watch_model: bool = False,\n        wandb_config: Optional[Dict[str, Any]] = None,\n        **kwargs,\n    ) -> None:\n        super().__init__(*args, **kwargs)\n\n        if isinstance(self.workspace, WandbWorkspace) or wandb.run is not None:\n            err_msg_template = \"Cannot set '{var_name}' in WandbTrainCallback \"\n            if isinstance(self.workspace, WandbWorkspace):\n                err_msg_template += \"since it has already been set from the WandbWorkspace.\"\n            else:\n                err_msg_template += \"since a W&B run has already been initialized.\"\n            for var, var_name in [\n                (project, \"project\"),\n                (entity, \"entity\"),\n                (group, \"group\"),\n                (name, \"name\"),\n            ]:\n                if var is not None:\n                    raise ConfigurationError(err_msg_template.format(var_name=var_name))\n\n        self.project = (\n            project if not isinstance(self.workspace, WandbWorkspace) else self.workspace.project\n        )\n        self.entity = (\n            entity if not isinstance(self.workspace, WandbWorkspace) else self.workspace.entity\n        )\n        self.group = group or self.step_id\n        self.notes = notes\n        self.tags = tags\n        self.watch_model = watch_model\n        self.wandb_config = self.train_config.as_dict()\n\n        if wandb_config is not None:\n            self.wandb_config.update(wandb_config)\n        if wandb.run is None:\n            self.wandb_config[\"job_type\"] = \"train_metrics\"\n\n        self.run_name: str = name or self.step_name or \"train\"\n\n        self.run_id: str = (\n            wandb.run.id if wandb.run is not None else self.step_id  # type: ignore[attr-defined]\n        )\n        self.resume: Optional[str] = None\n        self.should_finalize_run: bool = (\n            wandb.run is None\n        )  # if we have to start out own W&B run, we need to finish it\n\n    def state_dict(self) -> Dict[str, Any]:\n        return {}\n\n    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:\n        self.resume = \"allow\"\n\n    def pre_train_loop(self) -> None:\n        if wandb.run is None:\n            if self.run_id is None:\n                self.run_id = self.step_id\n\n            wandb.init(\n                id=self.run_id,\n                dir=str(self.work_dir),\n                project=self.project,\n                entity=self.entity,\n                group=self.group,\n                name=self.run_name,\n                notes=self.notes,\n                config=self.wandb_config,\n                tags=self.tags,\n                job_type=\"train_metrics\",\n            )\n        else:\n            # We are already running inside of a W&B run, possibly because\n            # we're using the WandbWorkspace.\n            wandb.config.update(self.wandb_config)\n            if self.tags:\n                wandb.run.tags = (wandb.run.tags or tuple()) + tuple(self.tags)\n            if self.notes:\n                wandb.run.notes = self.notes\n\n        if self.watch_model:\n            wandb.watch(self.model)\n\n    def post_train_loop(self, step: int, epoch: int) -> None:\n        if self.should_finalize_run:\n            wandb.finish()\n\n    def log_batch(self, step: int, epoch: int, train_metrics: Dict) -> None:\n        if len(jax.devices()) > 1:\n            train_metrics = jax_utils.unreplicate(train_metrics)\n        metrics = {\"train/loss\": train_metrics[\"loss\"], \"epoch\": epoch}\n        wandb.log(metrics, step=step + 1)\n\n    def post_val_loop(\n        self, step: int, epoch: int, val_metric: Optional[float], best_val_metric: Optional[float]\n    ) -> None:\n        wandb.log(\n            {\n                f\"val/{self.train_config.val_metric_name}\": val_metric,\n                f\"val/best_{self.train_config.val_metric_name}\": best_val_metric,\n                \"epoch\": epoch,\n            },\n            step=step + 1,\n        )\n"
  },
  {
    "path": "tango/integrations/wandb/step_cache.py",
    "content": "import logging\nfrom typing import Any, Optional, Union\n\nimport wandb\nfrom retry import retry\nfrom wandb.errors import Error as WandbError\n\nfrom tango.common.aliases import PathOrStr\nfrom tango.common.util import make_safe_filename, tango_cache_dir\nfrom tango.step import Step\nfrom tango.step_cache import StepCache\nfrom tango.step_caches.remote_step_cache import RemoteNotFoundError, RemoteStepCache\nfrom tango.step_info import StepInfo\n\nfrom .util import ArtifactKind, check_environment, is_missing_artifact_error\n\nlogger = logging.getLogger(__name__)\n\n\n@StepCache.register(\"wandb\")\nclass WandbStepCache(RemoteStepCache):\n    \"\"\"\n    This is a :class:`~tango.step_cache.StepCache` that's used by :class:`WandbWorkspace`.\n    It stores the results of steps on W&B as Artifacts.\n\n    It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a\n    step's resulting subsequent times should be fast.\n\n    :param project: The W&B project to use.\n    :param entity: The W&B entity (user or organization account) to use.\n\n    .. tip::\n        Registered as :class:`~tango.step_cache.StepCache` under the name \"wandb\".\n    \"\"\"\n\n    def __init__(self, project: str, entity: str):\n        check_environment()\n        super().__init__(\n            tango_cache_dir()\n            / \"wandb_cache\"\n            / make_safe_filename(entity)\n            / make_safe_filename(project)\n        )\n        self.project = project\n        self.entity = entity\n\n    @property\n    def wandb_client(self) -> wandb.Api:\n        return wandb.Api(overrides={\"entity\": self.entity, \"project\": self.project})\n\n    @property\n    def client(self):\n        \"\"\"\n        To maintain compatibility\n        \"\"\"\n        return self.wandb_client\n\n    @property\n    def wandb_project_url(self) -> str:\n        \"\"\"\n        The URL of the W&B project this workspace uses.\n        \"\"\"\n        app_url = self.wandb_client.client.app_url\n        app_url = app_url.rstrip(\"/\")\n        return f\"{app_url}/{self.entity}/{self.project}\"\n\n    def _step_artifact_name(self, step: Union[Step, StepInfo]) -> str:\n        if isinstance(step, Step):\n            return step.class_name\n        else:\n            return step.step_class_name\n\n    def _step_result_remote(  # type: ignore\n        self, step: Union[Step, StepInfo]\n    ) -> Optional[wandb.Artifact]:\n        artifact_kind = (step.metadata or {}).get(\"artifact_kind\", ArtifactKind.STEP_RESULT.value)\n        try:\n            return self.wandb_client.artifact(\n                f\"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}\",\n                type=artifact_kind,\n            )\n        except WandbError as exc:\n            if is_missing_artifact_error(exc):\n                return None\n            else:\n                raise\n\n    def create_step_result_artifact(self, step: Step, objects_dir: Optional[PathOrStr] = None):\n        self._upload_step_remote(step, objects_dir)\n\n    def get_step_result_artifact(self, step: Union[Step, StepInfo]) -> Optional[wandb.Artifact]:\n        artifact_kind = (step.metadata or {}).get(\"artifact_kind\", ArtifactKind.STEP_RESULT.value)\n        try:\n            return self.wandb_client.artifact(\n                f\"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}\",\n                type=artifact_kind,\n            )\n        except WandbError as exc:\n            if is_missing_artifact_error(exc):\n                return None\n            else:\n                raise\n\n    def _upload_step_remote(self, step: Step, objects_dir: Optional[PathOrStr] = None) -> Any:\n        \"\"\"\n        Create an artifact for the result of a step.\n        \"\"\"\n        artifact_kind = (step.metadata or {}).get(\"artifact_kind\", ArtifactKind.STEP_RESULT.value)\n        artifact = wandb.Artifact(self._step_artifact_name(step), type=artifact_kind)\n\n        # Add files\n        if objects_dir is not None:\n            artifact.add_dir(str(objects_dir))\n\n        # Log/persist the artifact to W&B.\n        artifact.save()\n        artifact.wait()\n\n        # Add an alias for the step's unique ID.\n        # Only after we've logged the artifact can we add an alias.\n        artifact.aliases.append(step.unique_id)\n        artifact.save()\n        artifact.wait()\n\n    def get_step_result_artifact_url(self, step: Union[Step, StepInfo]) -> str:\n        artifact_kind = (step.metadata or {}).get(\"artifact_kind\", ArtifactKind.STEP_RESULT.value)\n        return (\n            f\"{self.wandb_project_url}/artifacts/{artifact_kind}\"\n            f\"/{self._step_artifact_name(step)}/{step.unique_id}\"\n        )\n\n    @retry(exceptions=(wandb.errors.CommError,), delay=10, backoff=2, max_delay=120)\n    def use_step_result_artifact(self, step: Union[Step, StepInfo]) -> None:\n        \"\"\"\n        \"Use\" the artifact corresponding to the result of a step.\n        \"\"\"\n        if wandb.run is None:\n            raise RuntimeError(\"This can only be called from within a W&B run\")\n        wandb.run.use_artifact(\n            f\"{self.entity}/{self.project}/{self._step_artifact_name(step)}:{step.unique_id}\"\n        )\n\n    def _download_step_remote(self, step_result, target_dir: PathOrStr):\n        try:\n            step_result.download(root=target_dir)\n        except (WandbError, ValueError):\n            raise RemoteNotFoundError()\n\n    def __len__(self) -> int:\n        completed_cacheable_step_runs = self.wandb_client.runs(\n            f\"{self.entity}/{self.project}\",\n            filters={  # type: ignore\n                \"config.job_type\": \"step\",\n                \"config.cacheable\": True,\n                \"state\": \"finished\",\n            },\n        )\n        return len(list(completed_cacheable_step_runs))\n"
  },
  {
    "path": "tango/integrations/wandb/torch_train_callback.py",
    "content": "from typing import Any, Dict, List, Optional\n\nimport torch\nimport wandb\n\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.integrations.torch.train_callback import TrainCallback\nfrom tango.integrations.torch.util import peak_gpu_memory\n\nfrom .util import check_environment\nfrom .workspace import WandbWorkspace\n\n\n@TrainCallback.register(\"wandb::log\")\nclass WandbTrainCallback(TrainCallback):\n    \"\"\"\n    A torch :class:`~tango.integrations.torch.TrainCallback` for use with\n    the :class:`~tango.integrations.torch.TorchTrainStep` that logs training and\n    validation metrics to W&B.\n\n    This can be used with any :class:`~tango.workspace.Workspace` implementation,\n    including :class:`WandbWorkspace`.\n\n    .. tip::\n\n        Registered as a :class:`~tango.integrations.torch.TrainCallback`\n        under the name \"wandb::log\".\n\n    .. important::\n\n        When this callback is used with the :class:`WandbWorkspace` it will log metrics\n        to the same W&B project that the workspace uses. The ``group`` and ``name``\n        parameters will also automatically be set, so a :class:`~tango.common.exceptions.ConfigurationError`\n        will be raised if any of ``project``, ``entity``, ``group``, or ``name`` are set in this callback.\n\n    :param project:\n        W&B project to associated this run with.\n\n    :param entity:\n        W&B entity (user or organization) to associated this run with.\n\n    :param group:\n        W&B group to associated this run with.\n\n    :param name:\n        Set the name of the run in W&B. If not set, the default will be the name of the step.\n\n    :param notes:\n        Arbitrary notes to add in W&B to this run.\n\n    :param tags:\n        Arbitrary tags to add in W&B to this run.\n\n    :param watch_model:\n        If ``True``, ``wandb.watch()`` is called to collect gradients and other information\n        about the model throughout training.\n        See `docs.wandb.ai/ref/python/watch <https://docs.wandb.ai/ref/python/watch>`_.\n\n    :param wandb_config:\n        Arbitrary configuration fields to set in W&B for this run.\n        See `docs.wandb.ai/guides/track/config <https://docs.wandb.ai/guides/track/config>`_.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        *args,\n        project: Optional[str] = None,\n        entity: Optional[str] = None,\n        group: Optional[str] = None,\n        name: Optional[str] = None,\n        notes: Optional[str] = None,\n        tags: Optional[List[str]] = None,\n        watch_model: bool = False,\n        wandb_config: Optional[Dict[str, Any]] = None,\n        **kwargs,\n    ) -> None:\n        super().__init__(*args, **kwargs)\n\n        if self.is_local_main_process:\n            check_environment()\n\n        if isinstance(self.workspace, WandbWorkspace) or wandb.run is not None:\n            err_msg_template = \"Cannot set '{var_name}' in WandbTrainCallback \"\n            if isinstance(self.workspace, WandbWorkspace):\n                err_msg_template += \"since it has already been set from the WandbWorkspace.\"\n            else:\n                err_msg_template += \"since a W&B run has already been initialized.\"\n            for var, var_name in [\n                (project, \"project\"),\n                (entity, \"entity\"),\n                (group, \"group\"),\n                (name, \"name\"),\n            ]:\n                if var is not None:\n                    raise ConfigurationError(err_msg_template.format(var_name=var_name))\n\n        self.project = (\n            project if not isinstance(self.workspace, WandbWorkspace) else self.workspace.project\n        )\n        self.entity = (\n            entity if not isinstance(self.workspace, WandbWorkspace) else self.workspace.entity\n        )\n        self.group = group or self.step_id\n        self.notes = notes or self._get_default_notes()\n        self.tags = tags\n        self.watch_model = watch_model\n\n        self.wandb_config = self.train_config.as_dict()\n        del self.wandb_config[\"worker_id\"]\n        if wandb_config is not None:\n            self.wandb_config.update(wandb_config)\n        if wandb.run is None:\n            self.wandb_config[\"job_type\"] = \"train_metrics\"\n\n        self.run_name: str = name or self.step_name or \"train\"\n        if self.train_config.is_distributed:\n            self.run_name += f\" (rank {self.train_config.worker_id})\"\n\n        self.run_id: str = (\n            wandb.run.id  # type: ignore[attr-defined]\n            if wandb.run is not None\n            else self.step_id + f\"-rank{self.train_config.worker_id}\"\n        )\n        self.resume: Optional[str] = None\n        self.should_finalize_run: bool = (\n            wandb.run is None\n        )  # if we have to start out own W&B run, we need to finish it\n\n    def state_dict(self) -> Dict[str, Any]:\n        return {}\n\n    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:\n        self.resume = \"allow\"\n\n    def pre_train_loop(self) -> None:\n        if wandb.run is None:\n            if self.run_id is None:\n                self.run_id = self.step_id + f\"-rank{self.train_config.worker_id}\"\n            # Initialize a new W&B run.\n            wandb.init(\n                id=self.run_id,\n                dir=str(self.work_dir),\n                project=self.project,\n                entity=self.entity,\n                group=self.group,\n                name=self.run_name,\n                notes=self.notes,\n                config=self.wandb_config,\n                tags=self.tags,\n                job_type=\"train_metrics\",\n            )\n        else:\n            # We are already running inside of a W&B run, possibly because\n            # we're using the WandbWorkspace.\n            wandb.config.update(self.wandb_config)\n            if self.tags:\n                wandb.run.tags = (wandb.run.tags or tuple()) + tuple(self.tags)\n            if self.notes:\n                wandb.run.notes = self.notes\n\n        if self.watch_model:\n            wandb.watch(self.training_engine.model)\n\n        # Log GPU memory statistics.\n        if torch.cuda.is_available():\n            torch.cuda.reset_peak_memory_stats()\n        peak_gpu_mbs = peak_gpu_memory()\n        if self.is_local_main_process:\n            metrics = {f\"sys/worker{rank}_peak_gpu_mem\": mbs for rank, mbs in peak_gpu_mbs.items()}\n            metrics[\"epoch\"] = 0\n            wandb.log(metrics, step=0)\n\n    def post_train_loop(self, step: int, epoch: int) -> None:\n        if self.should_finalize_run:\n            wandb.finish()\n\n    def log_batch(\n        self, step: int, epoch: int, batch_loss: float, batch_outputs: List[Dict[str, Any]]\n    ) -> None:\n        peak_gpu_mbs = peak_gpu_memory()\n        if self.is_local_main_process:\n            metrics = {\n                \"train/loss\": batch_loss,\n                \"train/lr\": self.training_engine.optimizer.param_groups[0][\"lr\"],\n                \"epoch\": epoch,\n            }\n            metrics.update(\n                {f\"sys/worker{rank}_peak_gpu_mem\": mbs for rank, mbs in peak_gpu_mbs.items()}\n            )\n            wandb.log(\n                metrics,\n                step=step + 1,\n            )\n\n    def post_val_loop(\n        self, step: int, epoch: int, val_metric: float, best_val_metric: float\n    ) -> None:\n        if self.is_local_main_process:\n            wandb.log(\n                {\n                    f\"val/{self.train_config.val_metric_name}\": val_metric,\n                    f\"val/best_{self.train_config.val_metric_name}\": best_val_metric,\n                    \"epoch\": epoch,\n                },\n                step=step + 1,\n            )\n\n    def _get_default_notes(self) -> str:\n        notes = (\n            f'Metrics for Tango step \"{self.step_name}\" from worker {self.train_config.worker_id}.'\n        )\n        if isinstance(self.workspace, WandbWorkspace):\n            notes += f\"\\nMain run for step: {self.workspace.wandb_project_url}/runs/{self.step_id}/overview\"\n        return notes\n"
  },
  {
    "path": "tango/integrations/wandb/util.py",
    "content": "import os\nimport re\nimport warnings\nfrom enum import Enum\n\nfrom wandb.errors import Error as WandbError\n\n_API_KEY_WARNING_ISSUED = False\n_SILENCE_WARNING_ISSUED = False\n\n\ndef is_missing_artifact_error(err: WandbError):\n    \"\"\"\n    Check if a specific W&B error is caused by a 404 on the artifact we're looking for.\n    \"\"\"\n    # This is brittle, but at least we have a test for it.\n\n    # This is a workaround for a bug in the wandb API\n    if err.message == \"'NoneType' object has no attribute 'get'\":\n        return True\n\n    if re.search(r\"^artifact '.*' not found in '.*'$\", err.message):\n        return True\n\n    return (\"does not contain artifact\" in err.message) or (\n        \"Unable to fetch artifact with name\" in err.message\n    )\n\n\ndef check_environment():\n    global _API_KEY_WARNING_ISSUED, _SILENCE_WARNING_ISSUED\n    if \"WANDB_API_KEY\" not in os.environ and not _API_KEY_WARNING_ISSUED:\n        warnings.warn(\n            \"Missing environment variable 'WANDB_API_KEY' required to authenticate to Weights & Biases.\",\n            UserWarning,\n        )\n        _API_KEY_WARNING_ISSUED = True\n    if \"WANDB_SILENT\" not in os.environ and not _SILENCE_WARNING_ISSUED:\n        warnings.warn(\n            \"The Weights & Biases client may produce a lot of log messages. \"\n            \"You can silence these by setting the environment variable 'WANDB_SILENT=true'\",\n            UserWarning,\n        )\n        _SILENCE_WARNING_ISSUED = True\n\n\nclass RunKind(Enum):\n    STEP = \"step\"\n    TANGO_RUN = \"tango_run\"\n\n\nclass ArtifactKind(Enum):\n    STEP_RESULT = \"step_result\"\n"
  },
  {
    "path": "tango/integrations/wandb/workspace.py",
    "content": "import logging\nimport tempfile\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Any, Dict, Iterable, Iterator, Optional, TypeVar, Union\nfrom urllib.parse import ParseResult\n\nimport pytz\nimport wandb\n\nfrom tango.common.exceptions import StepStateError\nfrom tango.common.file_lock import FileLock\nfrom tango.common.util import exception_to_string, tango_cache_dir, utc_now_datetime\nfrom tango.step import Step\nfrom tango.step_cache import StepCache\nfrom tango.step_info import StepInfo, StepState\nfrom tango.workspace import Run, Workspace\n\nfrom .step_cache import WandbStepCache\nfrom .util import RunKind, check_environment\n\nT = TypeVar(\"T\")\n\nlogger = logging.getLogger(__name__)\n\n\n@Workspace.register(\"wandb\")\nclass WandbWorkspace(Workspace):\n    \"\"\"\n    This is a :class:`~tango.workspace.Workspace` that tracks Tango runs in a W&B project.\n    It also stores step results as W&B Artifacts via :class:`WandbStepCache`.\n\n    Each Tango run with this workspace will generate multiple runs in your W&B project.\n    There will always be a W&B run corresponding to each Tango run with the same name,\n    which will contain some metadata about the Tango run. Then there will be one W&B run\n    for each cacheable step that runs with a name corresponding to the name of the step.\n    So if your Tango run includes 3 cacheable steps, that will result in a total of 4 new runs in W&B.\n\n    :param project: The W&B project to use for the workspace.\n    :param entity: The W&B entity (user or organization account) to use for the workspace.\n\n    .. tip::\n        Registered as a :class:`~tango.workspace.Workspace` under the name \"wandb\".\n\n    .. tip::\n        If you want to change the artifact kind for step result artifacts uploaded\n        to W&B, add a field called ``artifact_kind`` to the ``metadata`` of\n        the :class:`~tango.step.Step` class.\n\n        This can be useful if you want model objects to be added to the model zoo.\n        In that you would set ``artifact_kind = \"model\"``.\n        For example, your config for the step would look like this:\n\n        .. code-block::\n\n            { type: \"trainer\", step_metadata: { artifact_kind: \"model\" }, ... }\n\n        Or just add this to the ``METADATA`` class attribute:\n\n        .. code-block::\n\n            @Step.register(\"trainer\")\n            class TrainerStep(Step):\n                METADATA = {\"artifact_kind\": \"model\"}\n    \"\"\"\n\n    def __init__(self, project: str, entity: Optional[str] = None):\n        check_environment()\n        super().__init__()\n        self.project = project\n        self._entity = entity\n        self.cache = WandbStepCache(project=self.project, entity=self.entity)\n        self.steps_dir = tango_cache_dir() / \"wandb_workspace\"\n        self.locks: Dict[Step, FileLock] = {}\n        self._running_step_info: Dict[str, StepInfo] = {}\n\n    def __getstate__(self):\n        \"\"\"\n        We override `__getstate__()` to customize how instances of this class are pickled\n        since we don't want to persist certain attributes.\n        \"\"\"\n        out = super().__getstate__()\n        out[\"locks\"] = {}\n        return out\n\n    @property\n    def wandb_client(self) -> wandb.Api:\n        overrides = {\"project\": self.project}\n        if self._entity is not None:\n            overrides[\"entity\"] = self._entity\n        return wandb.Api(overrides=overrides)\n\n    @property\n    def entity(self) -> str:\n        return self._entity or self.wandb_client.default_entity\n\n    @property\n    def url(self) -> str:\n        return f\"wandb://{self.entity}/{self.project}\"\n\n    @classmethod\n    def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:\n        entity = parsed_url.netloc\n        project = parsed_url.path\n        if project:\n            project = project.strip(\"/\")\n        return cls(project=project, entity=entity)\n\n    @property\n    def step_cache(self) -> StepCache:\n        return self.cache\n\n    @property\n    def wandb_project_url(self) -> str:\n        \"\"\"\n        The URL of the W&B project this workspace uses.\n        \"\"\"\n        app_url = self.wandb_client.client.app_url\n        app_url = app_url.rstrip(\"/\")\n        return f\"{app_url}/{self.entity}/{self.project}\"\n\n    def _get_unique_id(self, step_or_unique_id: Union[Step, str]) -> str:\n        if isinstance(step_or_unique_id, Step):\n            unique_id = step_or_unique_id.unique_id\n        else:\n            unique_id = step_or_unique_id\n        return unique_id\n\n    def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path:\n        unique_id = self._get_unique_id(step_or_unique_id)\n        path = self.steps_dir / unique_id\n        path.mkdir(parents=True, exist_ok=True)\n        return path\n\n    def work_dir(self, step: Step) -> Path:\n        path = self.step_dir(step) / \"work\"\n        path.mkdir(parents=True, exist_ok=True)\n        return path\n\n    def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:\n        unique_id = self._get_unique_id(step_or_unique_id)\n        if unique_id in self._running_step_info:\n            return self._running_step_info[unique_id]\n        step_info = self._get_updated_step_info(\n            unique_id,\n            step_name=step_or_unique_id.name if isinstance(step_or_unique_id, Step) else None,\n        )\n        if step_info is None:\n            raise KeyError(step_or_unique_id)\n        else:\n            return step_info\n\n    def step_starting(self, step: Step) -> None:\n        if wandb.run is not None:\n            raise RuntimeError(\n                \"There is already a W&B run initialized, cannot initialize another one.\"\n            )\n\n        work_dir = self.work_dir(step)\n\n        lock_path = self.step_dir(step) / \"lock\"\n        lock = FileLock(lock_path, read_only_ok=True)\n        lock.acquire_with_updates(desc=f\"acquiring lock for '{step.name}'\")\n        self.locks[step] = lock\n\n        step_info = self._get_updated_step_info(step.unique_id) or StepInfo.new_from_step(step)\n        if step_info.state not in {StepState.INCOMPLETE, StepState.FAILED, StepState.UNCACHEABLE}:\n            raise StepStateError(\n                step,\n                step_info.state,\n                context=\"If you are certain the step is not running somewhere else, delete the lock \"\n                f\"file at {lock_path}.\",\n            )\n\n        try:\n            # Initialize W&B run for the step.\n            wandb.init(\n                name=step_info.step_name,\n                job_type=RunKind.STEP.value,\n                group=step.unique_id,\n                dir=str(work_dir),\n                entity=self.entity,\n                project=self.project,\n                # For cacheable steps we can just use the step's unique ID as the W&B run ID,\n                # but not for uncacheable steps since those might be ran more than once, and\n                # and will need a unique W&B run ID each time.\n                id=step.unique_id if step.cache_results else None,\n                resume=\"allow\" if step.cache_results else None,\n                notes=\"\\n\".join(\n                    [\n                        f'Tango step \"{step.name}\"',\n                        f\"\\N{bullet} type: {step_info.step_class_name}\",\n                        f\"\\N{bullet} ID: {step.unique_id}\",\n                    ]\n                ),\n                config={\n                    \"job_type\": RunKind.STEP.value,\n                    \"_run_suite_id\": self._generate_run_suite_id(),  # used for testing only\n                },\n            )\n\n            assert wandb.run is not None\n            logger.info(\n                \"Tracking '%s' step on Weights and Biases: %s/runs/%s/overview\",\n                step.name,\n                self.wandb_project_url,\n                wandb.run.id,\n            )\n\n            # \"Use\" all of the result artifacts for this step's dependencies in order to declare\n            # those dependencies to W&B.\n            for dependency in step.dependencies:\n                self.cache.use_step_result_artifact(dependency)\n\n            # Update StepInfo to mark as running.\n            step_info.start_time = utc_now_datetime()\n            step_info.end_time = None\n            step_info.error = None\n            step_info.result_location = None\n            wandb.run.config.update({\"step_info\": step_info.to_json_dict()}, allow_val_change=True)\n            self._running_step_info[step.unique_id] = step_info\n        except:  # noqa: E722\n            lock.release()\n            del self.locks[step]\n            raise\n\n    def step_finished(self, step: Step, result: T) -> T:\n        if wandb.run is None:\n            raise RuntimeError(\n                f\"{self.__class__.__name__}.step_finished() called outside of a W&B run. \"\n                f\"Did you forget to call {self.__class__.__name__}.step_starting() first?\"\n            )\n\n        step_info = self._running_step_info.get(step.unique_id) or self._get_updated_step_info(\n            step.unique_id\n        )\n        if step_info is None:\n            raise KeyError(step.unique_id)\n\n        try:\n            if step.cache_results:\n                self.step_cache[step] = result\n                if hasattr(result, \"__next__\"):\n                    assert isinstance(result, Iterator)\n                    # Caching the iterator will consume it, so we write it to the\n                    # cache and then read from the cache for the return value.\n                    result = self.step_cache[step]\n                step_info.result_location = self.cache.get_step_result_artifact_url(step)\n            else:\n                # Create an empty artifact in order to build the DAG in W&B.\n                self.cache.create_step_result_artifact(step)\n\n            step_info.end_time = utc_now_datetime()\n            wandb.run.config.update({\"step_info\": step_info.to_json_dict()}, allow_val_change=True)\n\n            # Finalize the step's W&B run.\n            wandb.finish()\n        finally:\n            self.locks[step].release()\n            del self.locks[step]\n            if step.unique_id in self._running_step_info:\n                del self._running_step_info[step.unique_id]\n\n        return result\n\n    def step_failed(self, step: Step, e: BaseException) -> None:\n        if wandb.run is None:\n            raise RuntimeError(\n                f\"{self.__class__.__name__}.step_failed() called outside of a W&B run. \"\n                f\"Did you forget to call {self.__class__.__name__}.step_starting() first?\"\n            )\n\n        step_info = self._running_step_info.get(step.unique_id) or self._get_updated_step_info(\n            step.unique_id\n        )\n        if step_info is None:\n            raise KeyError(step.unique_id)\n\n        try:\n            # Update StepInfo, marking the step as failed.\n            if step_info.state != StepState.RUNNING:\n                raise StepStateError(step, step_info.state)\n            step_info.end_time = utc_now_datetime()\n            step_info.error = exception_to_string(e)\n            wandb.run.config.update({\"step_info\": step_info.to_json_dict()}, allow_val_change=True)\n\n            # Finalize the step's W&B run.\n            wandb.finish(exit_code=1)\n        finally:\n            self.locks[step].release()\n            del self.locks[step]\n            if step.unique_id in self._running_step_info:\n                del self._running_step_info[step.unique_id]\n\n    def remove_step(self, step_unique_id: str):\n        \"\"\"\n        Removes cached step using the given unique step id\n        :raises KeyError: If there is no step with the given name.\n        \"\"\"\n        raise NotImplementedError()\n\n    def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:\n        all_steps = set(targets)\n        for step in targets:\n            all_steps |= step.recursive_dependencies\n\n        wandb_run_id: str\n        wandb_run_name: str\n        with tempfile.TemporaryDirectory() as temp_dir_name:\n            with wandb.init(  # type: ignore[union-attr]\n                job_type=RunKind.TANGO_RUN.value,\n                entity=self.entity,\n                project=self.project,\n                name=name,\n                dir=temp_dir_name,\n                config={\n                    \"job_type\": RunKind.TANGO_RUN.value,  # need this in the config so we can filter runs by this\n                    \"_run_suite_id\": self._generate_run_suite_id(),  # used for testing only\n                },\n            ) as wandb_run:\n                wandb_run_id = wandb_run.id\n                wandb_run_name = wandb_run.name  # type: ignore[assignment]\n                logger.info(\"Registering run %s with Weights and Biases\", wandb_run.name)\n                logger.info(\n                    \"View run at: %s/runs/%s/overview\", self.wandb_project_url, wandb_run_id\n                )\n\n                # Collect step info for all steps.\n                step_ids: Dict[str, bool] = {}\n                step_name_to_info: Dict[str, Dict[str, Any]] = {}\n                for step in all_steps:\n                    step_info = StepInfo.new_from_step(step)\n                    step_name_to_info[step.name] = {\n                        k: v for k, v in step_info.to_json_dict().items() if v is not None\n                    }\n                    step_ids[step.unique_id] = True\n\n                # Update config with step info.\n                wandb_run.config.update({\"steps\": step_name_to_info, \"_step_ids\": step_ids})\n\n                # Update notes.\n                notes = \"Tango run\\n--------------\"\n                cacheable_steps = {step for step in all_steps if step.cache_results}\n                if cacheable_steps:\n                    notes += \"\\nCacheable steps:\\n\"\n                    for step in sorted(cacheable_steps, key=lambda step: step.name):\n                        notes += f\"\\N{bullet} {step.name}\"\n                        dependencies = step.dependencies\n                        if dependencies:\n                            notes += \", depends on: \" + \", \".join(\n                                sorted(\n                                    [f\"'{dep.name}'\" for dep in dependencies],\n                                )\n                            )\n                        notes += \"\\n  \\N{rightwards arrow with hook} \"\n                        notes += f\"{self.wandb_project_url}/runs/{step.unique_id}/overview\\n\"\n                wandb_run.notes = notes\n\n        return self.registered_run(wandb_run_name)\n\n    def _generate_run_suite_id(self) -> str:\n        return wandb.util.generate_id()\n\n    def registered_runs(self) -> Dict[str, Run]:\n        runs: Dict[str, Run] = {}\n        matching_runs = list(\n            self.wandb_client.runs(\n                f\"{self.entity}/{self.project}\",\n                filters={\"config.job_type\": RunKind.TANGO_RUN.value},  # type: ignore\n            )\n        )\n        for wandb_run in matching_runs:\n            runs[wandb_run.name] = self._get_run_from_wandb_run(wandb_run)\n        return runs\n\n    def registered_run(self, name: str) -> Run:\n        matching_runs = list(\n            self.wandb_client.runs(\n                f\"{self.entity}/{self.project}\",\n                filters={\"display_name\": name, \"config.job_type\": RunKind.TANGO_RUN.value},  # type: ignore\n            )\n        )\n        if not matching_runs:\n            raise KeyError(f\"Run '{name}' not found in workspace\")\n        elif len(matching_runs) > 1:\n            raise ValueError(f\"Found more than one run named '{name}' in W&B project\")\n        return self._get_run_from_wandb_run(matching_runs[0])\n\n    def _get_run_from_wandb_run(\n        self,\n        wandb_run: wandb.apis.public.Run,\n    ) -> Run:\n        step_name_to_info = {}\n        for step_name, step_info_dict in wandb_run.config[\"steps\"].items():\n            step_info = StepInfo.from_json_dict(step_info_dict)\n            if step_info.cacheable:\n                updated_step_info = self._get_updated_step_info(\n                    step_info.unique_id, step_name=step_name\n                )\n                if updated_step_info is not None:\n                    step_info = updated_step_info\n            step_name_to_info[step_name] = step_info\n        return Run(\n            name=wandb_run.name,\n            steps=step_name_to_info,\n            start_date=datetime.strptime(wandb_run.created_at, \"%Y-%m-%dT%H:%M:%S\").replace(\n                tzinfo=pytz.utc\n            ),\n        )\n\n    def _get_updated_step_info(\n        self, step_id: str, step_name: Optional[str] = None\n    ) -> Optional[StepInfo]:\n        # First try to find the W&B run corresponding to the step. This will only\n        # work if the step execution was started already.\n        filters = {\n            \"config.job_type\": RunKind.STEP.value,\n            \"config.step_info.unique_id\": step_id,\n        }\n        if step_name is not None:\n            filters[\"display_name\"] = step_name\n        for wandb_run in self.wandb_client.runs(\n            f\"{self.entity}/{self.project}\",\n            filters=filters,  # type: ignore\n        ):\n            step_info = StepInfo.from_json_dict(wandb_run.config[\"step_info\"])\n            # Might need to fix the step info the step failed and we failed to update the config.\n            if step_info.start_time is None:\n                step_info.start_time = datetime.strptime(\n                    wandb_run.created_at, \"%Y-%m-%dT%H:%M:%S\"\n                ).replace(tzinfo=pytz.utc)\n            if wandb_run.state in {\"failed\", \"finished\"}:\n                if step_info.end_time is None:\n                    step_info.end_time = datetime.strptime(\n                        wandb_run.heartbeatAt, \"%Y-%m-%dT%H:%M:%S\"\n                    ).replace(tzinfo=pytz.utc)\n                if wandb_run.state == \"failed\" and step_info.error is None:\n                    step_info.error = \"Exception\"\n            return step_info\n\n        # If the step hasn't been started yet, we'll have to pull the step info from the\n        # registered run.\n        filters = {\n            \"config.job_type\": RunKind.TANGO_RUN.value,\n            f\"config._step_ids.{step_id}\": True,\n        }\n        if step_name is not None:\n            filters[f\"config.steps.{step_name}.unique_id\"] = step_id\n        for wandb_run in self.wandb_client.runs(\n            f\"{self.entity}/{self.project}\",\n            filters=filters,  # type: ignore\n        ):\n            if step_name is not None:\n                step_info_data = wandb_run.config[\"steps\"][step_name]\n            else:\n                step_info_data = next(\n                    d for d in wandb_run.config[\"steps\"].values() if d[\"unique_id\"] == step_id\n                )\n            step_info = StepInfo.from_json_dict(step_info_data)\n            return step_info\n\n        return None\n"
  },
  {
    "path": "tango/py.typed",
    "content": ""
  },
  {
    "path": "tango/settings.py",
    "content": "from dataclasses import dataclass\nfrom pathlib import Path\nfrom typing import Any, ClassVar, Dict, List, Optional\n\nimport yaml\n\nfrom .common.aliases import PathOrStr\nfrom .common.from_params import FromParams\nfrom .common.params import Params\n\n\n@dataclass\nclass TangoGlobalSettings(FromParams):\n    \"\"\"\n    Defines global settings for tango.\n    \"\"\"\n\n    workspace: Optional[Dict[str, Any]] = None\n    \"\"\"\n    Parameters to initialize a :class:`~tango.workspace.Workspace` with.\n    \"\"\"\n\n    executor: Optional[Dict[str, Any]] = None\n    \"\"\"\n    Parameters to initialize an :class:`~tango.executor.Executor` with.\n    \"\"\"\n\n    include_package: Optional[List[str]] = None\n    \"\"\"\n    An list of modules where custom registered steps or classes can be found.\n    \"\"\"\n\n    log_level: Optional[str] = None\n    \"\"\"\n    The log level to use. Options are \"debug\", \"info\", \"warning\", and \"error\".\n\n    .. note::\n        This does not affect the :data:`~tango.common.logging.cli_logger`\n        or logs from :class:`~tango.common.Tqdm` progress bars.\n\n    \"\"\"\n\n    file_friendly_logging: Optional[bool] = None\n    \"\"\"\n    If this flag is set to ``True``, we add newlines to tqdm output, even on an interactive terminal, and we slow\n    down tqdm's output to only once every 10 seconds.\n    \"\"\"\n\n    multiprocessing_start_method: str = \"spawn\"\n    \"\"\"\n    The ``start_method`` to use when starting new multiprocessing workers. Can be \"fork\", \"spawn\",\n    or \"forkserver\". Default is \"spawn\".\n\n    See :func:`multiprocessing.set_start_method()` for more details.\n    \"\"\"\n\n    environment: Optional[Dict[str, str]] = None\n    \"\"\"\n    Environment variables that will be set each time ``tango`` is run.\n    \"\"\"\n\n    _path: Optional[Path] = None\n\n    _DEFAULT_LOCATION: ClassVar[Path] = Path.home() / \".config\" / \"tango.yml\"\n\n    @classmethod\n    def default(cls) -> \"TangoGlobalSettings\":\n        \"\"\"\n        Initialize the config from files by checking the default locations\n        in order, or just return the default if none of the files can be found.\n        \"\"\"\n        for directory in (Path(\".\"), cls._DEFAULT_LOCATION.parent):\n            for extension in (\"yml\", \"yaml\"):\n                path = directory / f\"tango.{extension}\"\n                if path.is_file():\n                    return cls.from_file(path)\n        return cls()\n\n    @classmethod\n    def find_or_default(cls, path: Optional[PathOrStr] = None) -> \"TangoGlobalSettings\":\n        \"\"\"\n        Initialize the config from a given configuration file, or falls back to returning\n        the default configuration if no file is given.\n        \"\"\"\n        if path is not None:\n            path = Path(path)\n            if not path.is_file():\n                raise FileNotFoundError(path)\n            return cls.from_file(path)\n        else:\n            return cls.default()\n\n    @property\n    def path(self) -> Optional[Path]:\n        \"\"\"\n        The path to the file the config was read from.\n        \"\"\"\n        return self._path\n\n    @classmethod\n    def from_file(cls, path: PathOrStr) -> \"TangoGlobalSettings\":\n        \"\"\"\n        Read settings from a file.\n        \"\"\"\n        params = Params.from_file(path)\n        params[\"_path\"] = Path(path).resolve()\n        return cls.from_params(params)\n\n    def to_file(self, path: PathOrStr) -> None:\n        \"\"\"\n        Save the settings to a file.\n        \"\"\"\n        data = {\n            k: v for k, v in self.to_params().as_dict(quiet=True).items() if not k.startswith(\"_\")\n        }\n        with open(path, \"w\") as settings_file:\n            yaml.safe_dump(data, settings_file)\n\n    def save(self) -> None:\n        \"\"\"\n        Save the settings to the file it was read from.\n\n        :raises ValueError: If the settings was not read from a file.\n        \"\"\"\n        if self.path is None:\n            raise ValueError(\"No path given, use .to_file() instead\")\n        self.to_file(self.path)\n"
  },
  {
    "path": "tango/step.py",
    "content": "import inspect\nimport itertools\nimport logging\nimport random\nimport re\nimport warnings\nfrom abc import abstractmethod\nfrom copy import deepcopy\nfrom dataclasses import dataclass\nfrom pathlib import Path\nfrom tempfile import TemporaryDirectory\nfrom typing import (\n    TYPE_CHECKING,\n    Any,\n    Callable,\n    ClassVar,\n    Dict,\n    Generic,\n    Iterable,\n    Optional,\n    Set,\n    Type,\n    TypeVar,\n    Union,\n    cast,\n)\n\nfrom tango.common.det_hash import CustomDetHash, det_hash\nfrom tango.common.exceptions import ConfigurationError, StepStateError\nfrom tango.common.from_params import (\n    FromParams,\n    infer_constructor_params,\n    infer_method_params,\n    pop_and_construct_arg,\n)\nfrom tango.common.lazy import Lazy\nfrom tango.common.logging import cli_logger, log_exception\nfrom tango.common.params import Params\nfrom tango.common.registrable import Registrable\nfrom tango.format import DillFormat, Format\n\ntry:\n    from typing import get_args, get_origin  # type: ignore\nexcept ImportError:\n\n    def get_origin(tp):  # type: ignore\n        return getattr(tp, \"__origin__\", None)\n\n    def get_args(tp):  # type: ignore\n        return getattr(tp, \"__args__\", ())\n\n\nif TYPE_CHECKING:\n    from tango.workspace import Workspace\n\n_version_re = re.compile(\"\"\"^[a-zA-Z0-9]+$\"\"\")\n\nT = TypeVar(\"T\")\n\n\n_random_for_step_names = random.Random()\n\n\n@dataclass\nclass StepResources(FromParams):\n    \"\"\"\n    TaskResources describe minimum external hardware requirements which must be available for a\n    step to run.\n    \"\"\"\n\n    machine: Optional[str] = None\n    \"\"\"\n    This is an executor-dependent option.\n\n    With the Beaker executor, for example, you can set this to \"local\" to force\n    the executor to run the step locally instead of on Beaker.\n    \"\"\"\n\n    cpu_count: Optional[float] = None\n    \"\"\"\n    Minimum number of logical CPU cores. It may be fractional.\n\n    Examples: ``4``, ``0.5``.\n    \"\"\"\n\n    gpu_count: Optional[int] = None\n    \"\"\"\n    Minimum number of GPUs. It must be non-negative.\n    \"\"\"\n\n    gpu_type: Optional[str] = None\n    \"\"\"\n    The type of GPU that the step requires.\n\n    The exact string you should use to define a GPU type depends on the executor.\n    With the Beaker executor, for example, you should use the same strings you\n    see in the Beaker UI, such as 'NVIDIA A100-SXM-80GB'.\n    \"\"\"\n\n    memory: Optional[str] = None\n    \"\"\"\n    Minimum available system memory as a number with unit suffix.\n\n    Examples: ``2.5GiB``, ``1024m``.\n    \"\"\"\n\n    shared_memory: Optional[str] = None\n    \"\"\"\n    Size of ``/dev/shm`` as a number with unit suffix.\n\n    Examples: ``2.5GiB``, ``1024m``.\n    \"\"\"\n\n\nclass Step(Registrable, Generic[T]):\n    \"\"\"\n    This class defines one step in your experiment. To write your own step, derive from this class\n    and overwrite the :meth:`run()` method. The :meth:`run()` method must have parameters with type hints.\n\n    ``Step.__init__()`` takes all the arguments we want to run the step with. They get passed\n    to :meth:`run()` (almost) as they are. If the arguments are other instances of ``Step``, those\n    will be replaced with the step's results before calling :meth:`run()`. Further, there are four special\n    parameters:\n\n    :param step_name: contains an optional human-readable name for the step. This name is used for\n      error messages and the like, and has no consequence on the actual computation.\n    :param cache_results: specifies whether the results of this step should be cached. If this is\n      ``False``, the step is recomputed every time it is needed. If this is not set at all,\n      and :attr:`CACHEABLE` is ``True``, we cache if the step is marked as :attr:`DETERMINISTIC`,\n      and we don't cache otherwise.\n    :param step_format: gives you a way to override the step's default format (which is given in :attr:`FORMAT`).\n    :param step_config: is the original raw part of the experiment config corresponding to this step.\n      This can be accessed via the :attr:`config` property within each step's :meth:`run()` method.\n    :param step_unique_id_override: overrides the construction of the step's unique id using the hash\n      of inputs.\n    :param step_resources: gives you a way to set the minimum compute resources required\n      to run this step. Certain executors require this information.\n    :param step_metadata: use this to specify additional metadata for your step.\n      This is added to the :attr:`METADATA` class variable to form the ``self.metadata`` attribute.\n      Values in ``step_metadata`` take precedence over ``METADATA``.\n    :param step_extra_dependencies: use this to force a dependency on other steps. Normally dependencies\n      between steps are determined by the inputs and outputs of the steps, but you can use this\n      parameter to force that other steps run before this step even if this step doesn't\n      explicitly depend on the outputs of those steps.\n\n    .. important::\n        Overriding the unique id means that the step will always map to this value, regardless of the inputs,\n        and therefore, the step cache will only hold a single copy of the step's output (from the last execution).\n        Thus, in most cases, this should not be used when constructing steps. We include this option for the case\n        when the executor creates subprocesses, which also need to access the *same* ``Step`` object.\n    \"\"\"\n\n    DETERMINISTIC: bool = True\n    \"\"\"This describes whether this step can be relied upon to produce the same results every time\n    when given the same inputs. If this is ``False``, you can still cache the output of the step,\n    but the results might be unexpected. Tango will print a warning in this case.\"\"\"\n\n    CACHEABLE: Optional[bool] = None\n    \"\"\"This provides a direct way to turn off caching. For example, a step that reads a HuggingFace\n    dataset doesn't need to be cached, because HuggingFace datasets already have their own caching\n    mechanism. But it's still a deterministic step, and all following steps are allowed to cache.\n    If it is ``None``, the step figures out by itself whether it should be cacheable or not.\"\"\"\n\n    VERSION: Optional[str] = None\n    \"\"\"This is optional, but recommended. Specifying a version gives you a way to tell Tango that\n    a step has changed during development, and should now be recomputed. This doesn't invalidate\n    the old results, so when you revert your code, the old cache entries will stick around and be\n    picked up.\"\"\"\n\n    FORMAT: Format = DillFormat(\"gz\")\n    \"\"\"This specifies the format the results of this step will be serialized in. See the documentation\n    for :class:`~tango.format.Format` for details.\"\"\"\n\n    SKIP_ID_ARGUMENTS: Set[str] = set()\n    \"\"\"If your :meth:`run()` method takes some arguments that don't affect the results, list them here.\n    Arguments listed here will not be used to calculate this step's unique ID, and thus changing those\n    arguments does not invalidate the cache.\n\n    For example, you might use this for the batch size in an inference step, where you only care about\n    the model output, not about how many outputs you can produce at the same time.\n    \"\"\"\n\n    SKIP_DEFAULT_ARGUMENTS: Dict[str, Any] = {}\n    \"\"\"Sometimes, you want to add another argument to your :meth:`run()` method, but you don't want to\n    invalidate the cache when this new argument is set to its default value. If that is the case, add\n    the argument to this dictionary with the default value that should be ignored.\"\"\"\n\n    METADATA: Dict[str, Any] = {}\n    \"\"\"\n    Arbitrary metadata about the step.\n    \"\"\"\n\n    _UNIQUE_ID_SUFFIX: Optional[str] = None\n    \"\"\"\n    Used internally for testing.\n    \"\"\"\n\n    def __init__(\n        self,\n        step_name: Optional[str] = None,\n        cache_results: Optional[bool] = None,\n        step_format: Optional[Format] = None,\n        step_config: Optional[Union[Dict[str, Any], Params]] = None,\n        step_unique_id_override: Optional[str] = None,\n        step_resources: Optional[StepResources] = None,\n        step_metadata: Optional[Dict[str, Any]] = None,\n        step_extra_dependencies: Optional[Iterable[\"Step\"]] = None,\n        **kwargs,\n    ):\n        if self.VERSION is not None:\n            assert _version_re.match(\n                self.VERSION\n            ), f\"Invalid characters in version '{self.VERSION}'\"\n\n        run_defaults = {\n            k: v.default\n            for k, v in inspect.signature(self.run).parameters.items()\n            if v.default is not inspect.Parameter.empty\n        }\n        self.kwargs = self.massage_kwargs({**run_defaults, **kwargs})\n\n        if step_format is None:\n            self.format = self.FORMAT\n            if isinstance(self.format, type):\n                self.format = self.format()\n        else:\n            self.format = step_format\n\n        self.unique_id_cache = step_unique_id_override\n        if step_name is None:\n            self.name = self.unique_id\n        else:\n            self.name = step_name\n        # TODO: It is bad design to have the step_name in the Step class. The same step can be part of multiple\n        # runs at the same time, and they can have different names in different runs. Step names are\n        # a property of the run, not of the step.\n\n        if cache_results is True:\n            if not self.CACHEABLE:\n                raise ConfigurationError(\n                    f\"Step {self.name} is configured to use the cache, but it's not a cacheable step.\"\n                )\n            if not self.DETERMINISTIC:\n                warnings.warn(\n                    f\"Step {self.name} is going to be cached despite not being deterministic.\",\n                    UserWarning,\n                )\n            self.cache_results = True\n        elif cache_results is False:\n            self.cache_results = False\n        elif cache_results is None:\n            c = (self.DETERMINISTIC, self.CACHEABLE)\n            if c == (False, None):\n                self.cache_results = False\n            elif c == (True, None):\n                self.cache_results = True\n            elif c == (False, False):\n                self.cache_results = False\n            elif c == (True, False):\n                self.cache_results = False\n            elif c == (False, True):\n                warnings.warn(\n                    f\"Step {self.name} is set to be cacheable despite not being deterministic.\",\n                    UserWarning,\n                )\n                self.cache_results = True\n            elif c == (True, True):\n                self.cache_results = True\n            else:\n                assert False, \"Step.DETERMINISTIC or step.CACHEABLE are set to an invalid value.\"\n        else:\n            raise ConfigurationError(\n                f\"Step {self.name}'s cache_results parameter is set to an invalid value.\"\n            )\n\n        self._workspace: Optional[\"Workspace\"] = None\n        self.work_dir_for_run: Optional[\n            Path\n        ] = None  # This is set only while the run() method runs.\n        if isinstance(step_config, Params):\n            self._config = step_config.as_dict(quiet=True)\n        else:\n            self._config = step_config\n        assert step_resources is None or isinstance(step_resources, StepResources)\n        self.step_resources = step_resources\n        self.metadata = deepcopy(self.METADATA)\n        if step_metadata:\n            self.metadata.update(step_metadata)\n        self.extra_dependencies = set(step_extra_dependencies) if step_extra_dependencies else set()\n\n    @property\n    def class_name(self) -> str:\n        return self.__class__.__name__\n\n    @classmethod\n    def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:\n        \"\"\"\n        Override this method in your step if you want to change the step's arguments before they are passed to the\n        :meth:`run()` method.\n\n        This can be useful if you want to normalize arguments that are passed to your step. For example,\n        you might not care about the case of a string that's passed in. You can lowercase the string in this\n        method, and the step will function as if it had been created with a lowercase string from the start.\n        This way you can make sure that the step's unique ID does not change when the case of the input changes.\n\n        .. note::\n            When the input to a step is another step, this method will see the step in the input, not the other\n            step's result.\n\n        .. warning::\n            This is an advanced feature of Tango that you won't need most of the time.\n\n        By default, this method does nothing and just returns its input unchanged.\n\n        :param kwargs: The original kwargs that were passed to the step during construction.\n        :return: New kwargs that will be passed to the step's :meth:`run()` method.\n        \"\"\"\n        return kwargs\n\n    @property\n    def logger(self) -> logging.Logger:\n        \"\"\"\n        A :class:`logging.Logger` that can be used within the :meth:`run()` method.\n        \"\"\"\n        return logging.getLogger(self.__class__.__name__)\n\n    @classmethod\n    def from_params(  # type: ignore[override]\n        cls: Type[\"Step\"],\n        params: Union[Params, dict, str],\n        constructor_to_call: Optional[Callable[..., \"Step\"]] = None,\n        constructor_to_inspect: Optional[\n            Union[Callable[..., \"Step\"], Callable[[\"Step\"], None]]\n        ] = None,\n        step_name: Optional[str] = None,\n        **extras,\n    ) -> \"Step\":\n        # Why do we need a custom from_params? Step classes have a run() method that takes all the\n        # parameters necessary to perform the step. The __init__() method of the step takes those\n        # same parameters, but each of them could be wrapped in another Step instead of being\n        # supplied directly. from_params() doesn't know anything about these shenanigans, so\n        # we have to supply the necessary logic here.\n\n        if constructor_to_call is not None:\n            raise ConfigurationError(\n                f\"{cls.__name__}.from_params cannot be called with a constructor_to_call.\"\n            )\n        if constructor_to_inspect is not None:\n            raise ConfigurationError(\n                f\"{cls.__name__}.from_params cannot be called with a constructor_to_inspect.\"\n            )\n\n        if isinstance(params, str):\n            params = Params({\"type\": params})\n\n        if not isinstance(params, Params):\n            if isinstance(params, dict):\n                params = Params(params)\n            else:\n                raise ConfigurationError(\n                    \"from_params was passed a ``params`` object that was not a ``Params``. This probably \"\n                    \"indicates malformed parameters in a configuration file, where something that \"\n                    \"should have been a dictionary was actually a list, or something else. \"\n                    f\"This happened when constructing an object of type {cls}.\"\n                )\n\n        # Build up a raw step config\n        def replace_steps_with_refs(o: Any) -> Any:\n            if isinstance(o, (list, tuple, set)):\n                return o.__class__(replace_steps_with_refs(i) for i in o)\n            elif isinstance(o, (dict, Params)):\n                result = {key: replace_steps_with_refs(value) for key, value in o.items()}\n                if isinstance(o, dict):\n                    return result\n                elif isinstance(o, Params):\n                    return Params(result, history=o.history)\n            elif isinstance(o, Step):\n                return {\"type\": \"ref\", \"ref\": o.name}\n            else:\n                return deepcopy(o)\n\n        raw_step_config = replace_steps_with_refs(params.as_dict(quiet=True))\n\n        as_registrable = cast(Type[Registrable], cls)\n        if \"type\" in params and params[\"type\"] not in as_registrable.list_available():\n            as_registrable.search_modules(params[\"type\"])\n        choice = params.pop_choice(\n            \"type\", choices=as_registrable.list_available(), default_to_first_choice=False\n        )\n        subclass, constructor_name = as_registrable.resolve_class_name(choice)\n        if not issubclass(subclass, Step):\n            # This can happen if `choice` is a fully qualified name.\n            raise ConfigurationError(\n                f\"Tried to make a Step of type {choice}, but ended up with a {subclass}.\"\n            )\n\n        if issubclass(subclass, FunctionalStep):\n            parameters = infer_method_params(subclass, subclass.WRAPPED_FUNC, infer_kwargs=False)\n            if subclass.BIND:\n                if \"self\" not in parameters:\n                    raise ConfigurationError(\n                        f\"Functional step for {subclass.WRAPPED_FUNC} is bound but is missing argument 'self'\"\n                    )\n                else:\n                    del parameters[\"self\"]\n        else:\n            parameters = infer_method_params(subclass, subclass.run, infer_kwargs=False)\n            del parameters[\"self\"]\n        init_parameters = infer_constructor_params(subclass)\n        del init_parameters[\"self\"]\n        del init_parameters[\"kwargs\"]\n        parameter_overlap = parameters.keys() & init_parameters.keys()\n        assert len(parameter_overlap) <= 0, (\n            f\"If this assert fails it means that you wrote a Step with a run() method that takes one of the \"\n            f\"reserved parameters ({', '.join(init_parameters.keys())})\"\n        )\n        parameters.update(init_parameters)\n\n        kwargs: Dict[str, Any] = {}\n        accepts_kwargs = False\n        for param_name, param in parameters.items():\n            if param.kind == param.VAR_KEYWORD:\n                # When a class takes **kwargs we store the fact that the method allows extra keys; if\n                # we get extra parameters, instead of crashing, we'll just pass them as-is to the\n                # constructor, and hope that you know what you're doing.\n                accepts_kwargs = True\n                continue\n\n            explicitly_set = param_name in params\n            constructed_arg = pop_and_construct_arg(\n                subclass.__name__, param_name, param.annotation, param.default, params, extras\n            )\n\n            # If the param wasn't explicitly set in `params` and we just ended up constructing\n            # the default value for the parameter, we can just omit it.\n            # Leaving it in can cause issues with **kwargs in some corner cases, where you might end up\n            # with multiple values for a single parameter (e.g., the default value gives you lazy=False\n            # for a dataset reader inside **kwargs, but a particular dataset reader actually hard-codes\n            # lazy=True - the superclass sees both lazy=True and lazy=False in its constructor).\n            if explicitly_set or constructed_arg is not param.default:\n                kwargs[param_name] = constructed_arg\n\n        if accepts_kwargs:\n            kwargs.update(params)\n        else:\n            params.assert_empty(subclass.__name__)\n\n        return subclass(step_name=step_name, step_config=raw_step_config, **kwargs)\n\n    @abstractmethod\n    def run(self, **kwargs) -> T:\n        \"\"\"\n        Execute the step's action.\n\n        This method needs to be implemented when creating a ``Step`` subclass, but\n        it shouldn't be called directly. Instead, call :meth:`result()`.\n        \"\"\"\n        raise NotImplementedError()\n\n    def _run_with_work_dir(self, workspace: \"Workspace\", needed_by: Optional[\"Step\"] = None) -> T:\n        if self.work_dir_for_run is not None:\n            raise RuntimeError(\"You can only run a Step's run() method once at a time.\")\n\n        if self.DETERMINISTIC:\n            random.seed(784507111)\n\n        self._workspace = workspace\n\n        if self.cache_results:\n            self.work_dir_for_run = workspace.work_dir(self)\n            dir_for_cleanup = None\n        else:\n            dir_for_cleanup = TemporaryDirectory(prefix=f\"{self.unique_id}-\", suffix=\".step_dir\")\n            self.work_dir_for_run = Path(dir_for_cleanup.name)\n\n        try:\n            self._replace_steps_with_results(self.extra_dependencies, workspace)\n            kwargs = self._replace_steps_with_results(self.kwargs, workspace)\n            self.log_starting(needed_by=needed_by)\n            workspace.step_starting(self)\n\n            try:\n                result = self.run(**kwargs)\n                result = workspace.step_finished(self, result)\n            except BaseException as e:\n                self.log_failure(e)\n                workspace.step_failed(self, e)\n                raise\n\n            self.log_finished()\n            return result\n        finally:\n            self._workspace = None\n            self.work_dir_for_run = None\n            if dir_for_cleanup is not None:\n                dir_for_cleanup.cleanup()\n\n    @property\n    def work_dir(self) -> Path:\n        \"\"\"\n        The working directory that a step can use while its ``:meth:run()`` method runs.\n\n        This is a convenience property for you to call inside your :meth:`run()` method.\n\n        This directory stays around across restarts. You cannot assume that it is empty when your\n        step runs, but you can use it to store information that helps you restart a step if it\n        got killed half-way through the last time it ran.\"\"\"\n        if self.work_dir_for_run is None:\n            raise RuntimeError(\n                \"You can only call this method while the step is running with a working directory. \"\n                \"Did you call '.run()' directly? You should only run a step with '.result()'.\"\n            )\n        return self.work_dir_for_run\n\n    @property\n    def workspace(self) -> \"Workspace\":\n        \"\"\"\n        The :class:`~tango.workspace.Workspace` being used.\n\n        This is a convenience property for you to call inside your :meth:`run()` method.\n        \"\"\"\n        if self._workspace is None:\n            raise RuntimeError(\n                \"You can only call this method while the step is running with a workspace. \"\n                \"Did you call '.run()' directly? You should only run a step with '.result()'.\"\n            )\n        return self._workspace\n\n    @property\n    def config(self) -> Dict[str, Any]:\n        \"\"\"\n        The configuration parameters that were used to construct the step. This can be empty\n        if the step was not constructed from a configuration file.\n        \"\"\"\n        if self._config is None:\n            raise ValueError(f\"No config has been assigned to this step! ('{self.name}')\")\n        else:\n            return self._config\n\n    def det_hash_object(self) -> Any:\n        return self.unique_id\n\n    @property\n    def resources(self) -> StepResources:\n        \"\"\"\n        Defines the minimum compute resources required to run this step.\n        Certain executors require this information in order to allocate resources for each step.\n\n        You can set this with the ``step_resources`` argument to :class:`Step`\n        or you can override this method to automatically define the required resources.\n        \"\"\"\n        return self.step_resources or StepResources()\n\n    @property\n    def unique_id(self) -> str:\n        \"\"\"Returns the unique ID for this step.\n\n        Unique IDs are of the shape ``$class_name-$version-$hash``, where the hash is the hash of the\n        inputs for deterministic steps, and a random string of characters for non-deterministic ones.\n        \"\"\"\n        if self.unique_id_cache is None:\n            self.unique_id_cache = self.class_name\n            if self.VERSION is not None:\n                self.unique_id_cache += \"-\"\n                self.unique_id_cache += self.VERSION\n\n            self.unique_id_cache += \"-\"\n            if self.DETERMINISTIC:\n                hash_kwargs = {\n                    key: value\n                    for key, value in self.kwargs.items()\n                    if (key not in self.SKIP_ID_ARGUMENTS)\n                    and (\n                        (\n                            key not in self.SKIP_DEFAULT_ARGUMENTS\n                            or self.SKIP_DEFAULT_ARGUMENTS[key] != value\n                        )\n                    )\n                }\n                self.unique_id_cache += det_hash(\n                    (\n                        (self.format.__class__.__module__, self.format.__class__.__qualname__),\n                        self.format.VERSION,\n                        hash_kwargs,\n                    )\n                )[:32]\n            else:\n                self.unique_id_cache += det_hash(\n                    _random_for_step_names.getrandbits((58**32).bit_length())\n                )[:32]\n            if self._UNIQUE_ID_SUFFIX is not None:\n                self.unique_id_cache += f\"-{self._UNIQUE_ID_SUFFIX}\"\n\n        return self.unique_id_cache\n\n    def __str__(self):\n        return self.unique_id\n\n    def __hash__(self):\n        \"\"\"\n        A step's hash is just its unique ID.\n        \"\"\"\n        return hash(self.unique_id)\n\n    def __eq__(self, other):\n        \"\"\"\n        Determines whether this step is equal to another step. Two steps with the same unique ID are\n        considered identical.\n        \"\"\"\n        if isinstance(other, Step):\n            return self.unique_id == other.unique_id\n        else:\n            return False\n\n    def _replace_steps_with_results(self, o: Any, workspace: \"Workspace\"):\n        if isinstance(o, (Step, StepIndexer)):\n            return o.result(workspace=workspace, needed_by=self)\n        elif isinstance(o, Lazy):\n            return Lazy(\n                o._constructor,\n                params=Params(\n                    self._replace_steps_with_results(o._params.as_dict(quiet=True), workspace)\n                ),\n                constructor_extras=self._replace_steps_with_results(\n                    o._constructor_extras, workspace\n                ),\n            )\n        elif isinstance(o, WithUnresolvedSteps):\n            return o.construct(workspace)\n        elif isinstance(o, (list, tuple, set)):\n            return o.__class__(self._replace_steps_with_results(i, workspace) for i in o)\n        elif isinstance(o, dict):\n            return {\n                key: self._replace_steps_with_results(value, workspace) for key, value in o.items()\n            }\n        else:\n            return o\n\n    def result(\n        self, workspace: Optional[\"Workspace\"] = None, needed_by: Optional[\"Step\"] = None\n    ) -> T:\n        \"\"\"Returns the result of this step. If the results are cached, it returns those. Otherwise it\n        runs the step and returns the result from there.\n\n        If necessary, this method will first produce the results of all steps it depends on.\"\"\"\n        if workspace is None:\n            from tango.workspaces import default_workspace\n\n            workspace = default_workspace\n\n        from tango.step_info import StepState\n\n        if not self.cache_results or self not in workspace.step_cache:\n            # Try running the step. It might get completed by a different tango process\n            # if there is a race, so we catch \"StepStateError\" and check if it's \"COMPLETED\"\n            # at that point.\n            try:\n                return self._run_with_work_dir(workspace, needed_by=needed_by)\n            except StepStateError as exc:\n                if exc.step_state != StepState.COMPLETED or not self.cache_results:\n                    raise\n                elif self not in workspace.step_cache:\n                    raise StepStateError(\n                        self, exc.step_state, \"because it's not found in the cache\"\n                    )\n                else:\n                    # Step has been completed (and cached) by a different process, so we're done.\n                    pass\n\n        self.log_cache_hit(needed_by=needed_by)\n        return workspace.step_cache[self]\n\n    def ensure_result(\n        self,\n        workspace: Optional[\"Workspace\"] = None,\n    ) -> None:\n        \"\"\"This makes sure that the result of this step is in the cache. It does\n        not return the result.\"\"\"\n        if not self.cache_results:\n            raise RuntimeError(\n                \"It does not make sense to call ensure_result() on a step that's not cacheable.\"\n            )\n\n        if workspace is None:\n            from tango.workspaces import default_workspace\n\n            workspace = default_workspace\n\n        if self in workspace.step_cache:\n            self.log_cache_hit()\n        else:\n            self.result(workspace)\n\n    def _ordered_dependencies(self) -> Iterable[\"Step\"]:\n        def dependencies_internal(o: Any) -> Iterable[Step]:\n            if isinstance(o, Step):\n                yield o\n            elif isinstance(o, Lazy):\n                yield from dependencies_internal(o._params.as_dict(quiet=True))\n            elif isinstance(o, WithUnresolvedSteps):\n                yield from dependencies_internal(o.args)\n                yield from dependencies_internal(o.kwargs)\n            elif isinstance(o, StepIndexer):\n                yield o.step\n            elif isinstance(o, str):\n                return  # Confusingly, str is an Iterable of itself, resulting in infinite recursion.\n            elif isinstance(o, (dict, Params)):\n                yield from dependencies_internal(o.values())\n            elif isinstance(o, Iterable):\n                yield from itertools.chain(*(dependencies_internal(i) for i in o))\n            else:\n                return\n\n        yield from self.extra_dependencies\n        yield from dependencies_internal(self.kwargs.values())\n\n    @property\n    def dependencies(self) -> Set[\"Step\"]:\n        \"\"\"\n        Returns a set of steps that this step depends on. This does not return recursive dependencies.\n        \"\"\"\n        return set(self._ordered_dependencies())\n\n    @property\n    def recursive_dependencies(self) -> Set[\"Step\"]:\n        \"\"\"\n        Returns a set of steps that this step depends on. This returns recursive dependencies.\n        \"\"\"\n        seen = set()\n        steps = list(self.dependencies)\n        while len(steps) > 0:\n            step = steps.pop()\n            if step in seen:\n                continue\n            seen.add(step)\n            steps.extend(step.dependencies)\n        return seen\n\n    def log_cache_hit(self, needed_by: Optional[\"Step\"] = None) -> None:\n        if needed_by is not None:\n            cli_logger.info(\n                '[green]\\N{check mark} Found output for step [bold]\"%s\"[/bold] in cache '\n                '(needed by \"%s\")...[/green]',\n                self.name,\n                needed_by.name,\n            )\n        else:\n            cli_logger.info(\n                '[green]\\N{check mark} Found output for step [bold]\"%s\"[/] in cache...[/]',\n                self.name,\n            )\n\n    def log_starting(self, needed_by: Optional[\"Step\"] = None) -> None:\n        if needed_by is not None:\n            cli_logger.info(\n                '[blue]\\N{black circle} Starting step [bold]\"%s\"[/] (needed by \"%s\")...[/]',\n                self.name,\n                needed_by.name,\n            )\n        else:\n            cli_logger.info(\n                '[blue]\\N{black circle} Starting step [bold]\"%s\"[/]...[/]',\n                self.name,\n            )\n\n    def log_finished(self, run_name: Optional[str] = None) -> None:\n        if run_name is not None:\n            cli_logger.info(\n                '[green]\\N{check mark} Finished run for step [bold]\"%s\"[/] (%s)[/]',\n                self.name,\n                run_name,\n            )\n        else:\n            cli_logger.info(\n                '[green]\\N{check mark} Finished step [bold]\"%s\"[/][/]',\n                self.name,\n            )\n\n    def log_failure(self, exception: Optional[BaseException] = None) -> None:\n        if exception is not None:\n            log_exception(exception, logger=self.logger)\n        cli_logger.error('[red]\\N{ballot x} Step [bold]\"%s\"[/] failed[/]', self.name)\n\n\nclass FunctionalStep(Step):\n    WRAPPED_FUNC: ClassVar[Callable]\n    BIND: ClassVar[bool] = False\n\n    @property\n    def class_name(self) -> str:\n        return self.WRAPPED_FUNC.__name__\n\n    def run(self, *args, **kwargs):\n        if self.BIND:\n            return self.WRAPPED_FUNC(*args, **kwargs)\n        else:\n            return self.__class__.WRAPPED_FUNC(*args, **kwargs)\n\n\ndef step(\n    name: Optional[str] = None,\n    *,\n    exist_ok: bool = False,\n    bind: bool = False,\n    deterministic: bool = True,\n    cacheable: Optional[bool] = None,\n    version: Optional[str] = None,\n    format: Format = DillFormat(\"gz\"),\n    skip_id_arguments: Optional[Set[str]] = None,\n    metadata: Optional[Dict[str, Any]] = None,\n):\n    \"\"\"\n    A decorator to create a :class:`Step` from a function.\n\n    :param name: A name to register the step under. By default the name of the function is used.\n    :param exist_ok:\n        If True, overwrites any existing step registered under the same ``name``. Else,\n        throws an error if a step is already registered under ``name``.\n    :param bind: If ``True``, the first argument passed to the step function will\n        be the underlying :class:`Step` instance, i.e. the function will be called as an instance method.\n        In this case you must name the first argument 'self' or you will get a\n        :class:`~tango.common.exceptions.ConfigurationError` when instantiating the class.\n\n    See the :class:`Step` class for an explanation of the other parameters.\n\n    Example\n    -------\n\n    .. testcode::\n\n        from tango import step\n\n        @step(version=\"001\")\n        def add(a: int, b: int) -> int:\n            return a + b\n\n        @step(bind=True)\n        def bound_step(self) -> None:\n            assert self.work_dir.is_dir()\n    \"\"\"\n\n    def step_wrapper(step_func):\n        @Step.register(name or step_func.__name__, exist_ok=exist_ok)\n        class WrapperStep(FunctionalStep):\n            DETERMINISTIC = deterministic\n            CACHEABLE = cacheable\n            VERSION = version\n            FORMAT = format\n            SKIP_ID_ARGUMENTS = skip_id_arguments or set()\n            METADATA = metadata or {}\n\n            WRAPPED_FUNC = step_func\n            BIND = bind\n\n        return WrapperStep\n\n    return step_wrapper\n\n\nclass StepIndexer(CustomDetHash):\n    def __init__(self, step: Step, key: Union[str, int]):\n        self.step = step\n        self.key = key\n\n    def result(\n        self, workspace: Optional[\"Workspace\"] = None, needed_by: Optional[\"Step\"] = None\n    ) -> Any:\n        return self.step.result(workspace=workspace, needed_by=needed_by)[self.key]\n\n    def det_hash_object(self) -> Any:\n        return self.step.unique_id, self.key\n\n\nclass WithUnresolvedSteps(CustomDetHash):\n    \"\"\"\n    This is a helper class for some scenarios where steps depend on other steps.\n\n    Let's say we have two steps, :class:`ConsumeDataStep` and :class:`ProduceDataStep`. The easiest way to make\n    :class:`ConsumeDataStep` depend on :class:`ProduceDataStep` is to specify ``Produce`` as one of the arguments\n    to the step. This works when ``Consume`` takes the output of ``Produce`` directly, or if it takes\n    it inside standard Python container, like a list, set, or dictionary.\n\n    But what if the output of :class:`ConsumeDataStep` needs to be added to a complex, custom data\n    structure? :class:`WithUnresolvedSteps` takes care of this scenario.\n\n    For example, this works without any help:\n\n    .. code-block:: Python\n\n        class ProduceDataStep(Step[MyDataClass]):\n            def run(self, ...) -> MyDataClass\n                ...\n                return MyDataClass(...)\n\n        class ConsumeDataStep(Step):\n            def run(self, input_data: MyDataClass):\n                ...\n\n        produce = ProduceDataStep()\n        consume = ConsumeDataStep(input_data = produce)\n\n    This scenario needs help:\n\n    .. code-block:: Python\n\n        @dataclass\n        class DataWithTimestamp:\n            data: MyDataClass\n            timestamp: float\n\n        class ProduceDataStep(Step[MyDataClass]):\n            def run(self, ...) -> MyDataClass\n                ...\n                return MyDataClass(...)\n\n        class ConsumeDataStep(Step):\n            def run(self, input_data: DataWithTimestamp):\n                ...\n\n        produce = ProduceDataStep()\n        consume = ConsumeDataStep(\n            input_data = DataWithTimestamp(produce, time.now())\n        )\n\n    That does not work, because :class:`DataWithTimestamp` needs an object of type :class:`MyDataClass`, but we're\n    giving it an object of type :class:`Step[MyDataClass]`. Instead, we change the last line to this:\n\n    .. code-block:: Python\n\n        consume = ConsumeDataStep(\n            input_data = WithUnresolvedSteps(\n                DataWithTimestamp, produce, time.now()\n            )\n        )\n\n    :class:`WithUnresolvedSteps` will delay calling the constructor of ``DataWithTimestamp`` until\n    the :meth:`run()` method runs. Tango will make sure that the results from the ``produce`` step\n    are available at that time, and replaces the step in the arguments with the step's results.\n\n    :param function: The function to call after resolving steps to their results.\n    :param args: The args to pass to the function. These may contain steps, which will be resolved before the\n                 function is called.\n    :param kwargs: The kwargs to pass to the function. These may contain steps, which will be resolved before the\n                   function is called.\n    \"\"\"\n\n    def __init__(self, function, *args, **kwargs):\n        self.function = function\n        self.args = args\n        self.kwargs = kwargs\n\n    @classmethod\n    def with_resolved_steps(\n        cls,\n        o: Any,\n        workspace: \"Workspace\",\n    ):\n        \"\"\"\n        Recursively goes through a Python object and replaces all instances of :class:`.Step` with the results of\n        that step.\n\n        :param o: The Python object to go through\n        :param workspace: The workspace in which to resolve all steps\n        :return: A new object that's a copy of the original object, with all instances of :class:`.Step` replaced\n                 with the results of the step.\n        \"\"\"\n        if isinstance(o, (Step, StepIndexer)):\n            return o.result(workspace=workspace)\n        elif isinstance(o, Lazy):\n            return Lazy(\n                o._constructor,\n                params=Params(cls.with_resolved_steps(o._params.as_dict(quiet=True), workspace)),\n                constructor_extras=cls.with_resolved_steps(o._constructor_extras, workspace),\n            )\n        elif isinstance(o, cls):\n            return o.construct(workspace)\n        elif isinstance(o, (dict, Params)):\n            return o.__class__(\n                {key: cls.with_resolved_steps(value, workspace) for key, value in o.items()}\n            )\n        elif isinstance(o, (list, tuple, set)):\n            return o.__class__(cls.with_resolved_steps(item, workspace) for item in o)\n        else:\n            return o\n\n    def construct(self, workspace: \"Workspace\"):\n        \"\"\"\n        Replaces all steps in the args that are stored in this object, and calls the function with those args.\n\n        :param workspace: The :class:`.Workspace` in which to resolve all the steps.\n        :return: The result of calling the function.\n        \"\"\"\n        resolved_args = self.with_resolved_steps(self.args, workspace)\n        resolved_kwargs = self.with_resolved_steps(self.kwargs, workspace)\n        return self.function(*resolved_args, **resolved_kwargs)\n\n    def det_hash_object(self) -> Any:\n        return self.function.__qualname__, self.args, self.kwargs\n"
  },
  {
    "path": "tango/step_cache.py",
    "content": "import logging\nfrom abc import abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Any, TypeVar, Union\n\nfrom .common.from_params import FromParams\nfrom .common.registrable import Registrable\nfrom .format import Format\nfrom .step import Step\nfrom .step_info import StepInfo\n\nlogger = logging.getLogger(__name__)\n\n\nT = TypeVar(\"T\")\n\n\nclass StepCache(Registrable):\n    \"\"\"\n    This is a mapping from instances of :class:`~tango.step.Step` to the results of that step.\n    Generally :class:`StepCache` implementations are used internally by :class:`~tango.workspace.Workspace`\n    implementations.\n    \"\"\"\n\n    default_implementation = \"memory\"\n    \"\"\"\n    The default implementation is :class:`.MemoryStepCache`.\n    \"\"\"\n\n    def __contains__(self, step: Any) -> bool:\n        \"\"\"This is a generic implementation of ``__contains__``. If you are writing your own\n        ``StepCache``, you might want to write a faster one yourself.\"\"\"\n        if not isinstance(step, (Step, StepInfo)):\n            return False\n        try:\n            self.__getitem__(step)\n            return True\n        except KeyError:\n            return False\n\n    @abstractmethod\n    def __getitem__(self, step: Union[Step, StepInfo]) -> Any:\n        \"\"\"Returns the results for the given step.\"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def __setitem__(self, step: Step, value: Any) -> None:\n        \"\"\"Writes the results for the given step. Throws an exception if the step is already cached.\"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def __delitem__(self, step_unique_id: Union[Step, StepInfo]) -> None:\n        \"\"\"Removes a step from step cache\"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def __len__(self) -> int:\n        \"\"\"Returns the number of results saved in this cache.\"\"\"\n        raise NotImplementedError()\n\n\n@dataclass\nclass CacheMetadata(FromParams):\n    step: str\n    \"\"\"\n    The step name.\n    \"\"\"\n\n    format: Format\n    \"\"\"\n    The format used to serialize the step's result.\n    \"\"\"\n"
  },
  {
    "path": "tango/step_caches/__init__.py",
    "content": "\"\"\"\nBuilt-in :class:`~tango.step_cache.StepCache` implementations.\n\"\"\"\n\nfrom .local_step_cache import LocalStepCache\nfrom .memory_step_cache import MemoryStepCache, default_step_cache\n"
  },
  {
    "path": "tango/step_caches/local_step_cache.py",
    "content": "import collections\nimport logging\nimport os\nimport shutil\nimport warnings\nimport weakref\nfrom pathlib import Path\nfrom typing import Any, MutableMapping, Optional, OrderedDict, Union, cast\n\nfrom tango.common.aliases import PathOrStr\nfrom tango.common.params import Params\nfrom tango.step import Step\nfrom tango.step_cache import CacheMetadata, StepCache\nfrom tango.step_info import StepInfo\n\nlogger = logging.getLogger(__name__)\n\n\n@StepCache.register(\"local\")\nclass LocalStepCache(StepCache):\n    \"\"\"\n    This is a :class:`.StepCache` that stores its results on disk, in the location given in ``dir``.\n\n    Every cached step gets a directory under ``dir`` with that step's :attr:`~tango.step.Step.unique_id`.\n    In that directory we store the results themselves in some format according to the step's\n    :attr:`~tango.step.Step.FORMAT`, and we also write a ``cache-metadata.json`` file that\n    stores the :class:`.CacheMetadata`.\n\n    The presence of ``cache-metadata.json`` signifies that the cache entry is complete and\n    has been written successfully.\n\n    .. tip::\n        Registered as :class:`.StepCache` under the name \"local\".\n\n    \"\"\"\n\n    LRU_CACHE_MAX_SIZE = 8\n    METADATA_FILE_NAME = \"cache-metadata.json\"\n\n    def __init__(self, dir: PathOrStr):\n        self.dir = Path(dir)\n        self.dir.mkdir(parents=True, exist_ok=True)\n\n        # We keep an in-memory cache as well so we don't have to de-serialize stuff\n        # we happen to have in memory already.\n        self.weak_cache: MutableMapping[str, Any]\n        # Not all Python objects can be referenced weakly, and even if they can they\n        # might get removed too quickly, so we also keep an LRU cache.\n        self.strong_cache: OrderedDict[str, Any]\n        self._init_mem_caches()\n\n    def _init_mem_caches(self):\n        self.weak_cache = weakref.WeakValueDictionary()\n        self.strong_cache = collections.OrderedDict()\n\n    def __getstate__(self):\n        \"\"\"\n        We override `__getstate__()` to customize how instances of this class are pickled\n        since we don't want to persist values in the weak and strong in-memory caches\n        during pickling. And `WeakValueDictionary` can't be pickled anyway.\n        \"\"\"\n        return {\"dir\": self.dir}\n\n    def __setstate__(self, state):\n        for k, v in state.items():\n            setattr(self, k, v)\n        self._init_mem_caches()\n\n    def _add_to_cache(self, key: str, o: Any) -> None:\n        if hasattr(o, \"__next__\"):\n            # We never cache iterators, because they are mutable, storing their current position.\n            return\n\n        self.strong_cache[key] = o\n        self.strong_cache.move_to_end(key)\n        while len(self.strong_cache) > self.LRU_CACHE_MAX_SIZE:\n            del self.strong_cache[next(iter(self.strong_cache))]\n\n        try:\n            self.weak_cache[key] = o\n        except TypeError:\n            pass  # Many native Python objects cannot be referenced weakly, and they throw TypeError when you try\n\n    def _get_from_cache(self, key: str) -> Optional[Any]:\n        result = self.strong_cache.get(key)\n        if result is not None:\n            self.strong_cache.move_to_end(key)\n            return result\n        try:\n            return self.weak_cache[key]\n        except KeyError:\n            return None\n\n    def _remove_from_cache(self, key: str) -> None:\n        # check and remove from strong cache\n        if key in self.strong_cache:\n            del self.strong_cache[key]\n            assert key not in self.strong_cache\n\n        # check and remove from weak cache\n        if key in self.weak_cache:\n            del self.weak_cache[key]\n            assert key not in self.weak_cache\n\n    def _metadata_path(self, step_or_unique_id: Union[Step, StepInfo, str]) -> Path:\n        return self.step_dir(step_or_unique_id) / self.METADATA_FILE_NAME\n\n    def __contains__(self, step: object) -> bool:\n        if (isinstance(step, Step) and step.cache_results) or (\n            isinstance(step, StepInfo) and step.cacheable\n        ):\n            key = step.unique_id\n            if key in self.strong_cache:\n                return True\n            if key in self.weak_cache:\n                return True\n            return self._metadata_path(\n                cast(Union[Step, StepInfo], step)  # cast is for mypy :/\n            ).exists()\n        else:\n            return False\n\n    def __getitem__(self, step: Union[Step, StepInfo]) -> Any:\n        key = step.unique_id\n        result = self._get_from_cache(key)\n        if result is None:\n            if step not in self:\n                raise KeyError(step)\n            metadata = CacheMetadata.from_params(Params.from_file(self._metadata_path(step)))\n            result = metadata.format.read(self.step_dir(step))\n            self._add_to_cache(key, result)\n        return result\n\n    def __setitem__(self, step: Step, value: Any) -> None:\n        if not step.cache_results:\n            warnings.warn(\n                f\"Tried to cache step '{step.name}' despite being marked as uncacheable\",\n                UserWarning,\n            )\n            return\n\n        location = self.step_dir(step)\n        location.mkdir(parents=True, exist_ok=True)\n\n        metadata_location = self._metadata_path(step)\n        if metadata_location.exists():\n            raise ValueError(f\"{metadata_location} already exists! Will not overwrite.\")\n        temp_metadata_location = metadata_location.with_suffix(\".temp\")\n\n        try:\n            step.format.write(value, location)\n            metadata = CacheMetadata(step=step.unique_id, format=step.format)\n            metadata.to_params().to_file(temp_metadata_location)\n            self._add_to_cache(step.unique_id, value)\n            temp_metadata_location.rename(metadata_location)\n        except:  # noqa: E722\n            try:\n                temp_metadata_location.unlink()\n            except FileNotFoundError:\n                pass\n            raise\n\n    def __delitem__(self, step: Union[Step, StepInfo]) -> None:\n        location = str(self.dir) + \"/\" + str(step.unique_id)\n        try:\n            shutil.rmtree(location)\n            self._remove_from_cache(step.unique_id)\n        except OSError:\n            raise OSError(f\"Step cache folder for '{step.unique_id}' not found. Cannot be deleted.\")\n\n    def __len__(self) -> int:\n        return sum(1 for _ in self.dir.glob(f\"*/{self.METADATA_FILE_NAME}\"))\n\n    def step_dir(self, step_or_unique_id: Union[Step, StepInfo, str]) -> Path:\n        \"\"\"Returns the directory that contains the results of the step.\n\n        You can use this even for a step that's not cached yet. In that case it will return the directory where\n        the results will be written.\"\"\"\n        if isinstance(step_or_unique_id, (Step, StepInfo)):\n            cacheable = (\n                step_or_unique_id.cache_results\n                if isinstance(step_or_unique_id, Step)\n                else step_or_unique_id.cacheable\n            )\n            if not cacheable:\n                class_name = (\n                    step_or_unique_id.class_name\n                    if isinstance(step_or_unique_id, Step)\n                    else step_or_unique_id.step_class_name\n                )\n                raise RuntimeError(\n                    f\"Uncacheable steps (like '{class_name}') don't have step directories.\"\n                )\n            unique_id = step_or_unique_id.unique_id\n        else:\n            unique_id = step_or_unique_id\n        return self.dir / unique_id\n"
  },
  {
    "path": "tango/step_caches/memory_step_cache.py",
    "content": "import logging\nimport warnings\nfrom typing import Any, Dict, Union\n\nfrom tango.step import Step\nfrom tango.step_cache import StepCache\nfrom tango.step_info import StepInfo\n\nlogger = logging.getLogger(__name__)\n\n\n@StepCache.register(\"memory\")\nclass MemoryStepCache(StepCache):\n    \"\"\"\n    This is a :class:`.StepCache` that stores results in memory. It is little more than a Python dictionary.\n\n    .. tip::\n        Registered as :class:`.StepCache` under the name \"memory\".\n    \"\"\"\n\n    def __init__(self):\n        self.cache: Dict[str, Any] = {}\n\n    def __getitem__(self, step: Union[Step, StepInfo]) -> Any:\n        return self.cache[step.unique_id]\n\n    def __setitem__(self, step: Step, value: Any) -> None:\n        if step in self:\n            raise ValueError(f\"{step.unique_id} is already cached! Will not overwrite.\")\n        if step.cache_results:\n            self.cache[step.unique_id] = value\n        else:\n            warnings.warn(\n                f\"Tried to cache step '{step.name}' despite being marked as uncacheable.\",\n                UserWarning,\n            )\n\n    def __delitem__(self, step: Union[Step, StepInfo]) -> None:\n        if step.unique_id in self.cache:\n            del self.cache[step.unique_id]\n        else:\n            raise KeyError(f\"{step.unique_id} not present in the memory cache. Cannot be deleted.\")\n\n    def __contains__(self, step: object) -> bool:\n        if isinstance(step, (Step, StepInfo)):\n            return step.unique_id in self.cache\n        else:\n            return False\n\n    def __len__(self) -> int:\n        return len(self.cache)\n\n\ndefault_step_cache = MemoryStepCache()\n"
  },
  {
    "path": "tango/step_caches/remote_step_cache.py",
    "content": "import logging\nimport os\nimport shutil\nimport tempfile\nfrom abc import abstractmethod\nfrom pathlib import Path\nfrom typing import Any, Union\n\nfrom tango.common.aliases import PathOrStr\nfrom tango.common.exceptions import TangoError\nfrom tango.common.file_lock import FileLock\nfrom tango.common.params import Params\nfrom tango.common.remote_utils import RemoteConstants\nfrom tango.step import Step\nfrom tango.step_cache import CacheMetadata\nfrom tango.step_caches.local_step_cache import LocalStepCache\nfrom tango.step_info import StepInfo\n\nlogger = logging.getLogger(__name__)\n\n\nclass RemoteNotFoundError(TangoError):\n    \"\"\"\n    Classes inheriting from the RemoteStepCache should raise this if a step result object is not found.\n    \"\"\"\n\n\n# This class inherits from `LocalStepCache` to benefit from its in-memory \"weak cache\" and \"strong cache\",\n# but it handles saving artifacts to disk a little differently.\nclass RemoteStepCache(LocalStepCache):\n    \"\"\"\n    This is a :class:`~tango.step_cache.StepCache` that's used by :class:`RemoteWorkspace`.\n    It stores the results of steps on some RemoteWorkspace.\n\n    It also keeps a limited in-memory cache as well as a local backup on disk, so fetching a\n    step's resulting subsequent times should be fast.\n\n    .. tip::\n        All remote step caches inherit from this.\n    \"\"\"\n\n    Constants = RemoteConstants\n\n    def __init__(self, local_dir: Path):\n        super().__init__(local_dir)\n\n    @abstractmethod\n    def _step_result_remote(self, step: Union[Step, StepInfo]):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def _upload_step_remote(self, step: Step, objects_dir: Path):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def _download_step_remote(self, step_result, target_dir: PathOrStr) -> None:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def __len__(self):\n        raise NotImplementedError()\n\n    def _acquire_step_lock_file(self, step: Union[Step, StepInfo], read_only_ok: bool = False):\n        return FileLock(\n            self.step_dir(step).with_suffix(\".lock\"), read_only_ok=read_only_ok\n        ).acquire_with_updates(desc=f\"acquiring step cache lock for '{step.unique_id}'\")\n\n    def __contains__(self, step: Any) -> bool:\n        if isinstance(step, (Step, StepInfo)):\n            cacheable = step.cache_results if isinstance(step, Step) else step.cacheable\n            if not cacheable:\n                return False\n\n            key = step.unique_id\n\n            # First check if we have a copy in memory.\n            if key in self.strong_cache:\n                return True\n            if key in self.weak_cache:\n                return True\n\n            # Then check if we have a copy on disk in our cache directory.\n            with self._acquire_step_lock_file(step, read_only_ok=True):\n                if self.step_dir(step).is_dir():\n                    return True\n\n            # If not, check the remote location.\n            return self._step_result_remote(step) is not None\n        else:\n            return False\n\n    def __getitem__(self, step: Union[Step, StepInfo]) -> Any:\n        key = step.unique_id\n        step_result = self._step_result_remote(step)\n        if step_result is None:\n            raise KeyError(step)\n\n        # Try getting the result from our in-memory caches first.\n        result = self._get_from_cache(key)\n        if result is not None:\n            return result\n\n        def load_and_return():\n            metadata = CacheMetadata.from_params(Params.from_file(self._metadata_path(step)))\n            result = metadata.format.read(self.step_dir(step) / self.Constants.STEP_RESULT_DIR)\n            self._add_to_cache(key, result)\n            return result\n\n        # Next check our local on-disk cache.\n        with self._acquire_step_lock_file(step, read_only_ok=True):\n            if self.step_dir(step).is_dir():\n                return load_and_return()\n\n        # Finally, check the remote location for the corresponding dataset.\n        with self._acquire_step_lock_file(step):\n            # Make sure the step wasn't cached since the last time we checked (above).\n            if self.step_dir(step).is_dir():\n                return load_and_return()\n\n            # We'll download the dataset to a temporary directory first, in case something goes wrong.\n            temp_dir = tempfile.mkdtemp(dir=self.dir, prefix=key)\n            try:\n                self._download_step_remote(step_result, target_dir=temp_dir)\n                # Download and extraction was successful, rename temp directory to final step result directory.\n                os.replace(temp_dir, self.step_dir(step))\n            except RemoteNotFoundError:\n                raise KeyError(step)\n            finally:\n                shutil.rmtree(temp_dir, ignore_errors=True)\n\n            return load_and_return()\n\n    def __setitem__(self, step: Step, value: Any) -> None:\n        if not step.cache_results:\n            logger.warning(\"Tried to cache step %s despite being marked as uncacheable.\", step.name)\n            return\n\n        with self._acquire_step_lock_file(step):\n            # We'll write the step's results to temporary directory first, and try to upload to\n            # remote workspace from there in case anything goes wrong.\n            temp_dir = Path(tempfile.mkdtemp(dir=self.dir, prefix=step.unique_id))\n            (temp_dir / self.Constants.STEP_RESULT_DIR).mkdir()\n            try:\n                step.format.write(value, temp_dir / self.Constants.STEP_RESULT_DIR)\n                metadata = CacheMetadata(step=step.unique_id, format=step.format)\n                metadata.to_params().to_file(temp_dir / self.METADATA_FILE_NAME)\n                # Create the dataset and upload serialized result to it.\n                self._upload_step_remote(step, temp_dir)\n                # Upload successful, rename temp directory to the final step result directory.\n                if self.step_dir(step).is_dir():\n                    shutil.rmtree(self.step_dir(step), ignore_errors=True)\n                os.replace(temp_dir, self.step_dir(step))\n            finally:\n                shutil.rmtree(temp_dir, ignore_errors=True)\n\n        # Finally, add to in-memory caches.\n        self._add_to_cache(step.unique_id, value)\n"
  },
  {
    "path": "tango/step_graph.py",
    "content": "import logging\nfrom typing import Any, Dict, Iterator, List, Mapping, Set, Type, Union\n\nfrom tango.common import PathOrStr\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.params import Params\nfrom tango.step import Step, StepIndexer\n\nlogger = logging.getLogger(__name__)\n\n\nclass StepGraph(Mapping[str, Step]):\n    \"\"\"\n    Represents an experiment as a directed graph.\n\n    It can be treated as a :class:`~collections.abc.Mapping` of step names (``str``)\n    to :class:`Step`.\n    \"\"\"\n\n    def __init__(self, step_dict: Dict[str, Step]):\n        # TODO: What happens with anonymous steps in here?\n\n        is_ordered = self._is_ordered(step_dict)\n        if not is_ordered:\n            self.parsed_steps = {step.name: step for step in self.ordered_steps(step_dict)}\n        else:\n            self.parsed_steps = {}\n            for step_name, step in step_dict.items():\n                step.name = step_name\n                self.parsed_steps[step_name] = step\n\n        # Sanity-check the graph\n        self._sanity_check()\n\n    @classmethod\n    def _is_ordered(cls, step_dict: Dict[str, Step]):\n        present = set()\n        for _, step in step_dict.items():\n            for dep in step.dependencies:\n                if dep.name not in present:\n                    return False\n            present.add(step.name)\n        return True\n\n    @classmethod\n    def _check_unsatisfiable_dependencies(cls, dependencies: Dict[str, Set[str]]) -> None:\n        # Check whether some of those dependencies can never be satisfied.\n        unsatisfiable_dependencies = {\n            dep\n            for step_deps in dependencies.values()\n            for dep in step_deps\n            if dep not in dependencies.keys()\n        }\n        if len(unsatisfiable_dependencies) > 0:\n            if len(unsatisfiable_dependencies) == 1:\n                dep = next(iter(unsatisfiable_dependencies))\n                raise ConfigurationError(\n                    f\"Specified dependency '{dep}' can't be found in the config.\"\n                )\n            else:\n                raise ConfigurationError(\n                    f\"Some dependencies can't be found in the config: {', '.join(unsatisfiable_dependencies)}\"\n                )\n\n    @classmethod\n    def _get_ordered_steps(cls, dependencies: Dict[str, Set[str]]) -> List[str]:\n        done: Set[str] = set()\n        todo = list(dependencies.keys())\n        ordered_steps = list()\n        while len(todo) > 0:\n            new_todo = []\n            for step_name in todo:\n                if len(dependencies[step_name] & done) == len(dependencies[step_name]):\n                    done.add(step_name)\n                    ordered_steps.append(step_name)\n                else:\n                    new_todo.append(step_name)\n            if len(todo) == len(new_todo):\n                raise ConfigurationError(\n                    \"Could not make progress parsing the steps. \"\n                    \"You probably have a circular reference between the steps, \"\n                    \"Or a missing dependency.\"\n                )\n            todo = new_todo\n        del dependencies\n        del done\n        del todo\n        return ordered_steps\n\n    def _sanity_check(self) -> None:\n        for step in self.parsed_steps.values():\n            if step.cache_results:\n                nondeterministic_dependencies = [\n                    s for s in step.recursive_dependencies if not s.DETERMINISTIC\n                ]\n                if len(nondeterministic_dependencies) > 0:\n                    nd_step = nondeterministic_dependencies[0]\n                    logger.warning(\n                        f\"Task {step.name} is set to cache results, but depends on non-deterministic \"\n                        f\"step {nd_step.name}. This will produce confusing results.\"\n                    )\n\n    @classmethod\n    def from_params(cls: Type[\"StepGraph\"], params: Dict[str, Params]) -> \"StepGraph\":  # type: ignore[override]\n        # Determine the order in which to create steps so that all dependent steps are available when we need them.\n        # This algorithm for resolving step dependencies is O(n^2). Since we're\n        # anticipating the number of steps in a single config to be in the dozens at most (#famouslastwords),\n        # we choose simplicity over cleverness.\n        dependencies = {\n            step_name: cls._find_step_dependencies(step_params)\n            for step_name, step_params in params.items()\n        }\n        cls._check_unsatisfiable_dependencies(dependencies)\n\n        # We need ordered dependencies to construct the steps with refs.\n        ordered_steps = cls._get_ordered_steps(dependencies)\n\n        # Parse the steps\n        step_dict: Dict[str, Step] = {}\n        for step_name in ordered_steps:\n            step_params = params.pop(step_name)\n            if step_name in step_dict:\n                raise ConfigurationError(f\"Duplicate step name {step_name}\")\n\n            step_params = cls._replace_step_dependencies(step_params, step_dict)\n            step_dict[step_name] = Step.from_params(step_params, step_name=step_name)\n\n        return cls(step_dict)\n\n    def sub_graph(self, *step_names: str) -> \"StepGraph\":\n        step_dict: Dict[str, Step] = {}\n        for step_name in step_names:\n            if step_name not in self.parsed_steps:\n                raise KeyError(\n                    f\"{step_name} is not a part of this StepGraph. \"\n                    f\"Available steps are: {list(self.parsed_steps.keys())}\"\n                )\n            step_dict.update(\n                {dep.name: dep for dep in self.parsed_steps[step_name].recursive_dependencies}\n            )\n            step_dict[step_name] = self.parsed_steps[step_name]\n        return StepGraph(step_dict)\n\n    @staticmethod\n    def _dict_is_ref(d: Union[dict, Params]) -> bool:\n        keys = set(d.keys())\n        if keys == {\"ref\"}:\n            return True\n        if keys >= {\"type\", \"ref\"} and d[\"type\"] == \"ref\":\n            return True\n        return False\n\n    @classmethod\n    def _find_step_dependencies(cls, o: Any) -> Set[str]:\n        dependencies: Set[str] = set()\n        if isinstance(o, (list, tuple, set)):\n            for item in o:\n                dependencies = dependencies | cls._find_step_dependencies(item)\n        elif isinstance(o, (dict, Params)):\n            if cls._dict_is_ref(o):\n                dependencies.add(o[\"ref\"])\n            else:\n                for value in o.values():\n                    dependencies = dependencies | cls._find_step_dependencies(value)\n        elif o is not None and not isinstance(o, (bool, str, int, float)):\n            raise ValueError(o)\n        return dependencies\n\n    @classmethod\n    def _replace_step_dependencies(cls, o: Any, existing_steps: Mapping[str, Step]) -> Any:\n        if isinstance(o, (list, tuple, set)):\n            return o.__class__(cls._replace_step_dependencies(i, existing_steps) for i in o)\n        elif isinstance(o, (dict, Params)):\n            if cls._dict_is_ref(o):\n                if \"key\" in o:\n                    return StepIndexer(existing_steps[o[\"ref\"]], o[\"key\"])\n                else:\n                    return existing_steps[o[\"ref\"]]\n            else:\n                result = {\n                    key: cls._replace_step_dependencies(value, existing_steps)\n                    for key, value in o.items()\n                }\n                if isinstance(o, dict):\n                    return result\n                elif isinstance(o, Params):\n                    return Params(result, history=o.history)\n                else:\n                    raise RuntimeError(f\"Object {o} is of unexpected type {o.__class__}.\")\n        elif o is not None and not isinstance(o, (bool, str, int, float)):\n            raise ValueError(o)\n        return o\n\n    def __getitem__(self, name: str) -> Step:\n        \"\"\"\n        Get the step with the given name.\n        \"\"\"\n        return self.parsed_steps[name]\n\n    def __len__(self) -> int:\n        \"\"\"\n        The number of steps in the experiment.\n        \"\"\"\n        return len(self.parsed_steps)\n\n    def __iter__(self) -> Iterator[str]:\n        \"\"\"\n        The names of the steps in the experiment.\n        \"\"\"\n        return iter(self.parsed_steps)\n\n    @classmethod\n    def ordered_steps(cls, step_dict: Dict[str, Step]) -> List[Step]:\n        \"\"\"\n        Returns the steps in this step graph in an order that can be executed one at a time.\n\n        This does not take into account which steps may be cached. It simply returns an executable\n        order of steps.\n        \"\"\"\n        dependencies = {\n            step_name: set([dep.name for dep in step.dependencies])\n            for step_name, step in step_dict.items()\n        }\n        result: List[Step] = []\n        for step_name in cls._get_ordered_steps(dependencies):\n            step_dict[step_name].name = step_name\n            result.append(step_dict[step_name])\n        return result\n\n    def uncacheable_leaf_steps(self) -> Set[Step]:\n        interior_steps: Set[Step] = set()\n        for _, step in self.parsed_steps.items():\n            for dependency in step.dependencies:\n                interior_steps.add(dependency)\n        uncacheable_leaf_steps = {\n            step for step in set(self.values()) - interior_steps if not step.cache_results\n        }\n        return uncacheable_leaf_steps\n\n    @classmethod\n    def from_file(cls, filename: PathOrStr) -> \"StepGraph\":\n        params = Params.from_file(filename)\n        return StepGraph.from_params(params.pop(\"steps\", keep_as_dict=True))\n\n    def to_config(self, include_unique_id: bool = False) -> Dict[str, Dict]:\n        step_dict = {}\n\n        def _to_config(o: Any):\n            if isinstance(o, (list, tuple, set)):\n                return o.__class__(_to_config(i) for i in o)\n            elif isinstance(o, dict):\n                return {key: _to_config(value) for key, value in o.items()}\n            elif isinstance(o, Step):\n                return {\"type\": \"ref\", \"ref\": o.name}\n            elif isinstance(o, StepIndexer):\n                return {\"type\": \"ref\", \"ref\": o.step.name, \"key\": o.key}\n            elif o is not None and not isinstance(o, (bool, str, int, float)):\n                raise ValueError(o)\n            return o\n\n        for step_name, step in self.parsed_steps.items():\n            try:\n                step_dict[step_name] = {\n                    key: _to_config(value) for key, value in step.config.items()\n                }\n            except ValueError:  # step.config throws an error.\n                # If the step_graph was not constructed using a config, we attempt to create\n                # the config using the step object.\n                step_dict[step_name] = {\n                    key: _to_config(val) for key, val in step._to_params()[\"kwargs\"].items()\n                }\n                step_dict[step_name][\"type\"] = step.__module__ + \".\" + step.class_name\n\n                # We only add cache_results and format to the config if the values are different from default.\n                if step.cache_results != step.CACHEABLE:\n                    step_dict[step_name][\"cache_results\"] = step.cache_results\n                if step.format != step.FORMAT:\n                    step_dict[step_name][\"step_format\"] = _to_config(step.format._to_params())\n\n            if include_unique_id:\n                step_dict[step_name][\"step_unique_id_override\"] = step.unique_id\n\n        return step_dict\n\n    def to_file(self, filename: PathOrStr, include_unique_id: bool = False) -> None:\n        \"\"\"\n        Note: In normal use cases, `include_unique_id` should always be False.\n        We do not want to save the unique id in the config, because we want the\n        output to change if we modify other kwargs in the config file. We include\n        this flag for `MulticoreExecutor`, to ensure that steps have the same\n        unique id in the main process and the created subprocesses.\n        \"\"\"\n        step_dict = self.to_config(include_unique_id=include_unique_id)\n        params = Params({\"steps\": step_dict})\n        params.to_file(filename)\n\n    def __repr__(self) -> str:\n        result = [f'\"{name}\": {step}' for name, step in self.items()]\n        result = \", \".join(result)\n        return f\"{self.__class__.__name__}({result})\"\n"
  },
  {
    "path": "tango/step_info.py",
    "content": "import getpass\nimport logging\nimport os\nimport platform\nimport socket\nimport sys\nfrom dataclasses import dataclass, field\nfrom datetime import datetime, timedelta\nfrom pathlib import Path\nfrom typing import Any, Dict, List, Optional, Set, Tuple\n\nimport pytz\n\nfrom .common.from_params import FromParams\nfrom .common.logging import log_exception\nfrom .common.util import StrEnum, jsonify, local_timezone, replace_steps_with_unique_id\nfrom .step import Step\nfrom .version import VERSION\n\nlogger = logging.getLogger(__name__)\n\n\ndef get_pip_packages() -> Optional[List[Tuple[str, str]]]:\n    \"\"\"\n    Get the current working set of pip packages. Equivalent to running ``pip freeze``.\n    \"\"\"\n    # Adapted from the Weights & Biases client library:\n    # github.com/wandb/client/blob/a04722575eee72eece7eef0419d0cea20940f9fe/wandb/sdk/internal/meta.py#L56-L72\n    try:\n        import pkg_resources\n\n        return sorted([(d.key, d.version) for d in iter(pkg_resources.working_set)])\n    except Exception as exc:\n        logger.error(\"Error saving pip packages\")\n        log_exception(exc)\n    return None\n\n\nclass StepState(StrEnum):\n    \"\"\"Describes the possible state a step can be in.\"\"\"\n\n    INCOMPLETE = \"incomplete\"\n    \"\"\"The step has not run yet.\"\"\"\n\n    RUNNING = \"running\"\n    \"\"\"The step is running right now.\"\"\"\n\n    COMPLETED = \"completed\"\n    \"\"\"The step finished running successfully.\"\"\"\n\n    FAILED = \"failed\"\n    \"\"\"The step ran, but failed.\"\"\"\n\n    UNCACHEABLE = \"uncacheable\"\n    \"\"\"The step is uncacheable. It will be executed as many times as the results are needed,\n    so we don't keep track of the state.\"\"\"\n\n\n@dataclass\nclass GitMetadata(FromParams):\n    commit: Optional[str] = None\n    \"\"\"\n    The commit SHA of the current repo.\n    \"\"\"\n\n    remote: Optional[str] = None\n    \"\"\"\n    The URL of the primary remote.\n    \"\"\"\n\n    @classmethod\n    def check_for_repo(cls) -> Optional[\"GitMetadata\"]:\n        from git import InvalidGitRepositoryError, Repo\n\n        try:\n            repo = Repo(\".\")\n        except InvalidGitRepositoryError:\n            return None\n\n        return cls(commit=str(repo.commit()), remote=repo.remote().url)\n\n\n@dataclass\nclass TangoMetadata(FromParams):\n    version: str = VERSION\n    \"\"\"\n    The tango release version.\n    \"\"\"\n\n\n@dataclass\nclass PlatformMetadata(FromParams):\n    operating_system: str = field(default_factory=platform.platform)\n    \"\"\"\n    Full operating system name.\n    \"\"\"\n\n    cpu_count: Optional[int] = field(default_factory=os.cpu_count)\n    \"\"\"\n    Numbers of CPUs on the machine.\n    \"\"\"\n\n    user: str = field(default_factory=getpass.getuser)\n    \"\"\"\n    The user that ran this step.\n    \"\"\"\n\n    host: str = field(default_factory=socket.gethostname)\n    \"\"\"\n    Name of the host machine.\n    \"\"\"\n\n\n@dataclass\nclass EnvironmentMetadata(FromParams):\n    python: str = field(default_factory=platform.python_version)\n    \"\"\"\n    The Python version.\n    \"\"\"\n\n    executable: Path = field(default_factory=lambda: Path(sys.executable))\n    \"\"\"\n    Path to the Python executable.\n    \"\"\"\n\n    command: str = field(default_factory=lambda: \" \".join(sys.argv))\n    \"\"\"\n    The exact command used.\n    \"\"\"\n\n    root: Path = field(default_factory=lambda: Path(os.getcwd()))\n    \"\"\"\n    The root directory from where the Python executable was ran.\n    \"\"\"\n\n    packages: Optional[List[Tuple[str, str]]] = field(default_factory=get_pip_packages)\n    \"\"\"\n    The current set of Python packages in the Python environment. Each entry is a tuple of strings.\n    The first element is the name of the package, the second element is the version.\n    \"\"\"\n\n    git: Optional[GitMetadata] = field(default_factory=GitMetadata.check_for_repo)\n    \"\"\"\n    The :class:`GitMetadata`.\n    \"\"\"\n\n    tango: Optional[TangoMetadata] = field(default_factory=TangoMetadata)\n    \"\"\"\n    The :class:`TangoMetadata`.\n    \"\"\"\n\n\n@dataclass\nclass StepInfo(FromParams):\n    \"\"\"Stores step information without being the :class:`.Step` itself.\n\n    It's not always possible to get a :class:`.Step` object, because :class:`.Step` objects can't be serialized.\n    But you can always serialize a :class:`.StepInfo` object.\n    \"\"\"\n\n    unique_id: str\n    \"\"\"\n    The unique ID of the step\n    \"\"\"\n\n    step_class_name: str\n    \"\"\"\n    The name of the :class:`.Step` class\n    \"\"\"\n\n    dependencies: Set[str]\n    \"\"\"\n    The unique ids of all the steps that this step depends on\n    \"\"\"\n\n    cacheable: bool\n    \"\"\"\n    Whether or not the step is cacheable.\n    \"\"\"\n\n    step_name: Optional[str] = None\n    \"\"\"\n    The name of the step, if it has one. Anonymous steps are identified only by their unique ID.\n\n    The same step can have different names in different runs. The last run wins, so don't rely\n    on this property in your code. It is just here to aid readability.\n    \"\"\"\n\n    version: Optional[str] = None\n    \"\"\"\n    The version string of the :class:`.Step`, if it has one.\n    \"\"\"\n\n    start_time: Optional[datetime] = None\n    \"\"\"\n    The time (in UTC) that this step started running.\n\n    .. seealso::\n        :meth:`start_time_local()`.\n    \"\"\"\n\n    end_time: Optional[datetime] = None\n    \"\"\"\n    The time (in UTC) that this step stopped running. This will be set whether the step succeeded or failed.\n\n    .. seealso::\n        :meth:`end_time_local()`.\n    \"\"\"\n\n    error: Optional[str] = None\n    \"\"\"\n    If the step failed, this is where the error goes.\n\n    .. note::\n        Some ``Workspace`` implementations need to serialize ``StepInfo`` (using pickle or dill, for example),\n        but some exceptions can't be pickled. In those cases ``error`` will just be a string representation\n        of the exception.\n    \"\"\"\n\n    result_location: Optional[str] = None\n    \"\"\"\n    Location of the result. This could be a path or a URL.\n    \"\"\"\n\n    config: Optional[Dict[str, Any]] = None\n    \"\"\"\n    The raw config of the step.\n    \"\"\"\n\n    metadata: Optional[Dict[str, Any]] = None\n    \"\"\"\n    Metadata from the step. This comes from the ``step_metadata``\n    argument to the :class:`~tango.step.Step` class.\n    \"\"\"\n\n    platform: PlatformMetadata = field(default_factory=PlatformMetadata)\n    \"\"\"\n    The :class:`PlatformMetadata`.\n    \"\"\"\n\n    environment: EnvironmentMetadata = field(default_factory=EnvironmentMetadata)\n    \"\"\"\n    The :class:`EnvironmentMetadata`.\n    \"\"\"\n\n    @property\n    def start_time_local(self) -> Optional[datetime]:\n        \"\"\"\n        The time the step started running with respect to the local timezone, if the timezone\n        can be determined.\n        \"\"\"\n        return None if self.start_time is None else self.start_time.astimezone(local_timezone())\n\n    @property\n    def end_time_local(self) -> Optional[datetime]:\n        \"\"\"\n        The time the step stopped running with respect to the local timezone, if the timezone\n        can be determined.\n        \"\"\"\n        return None if self.end_time is None else self.end_time.astimezone(local_timezone())\n\n    @property\n    def duration(self) -> Optional[timedelta]:\n        \"\"\"\n        The time it took to run this step.\n        \"\"\"\n        if self.start_time is not None and self.end_time is not None:\n            return self.end_time - self.start_time\n        else:\n            return None\n\n    @property\n    def state(self) -> StepState:\n        \"\"\"\n        Returns the state of the step\n        \"\"\"\n        if self.cacheable:\n            if self.start_time is None and self.end_time is None and self.error is None:\n                return StepState.INCOMPLETE\n            if self.start_time is not None and self.end_time is None and self.error is None:\n                return StepState.RUNNING\n            if self.start_time is not None and self.end_time is not None and self.error is None:\n                return StepState.COMPLETED\n            if self.start_time is not None and self.end_time is not None and self.error is not None:\n                return StepState.FAILED\n        else:\n            return StepState.UNCACHEABLE\n        raise RuntimeError(f\"{self.__class__.__name__} is in an invalid state.\")\n\n    def to_json_dict(self) -> Dict[str, Any]:\n        \"\"\"\n        Generates a JSON-safe, human-readable, dictionary representation of this dataclass.\n        \"\"\"\n        return jsonify(self)\n\n    @classmethod\n    def from_json_dict(cls, json_dict: Dict[str, Any]) -> \"StepInfo\":\n        \"\"\"\n        The inverse of :meth:`to_json_dict()`.\n\n        :param json_dict: A dictionary representation, such as the one produced by :meth:`to_json_dict()`.\n        \"\"\"\n        step_info = cls.from_params(\n            {\n                k: (\n                    datetime.strptime(v, \"%Y-%m-%dT%H:%M:%S\").replace(tzinfo=pytz.utc)\n                    if k in {\"start_time\", \"end_time\"} and v is not None\n                    else v\n                )\n                for k, v in json_dict.items()\n                if k != \"config\"\n            }\n        )\n        step_info.config = json_dict.get(\"config\")\n        return step_info\n\n    @classmethod\n    def new_from_step(cls, step: Step, **kwargs) -> \"StepInfo\":\n        try:\n            config = step.config\n        except ValueError:\n            config = None\n        return cls(\n            unique_id=step.unique_id,\n            step_name=step.name,\n            step_class_name=step.class_name,\n            version=step.VERSION,\n            dependencies={dep.unique_id for dep in step.dependencies},\n            cacheable=step.cache_results,\n            config=replace_steps_with_unique_id(config),\n            metadata=step.metadata,\n            **kwargs,\n        )\n\n    def refresh(self):\n        \"\"\"\n        Refresh environment and platform metadata.\n        \"\"\"\n        self.platform = PlatformMetadata()\n        self.environment = EnvironmentMetadata()\n"
  },
  {
    "path": "tango/steps/__init__.py",
    "content": "\"\"\"\nBuilt-in :class:`~tango.step.Step` implementations that are not tied to any particular\nintegration.\n\"\"\"\n\n__all__ = [\"DatasetCombineStep\", \"DatasetRemixStep\", \"PrintStep\", \"ShellStep\"]\n\nfrom .dataset_remix import DatasetCombineStep, DatasetRemixStep\nfrom .print import PrintStep\nfrom .shell_step import ShellStep\n"
  },
  {
    "path": "tango/steps/dataset_remix.py",
    "content": "import collections\nimport random\nimport re\nfrom typing import Any, Dict, List, Mapping, Sequence\n\nfrom tango.common.dataset_dict import DatasetDict\nfrom tango.common.sequences import (\n    ConcatenatedSequence,\n    ShuffledSequence,\n    SlicedSequence,\n)\nfrom tango.step import Step\n\n\n@Step.register(\"dataset_remix\")\nclass DatasetRemixStep(Step[DatasetDict]):\n    \"\"\"\n    This step can remix splits in a :class:`~tango.common.dataset_dict.DatasetDict` into new splits.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"dataset_remix\".\n\n    Examples\n    --------\n\n    .. testcode::\n        :hide:\n\n        from tango.common.logging import initialize_logging\n        initialize_logging(enable_cli_logs=True)\n\n    .. testcode::\n\n        input = DatasetDict({\n            \"train\": list(range(10)),\n            \"dev\": list(range(10, 15)),\n        })\n        new_splits = {\n            \"all\": \"train + dev\",\n            \"crossval_train\": \"train[0:5] + train[7:]\",\n            \"crossval_test\": \"train[5:7]\",\n        }\n        remix_step = DatasetRemixStep(input=input, new_splits=new_splits)\n        remixed_dataset = remix_step.result()\n\n    .. testoutput::\n        :hide:\n        :options: +ELLIPSIS\n\n        ...\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = False  # This is so fast it's not worth caching.\n    VERSION = \"001\"\n\n    def run(  # type: ignore\n        self,\n        input: DatasetDict,\n        new_splits: Dict[str, str],\n        keep_old_splits: bool = True,\n        shuffle_before: bool = False,\n        shuffle_after: bool = False,\n        random_seed: int = 1532637578,\n    ) -> DatasetDict:\n        \"\"\"\n        Remixes and shuffles a dataset. This is done lazily, so all operations are fast.\n\n        :param input:\n            The input dataset that will be remixed.\n        :param new_splits:\n            Specifies the new splits that the output dataset should have. Keys are the name of the new\n            splits. Values refer to the original splits. You can refer to original splits in the following ways:\n\n            * Mention the original split name to copy it to a new name.\n            * Mention the original split name with Python's slicing syntax to select part of the original\n              split's instances. For example, ``\"train[:1000]\"`` selects the first 1000 instances from the\n              ``\"train\"`` split.\n            * ``\"instances + instances\"`` concatenates the instances into one split.\n\n            You can combine these possibilities.\n        :param keep_old_splits:\n            Whether to keep the splits from the input dataset in addition to the new ones given by\n            ``new_splits``.\n        :param shuffle_before:\n            Whether to shuffle the input splits before creating the new ones.\n\n            If you need shuffled instances and you're not sure the input is properly shuffled, use this.\n        :param shuffle_after:\n            Whether to shuffle the input splits after creating the new ones.\n\n            If you need shuffled instances and you're slicing or concatenating splits, use this.\n\n            If you want to be on the safe side, shuffle both before and after. Shuffling is a cheap operation.\n        :param random_seed:\n            Random seed, affects shuffling\n\n        :returns:\n            Returns a new dataset that is appropriately remixed.\n        \"\"\"\n        random.seed(random_seed)\n\n        if shuffle_before:\n            input_splits: Mapping[str, Sequence[Any]] = {\n                split_name: ShuffledSequence(split_instances)\n                for split_name, split_instances in input.splits.items()\n            }\n        else:\n            input_splits = input.splits\n\n        def get_slice(split_name: str) -> Sequence[Any]:\n            slice_match = re.match(r\"(.*)\\[(-?[0-9]*:-?[0-9]*)\\]\", split_name)\n            if slice_match is None:\n                return input[split_name]\n            else:\n                split_name = slice_match[1]\n                slice_args = [int(a) if len(a) > 0 else None for a in slice_match[2].split(\":\")]\n                return SlicedSequence(input[split_name], slice(*slice_args))\n\n        def parse_split_spec(split_spec: str):\n            parts = [get_slice(name.strip()) for name in split_spec.split(\"+\")]\n            if len(parts) == 1:\n                return parts[0]\n            else:\n                return ConcatenatedSequence(*parts)\n\n        if keep_old_splits:\n            result = dict(input_splits.items())\n        else:\n            result = {}\n        result.update(\n            {\n                new_split_name: parse_split_spec(new_split_spec)\n                for new_split_name, new_split_spec in new_splits.items()\n            }\n        )\n\n        if shuffle_after:\n            result = {\n                split_name: ShuffledSequence(split_instances)\n                for split_name, split_instances in result.items()\n            }\n\n        return DatasetDict(splits=result, metadata=input.metadata)\n\n\n@Step.register(\"dataset_combine\")\nclass DatasetCombineStep(Step[DatasetDict]):\n    \"\"\"\n    This step combines multiple :class:`~tango.common.dataset_dict.DatasetDict` s into one.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"dataset_combine\".\n\n    Examples\n    --------\n\n    .. testcode::\n        :hide:\n\n        from tango.common.logging import initialize_logging\n        initialize_logging(enable_cli_logs=True)\n\n    .. testcode::\n\n        input1 = DatasetDict({\n            \"train\": list(range(10)),\n            \"dev\": list(range(10, 15)),\n        })\n        input2 = DatasetDict({\n            \"train\": list(range(15, 25)),\n            \"val\": list(range(25, 30)),\n        })\n        combined = DatasetCombineStep(inputs=[input1, input2])\n        combined_dataset = combined.result()\n\n    .. testoutput::\n        :hide:\n        :options: +ELLIPSIS\n\n        ...\n\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = False  # This is so fast it's not worth caching.\n    VERSION = \"001\"\n\n    def run(  # type: ignore\n        self,\n        inputs: List[DatasetDict],\n        shuffle: bool = False,\n        random_seed: int = 1532637578,\n    ) -> DatasetDict:\n        \"\"\"\n        Combines multiple datasets into one. This is done lazily, so all operations are fast.\n\n        If a split is present in more than one input dataset, the output dataset will have a split that's\n        the concatenation of the input splits.\n\n        :param inputs:\n            The list of input datasets that will be combined.\n        :param shuffle:\n            Whether to shuffle the combined datasets. If you don't do this, the new splits will contain first\n            all the instances from one dataset, and then all the instances from another dataset.\n        :param random_seed:\n            Random seed, affects shuffling\n\n        :returns:\n            Returns a new dataset that is the combination of the input datasets.\n        \"\"\"\n\n        split_to_datasets: Dict[str, List[Sequence]] = collections.defaultdict(lambda: [])\n        for input in inputs:\n            for split_name, sequence in input.items():\n                split_to_datasets[split_name].append(sequence)\n        result: Dict[str, Sequence] = {\n            split_name: ConcatenatedSequence(*sequences)\n            for split_name, sequences in split_to_datasets.items()\n        }\n\n        if shuffle:\n            random.seed(random_seed)\n            result = {\n                split_name: ShuffledSequence(split_instances)\n                for split_name, split_instances in result.items()\n            }\n\n        return DatasetDict(result, {})\n"
  },
  {
    "path": "tango/steps/print.py",
    "content": "import logging\nfrom typing import Any\n\nfrom tango.common.logging import cli_logger\nfrom tango.step import Step\n\n\n@Step.register(\"print\")\nclass PrintStep(Step):\n    \"\"\"\n    This step just logs or prints its input and also returns what it prints.\n    \"\"\"\n\n    DETERMINISTIC = True\n    CACHEABLE = False  # so fast it's not worth caching\n\n    def run(self, input: Any) -> str:  # type: ignore[override]\n        \"\"\"\n        Print out the input.\n        \"\"\"\n        out = str(input)\n        if self.logger.isEnabledFor(logging.INFO):\n            self.logger.info(out)\n        elif cli_logger.isEnabledFor(logging.INFO):\n            cli_logger.info(out)\n        else:\n            print(out)\n        return out\n"
  },
  {
    "path": "tango/steps/shell_step.py",
    "content": "import os\nimport subprocess\nfrom typing import List, Optional, Union\n\nfrom tango.common import PathOrStr, RegistrableFunction, make_registrable\nfrom tango.step import Step\n\n\n@make_registrable(exist_ok=True)\ndef check_path_existence(path: PathOrStr):\n    assert os.path.exists(path), f\"Output not found at {path}!\"\n\n\n@Step.register(\"shell_step\")\nclass ShellStep(Step):\n    \"\"\"\n    This step runs a shell command, and returns the standard output as a string.\n\n    .. tip::\n\n        Registered as a :class:`~tango.step.Step` under the name \"shell_step\".\n\n    :param shell_command: The shell command to run.\n    :param output_path: The step makes no assumptions about the command being run. If your command produces some\n        output, you can optionally specify the output path for recording the output location, and optionally\n        validating it. See `validate_output` argument for this.\n    :param validate_output: If an expected `output_path` has been specified, you can choose to validate that the\n        step produced the correct output. By default, it will just check if the `output_path` exists, but you can\n        pass any other validating function. For example, if your command is a script generating a model output,\n        you can check if the model weights can be loaded.\n    :param kwargs: Other kwargs to be passed to `subprocess.run()`. If you need to take advantage of environment\n        variables, set `shell = True`.\n    \"\"\"\n\n    def run(  # type: ignore[override]\n        self,\n        shell_command: Union[str, List[str]],\n        output_path: Optional[PathOrStr] = None,\n        validate_output: RegistrableFunction = check_path_existence,\n        **kwargs,\n    ):\n        output = self.run_command(shell_command, **kwargs)\n        self.logger.info(output)\n        if output_path is not None:\n            validate_output(output_path)\n            self.logger.info(f\"Output found at: {output_path}\")\n\n        return str(output.decode(\"utf-8\"))\n\n    def run_command(self, command: Union[str, List[str]], **kwargs):\n        import shlex\n\n        if kwargs.get(\"shell\", None):\n            if isinstance(command, list):\n                command = shlex.join(command)\n        else:\n            if isinstance(command, str):\n                command = shlex.split(command)\n        self.logger.info(f\"Command: {command}\")\n        process = subprocess.run(command, capture_output=True, **kwargs)\n        if process.returncode != 0:\n            raise RuntimeError(f\"The command failed with the following errors: {process.stderr}\")\n        return process.stdout\n"
  },
  {
    "path": "tango/version.py",
    "content": "_MAJOR = \"1\"\n_MINOR = \"3\"\n_PATCH = \"2\"\n# This is mainly for pre-releases which have the suffix \"rc[0-9]+\".\n_SUFFIX = \"\"\n\nVERSION_SHORT = \"{0}.{1}\".format(_MAJOR, _MINOR)\nVERSION = \"{0}.{1}.{2}{3}\".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)\n"
  },
  {
    "path": "tango/workspace.py",
    "content": "import logging\nfrom abc import abstractmethod\nfrom contextlib import contextmanager\nfrom dataclasses import dataclass\nfrom datetime import datetime\nfrom pathlib import Path\nfrom tempfile import TemporaryDirectory\nfrom typing import (\n    Any,\n    ContextManager,\n    Dict,\n    Generator,\n    Iterable,\n    List,\n    Optional,\n    TypeVar,\n    Union,\n    cast,\n)\nfrom urllib.parse import ParseResult, urlparse\n\nimport pytz\n\nfrom .common import Registrable\nfrom .common.from_params import FromParams\nfrom .common.util import StrEnum, jsonify, utc_now_datetime\nfrom .step import Step\nfrom .step_cache import StepCache\nfrom .step_info import StepInfo, StepState\n\nlogger = logging.getLogger(__name__)\n\nT = TypeVar(\"T\")\n\n\n@dataclass\nclass Run(FromParams):\n    \"\"\"\n    Stores information about a single Tango run.\n    \"\"\"\n\n    name: str\n    \"\"\"\n    The name of the run\n    \"\"\"\n\n    steps: Dict[str, StepInfo]\n    \"\"\"\n    A mapping from step names to :class:`~tango.step_info.StepInfo`, for all the target steps in the run.\n\n    This only contains the targets of a run. Usually, that means it contains all named steps.\n    Un-named dependencies (or dependencies that are not targets) are not contained in ``steps``.\n    \"\"\"\n\n    start_date: datetime\n    \"\"\"\n    The time at which the run was registered in the workspace.\n    \"\"\"\n\n    def to_json_dict(self) -> Dict[str, Any]:\n        return jsonify(self)\n\n    @classmethod\n    def from_json_dict(cls, json_dict: Dict[str, Any]) -> \"Run\":\n        params = {**json_dict}\n        params[\"start_date\"] = datetime.strptime(params[\"start_date\"], \"%Y-%m-%dT%H:%M:%S\").replace(\n            tzinfo=pytz.utc\n        )\n        params[\"steps\"] = {k: StepInfo.from_json_dict(v) for k, v in params[\"steps\"].items()}\n        return cls.from_params(params)\n\n\n@dataclass\nclass RunInfo(FromParams):\n    \"\"\"\n    Stores partial data about a run. This is the type that you get back from\n    :meth:`Workspace.search_registered_runs()`. The data here is a subset of\n    the data in the :class:`Run` type because not all workspaces can fetch all\n    of the data in the :class:`Run` type efficiently.\n    \"\"\"\n\n    name: str\n    \"\"\"\n    The name of the run.\n    \"\"\"\n\n    steps: Optional[Dict[str, str]] = None\n    \"\"\"\n    The steps within the run. An optional mapping of step name to step unique ID.\n    \"\"\"\n\n    start_date: Optional[datetime] = None\n    \"\"\"\n    The time at which the run was registered in the workspace.\n    \"\"\"\n\n\nclass RunSort(StrEnum):\n    START_DATE = \"start_date\"\n    NAME = \"name\"\n\n\nclass StepInfoSort(StrEnum):\n    UNIQUE_ID = \"unique_id\"\n    START_TIME = \"start_time\"\n\n\nclass Workspace(Registrable):\n    \"\"\"\n    A workspace is a place for Tango to put the results of steps, intermediate results, and various other pieces\n    of metadata. If you don't want to worry about all that, do nothing and Tango will use the default\n    :class:`.LocalWorkspace` that puts everything into a directory of your choosing.\n\n    If you want to do fancy things like store results in the cloud, share state across machines, etc., this is your\n    integration point.\n\n    If you got here solely because you want to share results between machines, consider that\n    :class:`.LocalWorkspace` works fine on an NFS drive.\n    \"\"\"\n\n    default_implementation = \"local\"\n\n    #\n    # As a general rule, workspaces can never return `Step`, only `StepInfo`, because we can't reliably serialize\n    # objects of type `Step`. Doing that would require serializing the code that runs the step, and we can't\n    # do that.\n    #\n\n    def __init__(self):\n        self._delayed_cleanup_temp_dirs: List[TemporaryDirectory] = []\n\n    def __getstate__(self):\n        \"\"\"\n        We override `__getstate__()` to customize how instances of this class are pickled\n        since we don't want to persist certain attributes.\n        \"\"\"\n        out = {k: v for k, v in self.__dict__.items() if k not in {\"_delayed_cleanup_temp_dirs\"}}\n        out[\"_delayed_cleanup_temp_dirs\"] = []\n        return out\n\n    @property\n    @abstractmethod\n    def url(self) -> str:\n        \"\"\"\n        Get a URL for the workspace that can be used to instantiate the same workspace\n        using :meth:`.from_url()`.\n        \"\"\"\n        raise NotImplementedError\n\n    @classmethod\n    def from_url(cls, url: str) -> \"Workspace\":\n        \"\"\"\n        Initialize a :class:`Workspace` from a workspace URL or path, e.g. ``local:///tmp/workspace``\n        would give you a :class:`~tango.workspaces.LocalWorkspace` in the directory ``/tmp/workspace``.\n\n        For :class:`~tango.workspaces.LocalWorkspace`, you can also just pass in a plain\n        path, e.g. ``/tmp/workspace``.\n\n        .. tip::\n            Registered as a workspace constructor under the name \"from_url\".\n\n        \"\"\"\n        parsed = urlparse(url)\n        workspace_type = parsed.scheme or \"local\"\n        workspace_cls = cast(Workspace, cls.by_name(workspace_type))\n        return workspace_cls.from_parsed_url(parsed)\n\n    @classmethod\n    @abstractmethod\n    def from_parsed_url(cls, parsed_url: ParseResult) -> \"Workspace\":\n        \"\"\"\n        Subclasses should override this so that can be initialized from a URL.\n\n        :param parsed_url: The parsed URL object.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    @abstractmethod\n    def step_cache(self) -> StepCache:\n        \"\"\"\n        A :class:`.StepCache` to store step results in\n        \"\"\"\n        raise NotImplementedError()\n\n    def work_dir(self, step: Step) -> Path:\n        \"\"\"Steps that can be restarted (like a training job that gets interrupted half-way through)\n        must save their state somewhere. A :class:`.StepCache` can help by providing a suitable location\n        in this method.\n\n        By default, the step dir is a temporary directory that gets cleaned up after every run.\n        This effectively disables restartability of steps.\"\"\"\n\n        # TemporaryDirectory cleans up the directory automatically when the TemporaryDirectory object\n        # gets garbage collected, so we hold on to it in the Workspace.\n        dir = TemporaryDirectory(prefix=f\"{step.unique_id}-\", suffix=\".step_dir\")\n        self._delayed_cleanup_temp_dirs.append(dir)\n        return Path(dir.name)\n\n    @abstractmethod\n    def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:\n        \"\"\"\n        Returns a :class:`~tango.step_info.StepInfo` for a given step.\n\n        :raises KeyError: If the corresponding step info cannot be found or created.\n            This should never happen if you pass a :class:`~tango.step.Step` object to this method\n            since a :class:`~tango.step_info.StepInfo` can always be created from a\n            :class:`~tango.step.Step`.\n        \"\"\"\n        raise NotImplementedError()\n\n    def search_step_info(\n        self,\n        *,\n        sort_by: Optional[StepInfoSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        state: Optional[StepState] = None,\n        start: int = 0,\n        stop: Optional[int] = None,\n    ) -> List[StepInfo]:\n        \"\"\"\n        Search through steps in the workspace.\n\n        This method is primarily meant to be used to implement a UI, and workspaces don't necessarily\n        need to implement all `sort_by` or filter operations. They should only implement those\n        that can be done efficiently.\n\n        :param sort_by: The field to sort the results by.\n        :param sort_descending: Sort the results in descending order of the ``sort_by`` field.\n        :param match: Only return steps with a unique ID matching this string.\n        :param state: Only return steps that are in the given state.\n        :param start: Start from a certain index in the results.\n        :param stop: Stop at a certain index in the results.\n\n        :raises NotImplementedError: If a workspace doesn't support an efficient implementation\n            for the given sorting/filtering criteria.\n        \"\"\"\n        steps = [\n            step\n            for run in self.registered_runs().values()\n            for step in run.steps.values()\n            if (match is None or match in step.unique_id) and (state is None or step.state == state)\n        ]\n\n        if sort_by == StepInfoSort.START_TIME:\n            now = utc_now_datetime()\n            steps = sorted(\n                steps,\n                key=lambda step: step.start_time or now,\n                reverse=sort_descending,\n            )\n        elif sort_by == StepInfoSort.UNIQUE_ID:\n            steps = sorted(steps, key=lambda step: step.unique_id, reverse=sort_descending)\n        elif sort_by is not None:\n            raise NotImplementedError\n\n        return steps[slice(start, stop)]\n\n    def num_steps(self, *, match: Optional[str] = None, state: Optional[StepState] = None) -> int:\n        \"\"\"\n        Get the total number of registered steps.\n\n        :param match: Only count steps with a unique ID matching this string.\n        :param state: Only count steps that are in the given state.\n        \"\"\"\n        return len(self.search_step_info(match=match, state=state))\n\n    @abstractmethod\n    def step_starting(self, step: Step) -> None:\n        \"\"\"\n        The :class:`.Step` class calls this when a step is about to start running.\n\n        :param step: The step that is about to start.\n\n        :raises StepStateError: If the step is in an unexpected state (e.g. RUNNING).\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def step_finished(self, step: Step, result: T) -> T:\n        \"\"\"\n        The :class:`.Step` class calls this when a step finished running.\n\n        :param step: The step that finished.\n\n        :raises StepStateError: If the step is in an unexpected state (e.g. RUNNING).\n\n        This method is given the result of the step's :meth:`.Step.run` method. It is expected to return that\n        result. This gives it the opportunity to make changes to the result if necessary. For example, if the\n        :meth:`.Step.run` method returns an iterator, that iterator would be consumed when it's written to the\n        cache. So this method can handle the situation and return something other than the now-consumed iterator.\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def step_failed(self, step: Step, e: BaseException) -> None:\n        \"\"\"\n        The :class:`.Step` class calls this when a step failed.\n\n        :param step: The step that failed.\n        :param e: The exception thrown by the step's :meth:`.Step.run` method.\n\n        :raises StepStateError: If the step is in an unexpected state (e.g. RUNNING).\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:\n        \"\"\"\n        Register a run in the workspace. A run is a set of target steps that a user wants to execute.\n\n        :param targets: The steps that the user wants to execute. This could come from a :class:`.StepGraph`.\n        :param name: A name for the run. Runs must have unique names. If not given, this method invents a name and\n                     returns it.\n        :return: The run object\n        \"\"\"\n        raise NotImplementedError()\n\n    def search_registered_runs(\n        self,\n        *,\n        sort_by: Optional[RunSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        start: int = 0,\n        stop: Optional[int] = None,\n    ) -> List[RunInfo]:\n        \"\"\"\n        Search through registered runs in the workspace.\n\n        This method is primarily meant to be used to implement a UI, and workspaces don't necessarily\n        need to implement all `sort_by` or filter operations. They should only implement those\n        that can be done efficiently.\n\n        .. note::\n            The data type returned in the list here is :class:`RunInfo`, which\n            contains a subset of the data in the :class:`Run` type.\n\n        :param sort_by: The field to sort the results by.\n        :param sort_descending: Sort the results in descending order of the ``sort_by`` field.\n        :param match: Only return results with a name matching this string.\n        :param start: Start from a certain index in the results.\n        :param stop: Stop at a certain index in the results.\n\n        :raises NotImplementedError: If a workspace doesn't support an efficient implementation\n            for the given sorting/filtering criteria.\n        \"\"\"\n        runs = [\n            run for run in self.registered_runs().values() if match is None or match in run.name\n        ]\n\n        if sort_by == RunSort.START_DATE:\n            runs = sorted(runs, key=lambda run: run.start_date, reverse=sort_descending)\n        elif sort_by == RunSort.NAME:\n            runs = sorted(runs, key=lambda run: run.name, reverse=sort_descending)\n        elif sort_by is not None:\n            raise NotImplementedError\n\n        return [\n            RunInfo(\n                name=run.name,\n                start_date=run.start_date,\n                steps={k: s.unique_id for k, s in run.steps.items()},\n            )\n            for run in runs[slice(start, stop)]\n        ]\n\n    def num_registered_runs(self, *, match: Optional[str] = None) -> int:\n        \"\"\"\n        Get the number of registered runs.\n\n        :param match: Only count runs with a name matching this string.\n        \"\"\"\n        return len(self.search_registered_runs(match=match))\n\n    @abstractmethod\n    def registered_runs(self) -> Dict[str, Run]:\n        \"\"\"\n        Returns all runs in the workspace\n\n        :return: A dictionary mapping run names to :class:`Run` objects\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def registered_run(self, name: str) -> Run:\n        \"\"\"\n        Returns the run with the given name\n\n        :return: A :class:`Run` object representing the named run\n\n        :raises KeyError: If there is no run with the given name.\n        \"\"\"\n        raise NotImplementedError()\n\n    def step_result_for_run(self, run_name: str, step_name: str) -> Any:\n        \"\"\"\n        Get the result of a step from a run.\n\n        :raises KeyError: If there is no run or step with the given name.\n        \"\"\"\n        run = self.registered_run(run_name)\n        step_info = run.steps[step_name]\n        try:\n            return self.step_cache[step_info]\n        except KeyError:\n            raise KeyError(f\"Step result for '{step_name}' not found in workspace\")\n\n    def step_result(self, step_name: str) -> Any:\n        \"\"\"\n        Get the result of a step from the latest run with a step by that name.\n\n        :raises KeyError: If there is no run with the given step.\n        \"\"\"\n        runs = sorted(self.registered_runs().values(), key=lambda run: run.start_date, reverse=True)\n        for run in runs:\n            if step_name in run.steps:\n                return self.step_cache[run.steps[step_name]]\n        raise KeyError(f\"No step named '{step_name}' found in previous runs\")\n\n    @abstractmethod\n    def remove_step(self, step_unique_id: str):\n        \"\"\"\n        Removes cached step using the given unique step id\n        :raises KeyError: If there is no step with the given name.\n        \"\"\"\n        raise NotImplementedError()\n\n    def capture_logs_for_run(self, name: str) -> ContextManager[None]:\n        \"\"\"\n        Should return a context manager that can be used to capture the logs for a run.\n\n        By default, this doesn't do anything.\n\n        Examples\n        --------\n\n        The :class:`.LocalWorkspace` implementation uses this method to capture logs\n        to a file in the workspace's directory using the :func:`~tango.common.logging.file_handler()`\n        context manager, similar to this:\n\n        .. testcode::\n\n            from tango.common.logging import file_handler\n            from tango.workspace import Workspace\n\n            class MyLocalWorkspace(Workspace):\n                def capture_logs_for_run(self, name: str):\n                    return file_handler(\"/path/to/workspace/\" + name + \".log\")\n\n        \"\"\"\n\n        @contextmanager\n        def do_nothing() -> Generator[None, None, None]:\n            yield None\n\n        return do_nothing()\n\n\nWorkspace.register(\"from_url\", constructor=\"from_url\")(Workspace)  # type: ignore\n"
  },
  {
    "path": "tango/workspaces/__init__.py",
    "content": "\"\"\"\nBuilt-in :class:`~tango.workspace.Workspace` implementations.\n\"\"\"\n\nfrom .local_workspace import LocalWorkspace\nfrom .memory_workspace import MemoryWorkspace, default_workspace\n"
  },
  {
    "path": "tango/workspaces/local_workspace.py",
    "content": "import json\nimport logging\nimport os\nfrom datetime import datetime\nfrom pathlib import Path\nfrom typing import Dict, Iterable, Iterator, List, Optional, Set, TypeVar, Union\nfrom urllib.parse import ParseResult\n\nimport dill\nimport petname\nfrom sqlitedict import SqliteDict\n\nfrom tango.common import PathOrStr\nfrom tango.common.exceptions import StepStateError\nfrom tango.common.file_lock import FileLock\nfrom tango.common.logging import file_handler\nfrom tango.common.util import exception_to_string, utc_now_datetime\nfrom tango.step import Step\nfrom tango.step_cache import StepCache\nfrom tango.step_caches import LocalStepCache\nfrom tango.step_info import StepInfo, StepState\nfrom tango.workspace import Run, StepInfoSort, Workspace\n\nlogger = logging.getLogger(__name__)\n\nT = TypeVar(\"T\")\n\n\n@Workspace.register(\"local\")\nclass LocalWorkspace(Workspace):\n    \"\"\"\n    This is a :class:`.Workspace` that keeps all its data in a local directory. This works great for single-machine\n    jobs, or for multiple machines in a cluster if they can all access the same NFS drive.\n\n    :param dir: The directory to store all the data in\n\n    The directory will have three subdirectories, ``cache/`` for the step cache, ``runs/`` for the runs,\n    and ``latest/`` for the results of the latest run. For the format of the ``cache/`` directory,\n    refer to :class:`.LocalStepCache`. The ``runs/`` directory will contain one subdirectory for each\n    registered run. Each one of those contains a symlink from the name of the step to the results directory\n    in the step cache. Note that :class:`.LocalWorkspace` creates these symlinks even for steps that have not\n    finished yet. You can tell the difference because either the symlink points to a directory that doesn't exist,\n    or it points to a directory in the step cache that doesn't contain results.\n\n    .. tip::\n\n        Registered as a :class:`~tango.workspace.Workspace` under the name \"local\".\n\n        You can also instantiate this workspace from a URL with the scheme ``local://``.\n        For example, ``Workspace.from_url(\"local:///tmp/workspace\")`` gives you a :class:`LocalWorkspace`\n        in the directory ``/tmp/workspace``.\n\n    \"\"\"\n\n    def __init__(self, dir: PathOrStr):\n        super().__init__()\n        self.dir = Path(dir)\n        self.dir.mkdir(parents=True, exist_ok=True)\n        self.cache = LocalStepCache(self.dir / \"cache\")\n        self.locks: Dict[Step, FileLock] = {}\n        self.runs_dir = self.dir / \"runs\"\n        self.runs_dir.mkdir(parents=True, exist_ok=True)\n        self.step_info_file = self.dir / \"stepinfo.sqlite\"\n        self.latest_dir = self.dir / \"latest\"\n\n        # Check the version of the local workspace\n        try:\n            with open(self.dir / \"settings.json\", \"r\") as settings_file:\n                settings = json.load(settings_file)\n        except FileNotFoundError:\n            settings = {\"version\": 1}\n\n        # Upgrade to version 2\n        if settings[\"version\"] == 1:\n            with SqliteDict(self.step_info_file) as d:\n                for stepinfo_file in self.cache.dir.glob(\"*/stepinfo.dill\"):\n                    with stepinfo_file.open(\"rb\") as f:\n                        stepinfo = dill.load(f)\n\n                    # The `StepInfo` class changed from one version to the next. The deserialized version\n                    # ends up being a `StepInfo` instance that is missing the `cacheable` member. This\n                    # hack adds it in.\n                    kwargs = stepinfo.__dict__\n                    kwargs[\n                        \"cacheable\"\n                    ] = True  # Only cacheable steps were saved in v1. That's what v2 fixes.\n                    d[stepinfo.unique_id] = StepInfo(**kwargs)\n                d.commit()\n            for stepinfo_file in self.cache.dir.glob(\"*/stepinfo.dill\"):\n                stepinfo_file.unlink()\n\n            settings[\"version\"] = 2\n            with open(self.dir / \"settings.json\", \"w\") as settings_file:\n                json.dump(settings, settings_file)\n\n    def __getstate__(self):\n        \"\"\"\n        We override `__getstate__()` to customize how instances of this class are pickled\n        since we don't want to persist certain attributes.\n        \"\"\"\n        out = super().__getstate__()\n        out[\"locks\"] = {}\n        return out\n\n    @property\n    def url(self) -> str:\n        return \"local://\" + str(self.dir)\n\n    @classmethod\n    def from_parsed_url(cls, parsed_url: ParseResult) -> \"Workspace\":\n        workspace_dir: Path\n        if parsed_url.netloc:\n            workspace_dir = Path(parsed_url.netloc)\n            if parsed_url.path:\n                workspace_dir = workspace_dir / parsed_url.path.lstrip(\"/\")\n        elif parsed_url.path:\n            workspace_dir = Path(parsed_url.path)\n        else:\n            workspace_dir = Path(\".\")\n        return cls(workspace_dir.resolve())\n\n    def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path:\n        return self.cache.step_dir(step_or_unique_id)\n\n    @property\n    def step_cache(self) -> StepCache:\n        return self.cache\n\n    def work_dir(self, step: Step) -> Path:\n        result = self.step_dir(step) / \"work\"\n        result.mkdir(parents=True, exist_ok=True)\n        return result\n\n    @classmethod\n    def guess_step_dir_state(cls, dir: Path) -> Set[StepState]:\n        \"\"\"\n        Returns the possible states of a given step dir, to the best of our knowledge.\n\n        :param dir: the step dir to example\n        :return: a set of possible states for the step\n        \"\"\"\n\n        # If the directory doesn't exist, the step is incomplete or uncacheable.\n        if not dir.exists():\n            return {StepState.INCOMPLETE, StepState.UNCACHEABLE}\n\n        # If the lock file exists and is locked, the step is running.\n        lock_file = dir / \"lock\"\n        if lock_file.exists():\n            lock = FileLock(lock_file)\n            try:\n                lock.acquire(0.1)\n                lock.release()\n            except TimeoutError:\n                return {StepState.RUNNING}\n\n        # If the directory is empty except for the work dir and the lock file, the step is running, incomplete,\n        # or failed. But it can't be running because then the lockfile would be locked, so it can only be\n        # incomplete or failed.\n        for dir_entry in dir.iterdir():\n            if dir_entry.name == \"work\" and dir_entry.is_dir():\n                continue\n            if dir_entry.name == \"lock\" and dir_entry.is_file():\n                continue\n            break\n        else:\n            return {StepState.INCOMPLETE, StepState.FAILED}\n\n        return set(StepState)\n\n    @staticmethod\n    def _fix_step_info(step_info: StepInfo) -> None:\n        \"\"\"\n        Tragically we need to run a fix-up step over StepInfo objects that are freshly read from\n        the database. This is for backwards compatibility.\n\n        This function operates on the `step_info` object in place.\n        \"\"\"\n        if isinstance(step_info.error, BaseException):\n            step_info.error = exception_to_string(step_info.error)\n\n    def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:\n        with SqliteDict(self.step_info_file) as d:\n\n            def find_or_add_step_info(step_or_unique_id: Union[Step, str]) -> StepInfo:\n                if isinstance(step_or_unique_id, Step):\n                    unique_id = step_or_unique_id.unique_id\n                else:\n                    unique_id = step_or_unique_id\n\n                try:\n                    step_info = d[unique_id]\n                except KeyError:\n                    if not isinstance(step_or_unique_id, Step):\n                        raise\n                    step = step_or_unique_id\n\n                    for dep in step.dependencies:\n                        find_or_add_step_info(dep)\n\n                    step_info = StepInfo.new_from_step(step)\n                    d[unique_id] = step_info\n                    del step\n\n                # Perform some sanity checks. Sqlite and the file system can get out of sync\n                # when a process dies suddenly.\n                step_dir = self.step_dir(unique_id)\n                step_state_guesses = self.guess_step_dir_state(step_dir) or step_info.state\n                if step_info.state not in step_state_guesses:\n                    if step_info.state == StepState.RUNNING:\n                        # We think the step is running, but it can't possibly be running, so we go ahead and\n                        # assume the step is incomplete.\n                        step_info.start_time = None\n                        step_info.end_time = None\n                        d[unique_id] = step_info\n                    else:\n                        possible_states = \", \".join(s.value for s in step_state_guesses)\n                        raise IOError(\n                            f\"The step '{unique_id}' is marked as being {step_info.state.value}, but we \"\n                            f\"determined it can only be one of {{{possible_states}}}. If you are positive \"\n                            f\"this is a screw-up, delete the directory at '{step_dir}' and try again.\"\n                        )\n\n                return step_info\n\n            result = find_or_add_step_info(step_or_unique_id)\n            d.commit()\n            self._fix_step_info(result)\n            return result\n\n    def _step_lock_file(self, step_or_unique_id: Union[Step, str]) -> Path:\n        step_dir = self.step_dir(step_or_unique_id)\n        step_dir.mkdir(parents=True, exist_ok=True)\n        return step_dir / \"lock\"\n\n    def step_starting(self, step: Step) -> None:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return\n\n        # Gather the existing step info first. Step info automatically fixes itself if steps are\n        # marked as \"running\" but are not locked. This happens, for example, when a process\n        # gets killed. To make sure this works, we have to get the step info before we start\n        # messing with locks.\n        step_info = self.step_info(step)\n        if step_info.state not in {StepState.INCOMPLETE, StepState.FAILED}:\n            raise StepStateError(\n                step,\n                step_info.state,\n                context=\"If you are certain the step is not running somewhere else, delete the lock \"\n                f\"file at {self._step_lock_file(step)}.\",\n            )\n\n        if step_info.state == StepState.FAILED:\n            # Refresh environment metadata since it might be out-of-date now.\n            step_info.refresh()\n\n        lock = FileLock(self._step_lock_file(step), read_only_ok=True)\n        lock.acquire_with_updates(desc=f\"acquiring lock for '{step.name}'\")\n        self.locks[step] = lock\n\n        try:\n            step_info.start_time = utc_now_datetime()\n            step_info.end_time = None\n            step_info.error = None\n            step_info.result_location = None\n            with SqliteDict(self.step_info_file) as d:\n                d[step.unique_id] = step_info\n                d.commit()\n        except:  # noqa: E722\n            lock.release()\n            del self.locks[step]\n            raise\n\n    def step_finished(self, step: Step, result: T) -> T:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return result\n\n        lock = self.locks[step]\n\n        step_info = self.step_info(step)\n        if step_info.state != StepState.RUNNING:\n            raise StepStateError(step, step_info.state)\n\n        self.step_cache[step] = result\n        if hasattr(result, \"__next__\"):\n            assert isinstance(result, Iterator)\n            # Caching the iterator will consume it, so we write it to the cache and then read from the cache\n            # for the return value.\n            result = self.step_cache[step]\n\n        # Mark the step as finished\n        step_info.end_time = utc_now_datetime()\n        step_info.result_location = str(self.step_dir(step).absolute())\n        with SqliteDict(self.step_info_file) as d:\n            d[step.unique_id] = step_info\n            d.commit()\n\n        lock.release()\n        del self.locks[step]\n\n        return result\n\n    def step_failed(self, step: Step, e: BaseException) -> None:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return\n\n        lock = self.locks[step]\n\n        try:\n            step_info = self.step_info(step)\n            if step_info.state != StepState.RUNNING:\n                raise StepStateError(step, step_info.state)\n            step_info.end_time = utc_now_datetime()\n            step_info.error = exception_to_string(e)\n            with SqliteDict(self.step_info_file) as d:\n                d[step.unique_id] = step_info\n                d.commit()\n        finally:\n            lock.release()\n            del self.locks[step]\n\n    def remove_step(self, step_unique_id: str) -> None:\n        \"\"\"\n        Get Step unique id from the user and remove the step information from cache\n        :raises KeyError: If no step with the unique name found in the cache dir\n        \"\"\"\n        with SqliteDict(self.step_info_file) as d:\n            try:\n                step_info = self.step_info(step_unique_id)\n                del d[step_unique_id]\n                d.commit()\n                del self.cache[step_info]\n            except KeyError:\n                raise KeyError(f\"No step named '{step_unique_id}' found\")\n\n    def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:\n        # sanity check targets\n        targets = list(targets)\n        if name is None:\n            while name is None or (self.runs_dir / name).exists():\n                name = petname.generate()\n        run_dir = self.runs_dir / name\n\n        # clean any existing run directory\n        if run_dir.exists():\n            for filename in run_dir.iterdir():\n                filename.unlink()\n        else:\n            run_dir.mkdir(parents=True, exist_ok=True)\n\n        # write step info for all steps\n        all_steps = set(targets)\n        for step in targets:\n            all_steps |= step.recursive_dependencies\n\n        self._save_registered_run(name, all_steps)\n\n        # write targets\n        for target in targets:\n            if target.cache_results:\n                target_path = self.step_dir(target)\n                (run_dir / target.name).symlink_to(os.path.relpath(target_path, run_dir))\n\n        self.latest_dir.unlink(missing_ok=True)\n        self.latest_dir.symlink_to(run_dir)\n\n        return self.registered_run(name)\n\n    def registered_runs(self) -> Dict[str, Run]:\n        return {\n            str(run_dir.name): self.registered_run(run_dir.name)\n            for run_dir in self.runs_dir.iterdir()\n            if run_dir.is_dir()\n        }\n\n    def search_step_info(\n        self,\n        *,\n        sort_by: Optional[StepInfoSort] = None,\n        sort_descending: bool = True,\n        match: Optional[str] = None,\n        state: Optional[StepState] = None,\n        start: int = 0,\n        stop: Optional[int] = None,\n    ) -> List[StepInfo]:\n        with SqliteDict(self.step_info_file, flag=\"r\") as d:\n            steps = [\n                step\n                for step in d.values()\n                if (match is None or match in step.unique_id)\n                and (state is None or step.state == state)\n            ]\n\n        if sort_by == StepInfoSort.START_TIME:\n            now = utc_now_datetime()\n            steps = sorted(\n                steps,\n                key=lambda step: step.start_time or now,\n                reverse=sort_descending,\n            )\n        elif sort_by == StepInfoSort.UNIQUE_ID:\n            steps = sorted(steps, key=lambda step: step.unique_id, reverse=sort_descending)\n        elif sort_by is not None:\n            raise NotImplementedError\n\n        return steps[slice(start, stop)]\n\n    def registered_run(self, name: str) -> Run:\n        run_dir = self.runs_dir / name\n        if not run_dir.is_dir():\n            raise KeyError(name)\n        steps_for_run = self._load_registered_run(name)\n        return Run(name, steps_for_run, datetime.fromtimestamp(run_dir.stat().st_ctime))\n\n    def _run_step_info_file(self, name: str) -> Path:\n        return self.runs_dir / name / \"stepinfo.json\"\n\n    def _save_registered_run(self, name: str, all_steps: Iterable[Step]) -> None:\n        step_unique_ids = {}\n        with SqliteDict(self.step_info_file) as d:\n            for step in all_steps:\n                try:\n                    step_info = d[step.unique_id]\n                    step_info.name = step.name\n                    d[step.unique_id] = step_info\n                except KeyError:\n                    d[step.unique_id] = StepInfo.new_from_step(step)\n                step_unique_ids[step.name] = step.unique_id\n\n            d.commit()\n\n            run_step_info_file = self._run_step_info_file(name)\n            with open(run_step_info_file, \"w\") as file_ref:\n                json.dump(step_unique_ids, file_ref)\n\n    def _load_registered_run(self, name: str) -> Dict[str, StepInfo]:\n        run_step_info_file = self._run_step_info_file(name)\n        try:\n            with open(run_step_info_file, \"r\") as file_ref:\n                step_ids = json.load(file_ref)\n        except FileNotFoundError:\n            # for backwards compatibility\n            run_dir = self.runs_dir / name\n            step_ids = {}\n            for step_symlink in run_dir.iterdir():\n                if not step_symlink.is_symlink():\n                    continue\n                step_name = str(step_symlink.name)\n                unique_id = str(step_symlink.resolve().name)\n                step_ids[step_name] = unique_id\n\n        with SqliteDict(self.step_info_file, flag=\"r\") as d:\n            steps_for_run = {}\n            for step_name, unique_id in step_ids.items():\n                step_info = d[unique_id]\n                assert isinstance(step_info, StepInfo)\n                self._fix_step_info(step_info)\n                steps_for_run[step_name] = step_info\n            return steps_for_run\n\n    def run_dir(self, name: str) -> Path:\n        \"\"\"\n        Returns the directory where a given run is stored.\n\n        :param name: The name of the run\n        :return: The directory where the results of the run are stored\n\n        If the run does not exist, this returns the directory where it will be stored if you call\n        :meth:`register_run()` with that name.\n        \"\"\"\n        return self.runs_dir / name\n\n    def capture_logs_for_run(self, name: str):\n        return file_handler(self.run_dir(name) / \"out.log\")\n"
  },
  {
    "path": "tango/workspaces/memory_workspace.py",
    "content": "import copy\nfrom typing import Dict, Iterable, Iterator, Optional, TypeVar, Union\nfrom urllib.parse import ParseResult\n\nimport petname\n\nfrom tango.common.exceptions import StepStateError\nfrom tango.common.util import exception_to_string, utc_now_datetime\nfrom tango.step import Step\nfrom tango.step_cache import StepCache\nfrom tango.step_caches import default_step_cache\nfrom tango.step_info import StepInfo, StepState\nfrom tango.workspace import Run, Workspace\n\nT = TypeVar(\"T\")\n\n\n@Workspace.register(\"memory\")\nclass MemoryWorkspace(Workspace):\n    \"\"\"\n    This is a workspace that keeps all its data in memory. This is useful for debugging or for quick jobs, but of\n    course you don't get any caching across restarts.\n\n    .. tip::\n\n        Registered as a :class:`~tango.workspace.Workspace` under the name \"memory\".\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.unique_id_to_info: Dict[str, StepInfo] = {}\n        self.runs: Dict[str, Run] = {}\n\n    @property\n    def url(self) -> str:\n        return \"memory://\"\n\n    @classmethod\n    def from_parsed_url(cls, parsed_url: ParseResult) -> \"Workspace\":\n        return cls()\n\n    @property\n    def step_cache(self) -> StepCache:\n        return default_step_cache\n\n    def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:\n        unique_id = (\n            step_or_unique_id.unique_id\n            if isinstance(step_or_unique_id, Step)\n            else step_or_unique_id\n        )\n        try:\n            return self.unique_id_to_info[unique_id]\n        except KeyError:\n            if isinstance(step_or_unique_id, Step):\n                step = step_or_unique_id\n                return StepInfo.new_from_step(step)\n            else:\n                raise KeyError()\n\n    def step_starting(self, step: Step) -> None:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return\n\n        self.unique_id_to_info[step.unique_id] = StepInfo.new_from_step(\n            step, start_time=utc_now_datetime()\n        )\n\n    def step_finished(self, step: Step, result: T) -> T:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return result\n\n        existing_step_info = self.unique_id_to_info[step.unique_id]\n        if existing_step_info.state != StepState.RUNNING:\n            raise StepStateError(step, existing_step_info.state)\n        existing_step_info.end_time = utc_now_datetime()\n\n        if step.cache_results:\n            self.step_cache[step] = result\n            if hasattr(result, \"__next__\"):\n                assert isinstance(result, Iterator)\n                # Caching the iterator will consume it, so we write it to the cache and then read from the cache\n                # for the return value.\n                return self.step_cache[step]\n        return result\n\n    def step_failed(self, step: Step, e: BaseException) -> None:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return\n\n        assert e is not None\n        existing_step_info = self.unique_id_to_info[step.unique_id]\n        if existing_step_info.state != StepState.RUNNING:\n            raise StepStateError(step, existing_step_info.state)\n        existing_step_info.end_time = utc_now_datetime()\n        existing_step_info.error = exception_to_string(e)\n\n    def remove_step(self, step_unique_id: str) -> None:\n        \"\"\"\n        Get Step unique id from the user and remove the step information from memory cache\n        :raises KeyError: If no step with the unique name found in the cache dir\n        \"\"\"\n        try:\n            step_info = self.step_info(step_unique_id)\n            del self.unique_id_to_info[step_unique_id]\n            del self.step_cache[step_info]\n        except KeyError:\n            raise KeyError(f\"{step_unique_id} step info not found, step cache cannot be deleted\")\n\n    def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:\n        if name is None:\n            name = petname.generate()\n        steps: Dict[str, StepInfo] = {}\n        for step in targets:\n            step_info = StepInfo.new_from_step(step)\n            self.unique_id_to_info[step.unique_id] = step_info\n            steps[step.unique_id] = step_info\n        run = Run(name, steps, utc_now_datetime())\n        self.runs[name] = run\n        return run\n\n    def registered_runs(self) -> Dict[str, Run]:\n        return copy.deepcopy(self.runs)\n\n    def registered_run(self, name: str) -> Run:\n        return copy.deepcopy(self.runs[name])\n\n\ndefault_workspace = MemoryWorkspace()\n"
  },
  {
    "path": "tango/workspaces/remote_workspace.py",
    "content": "import logging\nimport tempfile\nimport warnings\nfrom abc import abstractmethod\nfrom contextlib import contextmanager\nfrom pathlib import Path\nfrom typing import Dict, Generator, Iterable, Iterator, Optional, Tuple, TypeVar, Union\nfrom urllib.parse import ParseResult\n\nfrom tango.common.exceptions import StepStateError\nfrom tango.common.logging import file_handler\nfrom tango.common.remote_utils import RemoteConstants\nfrom tango.common.util import exception_to_string, tango_cache_dir, utc_now_datetime\nfrom tango.step import Step\nfrom tango.step_caches.remote_step_cache import RemoteStepCache\nfrom tango.step_info import StepInfo, StepState\nfrom tango.workspace import Run, Workspace\n\nT = TypeVar(\"T\")\n\nlogger = logging.getLogger(__name__)\n\n\nclass RemoteWorkspace(Workspace):\n    \"\"\"\n    This is a :class:`~tango.workspace.Workspace` that stores step artifacts on some remote storage location.\n\n    .. tip::\n        All remote workspaces inherit from this.\n    \"\"\"\n\n    Constants = RemoteConstants\n    NUM_CONCURRENT_WORKERS: int = 9\n\n    @property\n    @abstractmethod\n    def cache(self) -> RemoteStepCache:\n        raise NotImplementedError()\n\n    @property\n    @abstractmethod\n    def steps_dir_name(self) -> str:\n        raise NotImplementedError()\n\n    @property\n    @abstractmethod\n    def locks(self) -> Dict:\n        raise NotImplementedError()\n\n    @property\n    def steps_dir(self) -> Path:\n        return tango_cache_dir() / self.steps_dir_name\n\n    @property\n    @abstractmethod\n    def url(self) -> str:\n        raise NotImplementedError()\n\n    @classmethod\n    @abstractmethod\n    def from_parsed_url(cls, parsed_url: ParseResult) -> Workspace:\n        raise NotImplementedError()\n\n    @property\n    def step_cache(self) -> RemoteStepCache:\n        return self.cache\n\n    def step_dir(self, step_or_unique_id: Union[Step, str]) -> Path:\n        unique_id = (\n            step_or_unique_id if isinstance(step_or_unique_id, str) else step_or_unique_id.unique_id\n        )\n        path = self.steps_dir / unique_id\n        path.mkdir(parents=True, exist_ok=True)\n        return path\n\n    def work_dir(self, step: Step) -> Path:\n        path = self.step_dir(step) / \"work\"\n        path.mkdir(parents=True, exist_ok=True)\n        return path\n\n    def step_info(self, step_or_unique_id: Union[Step, str]) -> StepInfo:\n        raise NotImplementedError()\n\n    @abstractmethod\n    def _remote_lock(self, step: Step):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def _step_location(self, step: Step) -> str:\n        raise NotImplementedError()\n\n    def step_starting(self, step: Step) -> None:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return\n\n        # Get local file lock + remote dataset lock.\n        lock = self._remote_lock(step)\n        lock.acquire()\n        self.locks[step] = lock\n\n        step_info = self.step_info(step)\n        if step_info.state == StepState.RUNNING:\n            # Since we've acquired the step lock we know this step can't be running\n            # elsewhere. But the step state can still say its running if the last\n            warnings.warn(\n                f\"Step info for step '{step.unique_id}' is invalid - says step is running \"\n                \"although it shouldn't be. Ignoring and overwriting step start time.\",\n                UserWarning,\n            )\n        elif step_info.state not in {StepState.INCOMPLETE, StepState.FAILED, StepState.UNCACHEABLE}:\n            self.locks.pop(step).release()\n            raise StepStateError(\n                step,\n                step_info.state,\n                context=f\"If you are certain the step is not running somewhere else, delete the step \"\n                f\"datasets at {self._step_location(step)}\",\n            )\n\n        if step_info.state == StepState.FAILED:\n            # Refresh the environment metadata since it might be out-of-date now.\n            step_info.refresh()\n\n        # Update StepInfo to mark as running.\n        try:\n            step_info.start_time = utc_now_datetime()\n            step_info.end_time = None\n            step_info.error = None\n            step_info.result_location = None\n            self._update_step_info(step_info)\n        except:  # noqa: E722\n            self.locks.pop(step).release()\n            raise\n\n    def step_finished(self, step: Step, result: T) -> T:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return result\n\n        step_info = self.step_info(step)\n        if step_info.state != StepState.RUNNING:\n            raise StepStateError(step, step_info.state)\n\n        # Update step info and save step execution metadata.\n        # This needs to be done *before* adding the result to the cache, since adding\n        # the result to the cache will commit the step dataset, making it immutable.\n        step_info.end_time = utc_now_datetime()\n        step_info.result_location = self._step_location(step)\n        self._update_step_info(step_info)\n\n        self.cache[step] = result\n        if hasattr(result, \"__next__\"):\n            assert isinstance(result, Iterator)\n            # Caching the iterator will consume it, so we write it to the cache and then read from the cache\n            # for the return value.\n            result = self.cache[step]\n\n        self.locks.pop(step).release()\n\n        return result\n\n    def step_failed(self, step: Step, e: BaseException) -> None:\n        # We don't do anything with uncacheable steps.\n        if not step.cache_results:\n            return\n\n        try:\n            step_info = self.step_info(step)\n            if step_info.state != StepState.RUNNING:\n                raise StepStateError(step, step_info.state)\n            step_info.end_time = utc_now_datetime()\n            step_info.error = exception_to_string(e)\n            self._update_step_info(step_info)\n        finally:\n            self.locks.pop(step).release()\n\n    def remove_step(self, step_unique_id: str) -> None:\n        \"\"\"\n        Get Step unique id from the user and remove the step information from cache\n        :raises KeyError: If no step with the unique name found in the cache dir\n        \"\"\"\n        try:\n            step_info = self.step_info(step_unique_id)\n            # remove remote objects\n            self._remove_step_info(step_info)\n\n            # remove cache info\n            del self.cache[step_info]\n        except KeyError:\n            raise KeyError(f\"No step named '{step_unique_id}' found.\")\n        return None\n\n    def _get_run_step_info(self, targets: Iterable[Step]) -> Tuple[Dict, Dict]:\n        import concurrent.futures\n\n        all_steps = set(targets)\n        for step in targets:\n            all_steps |= step.recursive_dependencies\n\n        steps: Dict[str, StepInfo] = {}\n        run_data: Dict[str, str] = {}\n\n        # Collect step info.\n        with concurrent.futures.ThreadPoolExecutor(\n            thread_name_prefix=\"RemoteWorkspace._get_run_step_info()-\"\n        ) as executor:\n            step_info_futures = []\n            for step in all_steps:\n                if step.name is None:\n                    continue\n                step_info_futures.append(executor.submit(self.step_info, step))\n            for future in concurrent.futures.as_completed(step_info_futures):\n                step_info = future.result()\n                assert step_info.step_name is not None\n                steps[step_info.step_name] = step_info\n                run_data[step_info.step_name] = step_info.unique_id\n\n        return steps, run_data\n\n    @abstractmethod\n    def _save_run(\n        self, steps: Dict[str, StepInfo], run_data: Dict[str, str], name: Optional[str] = None\n    ) -> Run:\n        raise NotImplementedError()\n\n    def register_run(self, targets: Iterable[Step], name: Optional[str] = None) -> Run:\n        steps, run_data = self._get_run_step_info(targets)\n        run = self._save_run(steps, run_data, name)\n        return run\n\n    @abstractmethod\n    def _save_run_log(self, name: str, log_file: Path):\n        raise NotImplementedError()\n\n    @contextmanager\n    def capture_logs_for_run(self, name: str) -> Generator[None, None, None]:\n        with tempfile.TemporaryDirectory() as tmp_dir_name:\n            log_file = Path(tmp_dir_name) / \"out.log\"\n            try:\n                with file_handler(log_file):\n                    yield None\n            finally:\n                self._save_run_log(name, log_file)\n\n    @abstractmethod\n    def _update_step_info(self, step_info: StepInfo):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def _remove_step_info(self, step_info: StepInfo):\n        raise NotImplementedError()\n"
  },
  {
    "path": "test_fixtures/__init__.py",
    "content": ""
  },
  {
    "path": "test_fixtures/beaker/nvidia_smi.yml",
    "content": "# Used to test that GPUs in a cluster are available. Submit this to beaker with:\n# $ beaker experiment create test_fixtures/beaker/nvidia_smi.yml --workspace ai2/tango-testing --name tango-test-1\nversion: v2-alpha\ndescription: NvidiaSMI\ntasks:\n  - name: nvidia-smi\n    image:\n      docker: nvidia/cuda:11.0-base\n    command: [nvidia-smi]\n    result:\n      path: '/unused'\n    resources:\n      gpuCount: 2\n    context:\n      cluster: ai2/tango-gpu-tests\n      priority: normal\n"
  },
  {
    "path": "test_fixtures/common/params_example.jsonnet",
    "content": "{\n    \"model\": {\n        \"type\": \"classifier\",\n        \"num_classes\": 3,\n        \"layers\": [\n            {\n                \"type\": \"ff\",\n                \"activation\": \"relu\",\n            },\n            {\n                \"type\": \"ff\",\n                \"activation\": \"softmax\",\n            },\n        ]\n    },\n    \"data_path\": \"data.txt\",\n}\n"
  },
  {
    "path": "test_fixtures/common/params_example.yaml",
    "content": "model:\n  type: classifier\n  num_classes: 3\n  layers:\n    - type: ff\n      activation: relu\n    - type: ff\n      activation: softmax\n  data_path: data.txt\n"
  },
  {
    "path": "test_fixtures/experiment/hello_world.jsonnet",
    "content": "{\n    \"steps\": {\n        \"hello\": {\"type\": \"string\", \"result\": \"Hello\"},\n        \"hello_world\": {\n            \"type\": \"concat_strings\",\n            \"string1\": {\"type\": \"ref\", \"ref\": \"hello\"},\n            \"string2\": \"World!\",\n            \"join_with\": \", \",\n        },\n    },\n}\n"
  },
  {
    "path": "test_fixtures/experiment/logging_check.jsonnet",
    "content": "{\n    \"steps\": {\n        \"stringA\": {\"type\": \"logging-step\", \"string\": \"This is a logging test.\", \"num_log_lines\": 5},\n        \"stringB\": {\n            \"type\": \"concat_strings\",\n            \"string1\": {\"type\": \"ref\", \"ref\": \"stringA\"},\n            \"string2\": \"This is being logged.\"\n        },\n        \"stringC\": {\"type\": \"logging-step\", \"string\": \"This is also a logging test.\", \"num_log_lines\": 5},\n        \"final_string\": {\n            \"type\": \"logging-step\",\n            \"string\": {\"type\": \"ref\", \"ref\": \"stringB\"},\n            \"num_log_lines\": 3\n        },\n        \"multiprocessing_result\": {\n            \"type\": \"multiprocessing_step\",\n        }\n    }\n}"
  },
  {
    "path": "test_fixtures/experiment/multiprocessing.jsonnet",
    "content": "{\n    \"steps\": {\n        \"result\": {\n            \"type\": \"multiprocessing_step\",\n        }\n    }\n}\n"
  },
  {
    "path": "test_fixtures/experiment/noisy.jsonnet",
    "content": "{\n    steps: {\n        hello_world: { type: \"string\", result: \"Hello, World!\" },\n        noisy_step: { type: \"noisy_step\" },\n    }\n}\n"
  },
  {
    "path": "test_fixtures/experiment/random.jsonnet",
    "content": "{\n    \"steps\": {\n        \"rand_string1\": {\"type\": \"random_string\", \"length\": 5},\n        \"rand_string2\": {\"type\": \"random_string\", \"length\": 5},\n        \"string1\": {\n            \"type\": \"concat_strings\",\n            \"string1\": {\"type\": \"ref\", \"ref\": \"rand_string1\"},\n            \"string2\": {\"type\": \"ref\", \"ref\": \"rand_string2\"},\n        },\n        \"string2\": {\n            \"type\": \"string\",\n            \"result\": \"foo\",\n        },\n        \"final_string\": {\n            \"type\": \"concat_strings\",\n            \"string1\": {\"type\": \"ref\", \"ref\": \"string1\"},\n            \"string2\": {\"type\": \"ref\", \"ref\": \"string2\"},\n        }\n    }\n}\n"
  },
  {
    "path": "test_fixtures/integrations/__init__.py",
    "content": ""
  },
  {
    "path": "test_fixtures/integrations/common/__init__.py",
    "content": "import torch\nfrom torch.utils.data import IterableDataset\n\nfrom tango import Step\nfrom tango.common import DatasetDict, IterableDatasetDict\n\n\n@Step.register(\"generate_data\")\nclass GenerateData(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(self) -> DatasetDict:  # type: ignore[override]\n        torch.manual_seed(1)\n        return DatasetDict(\n            {\n                \"train\": [{\"x\": torch.rand(10), \"y\": torch.rand(1)} for _ in range(64)],\n                \"validation\": [{\"x\": torch.rand(10), \"y\": torch.rand(1)} for _ in range(32)],\n            }\n        )\n\n\nclass RandomIterableDataset(IterableDataset):\n    def __init__(self, data):\n        self.data = data\n\n    def __iter__(self):\n        return iter(self.data)\n\n\n@Step.register(\"generate_streaming_data\")\nclass GenerateStreamingData(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(self) -> IterableDatasetDict:  # type: ignore[override]\n        torch.manual_seed(1)\n        return IterableDatasetDict(\n            {\n                \"train\": RandomIterableDataset(\n                    [{\"x\": torch.rand(10), \"y\": torch.rand(1)} for _ in range(64)]\n                ),\n                \"validation\": RandomIterableDataset(\n                    [{\"x\": torch.rand(10), \"y\": torch.rand(1)} for _ in range(32)]\n                ),\n            }\n        )\n"
  },
  {
    "path": "test_fixtures/integrations/datasets/config.json",
    "content": "{\n    \"steps\": {\n        \"train_data\": {\n            \"type\": \"datasets::load\",\n            \"path\": \"lhoestq/test\",\n            \"split\": \"train\"\n        },\n        \"dev_data\": {\n            \"type\": \"datasets::load\",\n            \"path\": \"lhoestq/test\",\n            \"split\": \"validation\"\n        },\n        \"all_data\": {\n            \"type\": \"datasets::concatenate\",\n            \"datasets\": [\n                {\n                    \"type\": \"ref\",\n                    \"ref\": \"train_data\"\n                },\n                {\n                    \"type\": \"ref\",\n                    \"ref\": \"dev_data\"\n                }\n            ]\n        },\n        \"mixed_data\": {\n            \"type\": \"datasets::interleave\",\n            \"datasets\": [\n                {\n                    \"type\": \"ref\",\n                    \"ref\": \"train_data\"\n                },\n                {\n                    \"type\": \"ref\",\n                    \"ref\": \"dev_data\"\n                }\n            ],\n            \"probabilities\": [0.9, 0.1]\n        }\n    }\n}\n"
  },
  {
    "path": "test_fixtures/integrations/fairscale/__init__.py",
    "content": ""
  },
  {
    "path": "test_fixtures/integrations/fairscale/components.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom tango import Step\nfrom tango.common import DatasetDict\nfrom tango.integrations.torch import Model\nfrom tango.integrations.torch.util import set_seed_all\n\n\nclass FeedForward(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.linear = nn.Linear(4, 4)\n        self.activation = nn.ReLU()\n\n    def forward(self, x):\n        return self.activation(self.linear(x))\n\n\n@Model.register(\"simple_regression_model\", exist_ok=True)\nclass SimpleRegressionModel(Model):\n    def __init__(self):\n        super().__init__()\n        self.blocks = nn.Sequential(*[FeedForward() for _ in range(3)])\n        self.regression_head = nn.Linear(4, 1)\n        self.loss_fcn = nn.MSELoss()\n\n    def forward(self, x, y):\n        output = self.blocks(x)\n        output = self.regression_head(output)\n        loss = self.loss_fcn(output, y)\n        return {\"loss\": loss}\n\n\n@Step.register(\"simple_regression_data\")\nclass SimpleRegressionDataStep(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(self, seed: int = 317) -> DatasetDict:  # type: ignore\n        set_seed_all(seed)\n\n        def get_data(n: int):\n            return [{\"x\": torch.randn(4), \"y\": torch.randn(1)} for _ in range(n)]\n\n        dataset_dict = DatasetDict(splits={\"train\": get_data(32), \"dev\": get_data(16)})\n        return dataset_dict\n"
  },
  {
    "path": "test_fixtures/integrations/fairscale/config.jsonnet",
    "content": "local pretrained_model = \"sshleifer/tiny-gpt2\";\n\n####################\n# Trainer settings #\n####################\n\nlocal training_steps = 4;\nlocal validate_every = 4;\n\nlocal devices = 2;\nlocal grad_accum = 1;\nlocal batch_size = 2;\n\nlocal activation_checkpointing = true;\nlocal amp = false;\nlocal fsdp = true;\nlocal cpu_offloading = false;  # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow.\n\n######################\n# Optimizer settings #\n######################\n\nlocal warmup_steps = 2;\nlocal learning_rate = 0.005;\n\n\nlocal fsdp_config = {\n    reshard_after_forward: true,\n    move_params_to_cpu: cpu_offloading,\n    move_grads_to_cpu: cpu_offloading,\n    mixed_precision: amp,\n};\n\nlocal training_engine = {\n    type: \"fairscale\",\n    optimizer: {\n        type: \"torch::AdamW\",\n        lr: learning_rate,\n        betas: [0.9, 0.95],\n        eps: 1e-6,\n    },\n    amp: amp,\n    fsdp_config: fsdp_config,\n};\n\nlocal dataloader = {\n  batch_size: batch_size,\n  sampler: {\n    type: \"torch::DistributedSampler\",\n    shuffle: true,\n    drop_last: true,\n  },\n};\n\n{\n    steps: {\n        regression_data: {\n            type: \"simple_regression_data\",\n        },\n        trained_model: {\n            type: \"torch::train\",\n            model: {\n                type: \"fairscale::with_wrapped_modules\",\n                model: {\n                    type: \"simple_regression_model\",\n                },\n                modules_to_wrap: [\"blocks\\\\.[0-9]+\"],\n                fsdp_config: fsdp_config,\n                activation_checkpointing: activation_checkpointing,\n            },\n            training_engine: training_engine,\n            dataset_dict: { type: \"ref\", ref: \"regression_data\" },\n            train_dataloader: dataloader,\n            validation_split: \"dev\",\n            grad_accum: grad_accum,\n            train_steps: training_steps,\n            validate_every: training_steps,\n            validation_steps: 2,\n            checkpoint_every: training_steps,\n            log_every: 1,\n            device_count: devices,\n        },\n    }\n}\n"
  },
  {
    "path": "test_fixtures/integrations/flax/__init__.py",
    "content": ""
  },
  {
    "path": "test_fixtures/integrations/flax/config.jsonnet",
    "content": "{\n    \"steps\": {\n        \"data_full\": {\n            \"type\": \"datasets::load\",\n            \"path\": \"iohadrubin/mini_xsum\",\n        },\n        \"data\": {\n            \"type\": \"datasets::dataset_remix\",\n            \"input\": {\"type\": \"ref\", \"ref\": \"data_full\"},\n            \"new_splits\": {\"train\": \"train[:20]\", \"validation\": \"validation[:20]\"},\n        },\n        \"tokenize\": {\n            \"type\": \"tokenize_data\",\n            \"dataset\": {\n                \"type\": \"ref\",\n                \"ref\": \"data\"\n            }\n        },\n        \"train\": {\n            \"type\": \"flax::train\",\n            \"model\": {\n                \"type\" : \"transformers::FlaxAutoModelForSeq2SeqLM::from_pretrained\",\n                \"pretrained_model_name_or_path\" : \"t5-small\"\n            },\n            \"dataset\": {\n                \"type\": \"ref\",\n                \"ref\": \"tokenize\"\n            },\n            \"optimizer\": {\n                \"type\" : \"optax::adamw\",\n                \"learning_rate\" : 2e-5\n            },\n            \"train_dataloader\": {\n                \"batch_size\": 16,\n                \"drop_last\": true\n            },\n            \"wrapper\": {\n                \"type\": \"xsum_wrapper\"\n            },\n            \"train_split\": \"train\",\n            \"validation_split\" : \"validation\",\n            \"validate_every\" : 1,\n            \"validation_dataloader\": {\n                \"batch_size\": 16,\n                \"drop_last\": true\n            },\n            \"train_epoch\": 1,\n            \"checkpoint_every\": 1,\n            \"log_every\": 1\n        }\n    }\n}"
  },
  {
    "path": "test_fixtures/integrations/flax/xsum.py",
    "content": "import jax.numpy as jnp\nimport numpy as np\nimport optax\nfrom flax.training.common_utils import onehot\nfrom transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSeq2SeqLM\n\nfrom tango.integrations.flax import FlaxWrapper\nfrom tango.step import Step\n\n\"\"\"\nA minimal xsum t5-small config for testing.\n\"\"\"\n\n\n@Step.register(\"tokenize_data\")\nclass PreProcessing(Step):\n    DETERMINISTIC = False\n\n    def run(self, dataset):\n        tokenizer = AutoTokenizer.from_pretrained(\"t5-small\")\n        model = FlaxAutoModelForSeq2SeqLM.from_pretrained(\"t5-small\")\n        model_module = __import__(model.__module__, fromlist=[\"shift_tokens_tight\"])\n        shift_tokens_right_fn = getattr(model_module, \"shift_tokens_right\")\n        config = AutoConfig.from_pretrained(\"t5-small\")\n\n        MAX_SOURCE_LENGTH = 512\n        MAX_TGT_LENGTH = 64\n\n        def preprocess_function(examples):\n            inputs = examples[\"document\"]\n            targets = examples[\"summary\"]\n            inputs = [inp for inp in inputs]\n            model_inputs = tokenizer(\n                inputs,\n                max_length=MAX_SOURCE_LENGTH,\n                padding=\"max_length\",\n                truncation=True,\n                return_tensors=\"np\",\n            )\n\n            # Setup the tokenizer for targets\n            with tokenizer.as_target_tokenizer():\n                labels = tokenizer(\n                    targets,\n                    max_length=MAX_TGT_LENGTH,\n                    padding=\"max_length\",\n                    truncation=True,\n                    return_tensors=\"np\",\n                )\n\n            model_inputs[\"labels\"] = labels[\"input_ids\"]\n            decoder_input_ids = shift_tokens_right_fn(\n                labels[\"input_ids\"], config.pad_token_id, config.decoder_start_token_id\n            )\n            model_inputs[\"decoder_input_ids\"] = np.asarray(decoder_input_ids)\n\n            # We need decoder_attention_mask so we can ignore pad tokens from loss\n            model_inputs[\"decoder_attention_mask\"] = labels[\"attention_mask\"]\n\n            return model_inputs\n\n        column_names = dataset[\"train\"].column_names\n\n        dataset = dataset.map(\n            preprocess_function,\n            batched=True,\n            remove_columns=column_names,\n            desc=\"Running tokenizer on dataset\",\n        )\n\n        return dataset\n\n\n@FlaxWrapper.register(\"xsum_wrapper\")  # type: ignore\nclass TransformerWrapper(FlaxWrapper):\n    def train_metrics(self, state, batch, labels):\n        # return empty dict if no other metrics to compute\n        return {}\n\n    def loss_helper(self, logits, labels, batch):\n        label_smoothing_factor = 0\n        padding_mask = batch[\"decoder_attention_mask\"]\n        vocab_size = logits.shape[-1]\n        confidence = 1.0 - label_smoothing_factor\n        low_confidence = (1.0 - confidence) / (vocab_size - 1)\n        normalizing_constant = -(\n            confidence * jnp.log(confidence)\n            + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)\n        )\n        soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)\n\n        loss = optax.softmax_cross_entropy(logits, soft_labels)\n        loss = loss - normalizing_constant\n\n        # ignore padded tokens from loss\n        loss = loss * padding_mask\n        loss = loss.sum() / padding_mask.sum()\n        return loss\n\n    def train_loss(self, params, state, batch, dropout_rng, labels):\n        logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n        loss = self.loss_helper(logits, labels, batch)\n        return loss\n\n    def val_metrics(self, batch, logits, labels):\n        loss = self.loss_helper(logits, labels, batch)\n        metrics = {\"loss\": loss}\n        return metrics\n\n    def eval_metrics(self, batch, logits, labels):\n        loss = self.loss_helper(logits, labels, batch)\n        metrics = {\"loss\": loss}\n        return metrics\n"
  },
  {
    "path": "test_fixtures/integrations/torch/__init__.py",
    "content": "import torch.nn as nn\n\nfrom tango.integrations.torch import Model\n\n\n@Model.register(\"basic_regression\")\nclass BasicRegression(Model):\n    def __init__(self):\n        super().__init__()\n        self.linear = nn.Linear(10, 1)\n        self.sigmoid = nn.Sigmoid()\n        self.mse = nn.MSELoss()\n\n    def forward(self, x, y=None):\n        pred = self.sigmoid(self.linear(x))\n        out = {\"pred\": pred}\n        if y is not None:\n            out[\"loss\"] = self.mse(pred, y)\n        return out\n\n    def _to_params(self):\n        return {}\n"
  },
  {
    "path": "test_fixtures/integrations/torch/eval.jsonnet",
    "content": "{\n    \"steps\": {\n        \"data\": {\n            \"type\": \"generate_data\",\n        },\n        \"eval\": {\n            \"type\": \"torch::eval\",\n            \"model\": {\n                \"type\": \"basic_regression\",\n            },\n            \"dataset_dict\": {\n                \"type\": \"ref\",\n                \"ref\": \"data\"\n            },\n            \"dataloader\": {\n                \"batch_size\": 8,\n                \"shuffle\": true\n            },\n            \"test_split\": \"validation\",\n            \"log_every\": 1\n        }\n    }\n}\n"
  },
  {
    "path": "test_fixtures/integrations/torch/train.jsonnet",
    "content": "{\n    \"steps\": {\n        \"data\": {\n            \"type\": \"generate_data\",\n        },\n        \"train\": {\n            \"type\": \"torch::train\",\n            \"model\": {\n                \"type\": \"basic_regression\",\n            },\n            \"training_engine\": {\n                \"optimizer\": {\n                    \"type\": \"torch::Adam\",\n                },\n            },\n            \"dataset_dict\": {\n                \"type\": \"ref\",\n                \"ref\": \"data\"\n            },\n            \"train_dataloader\": {\n                \"batch_size\": 8,\n                \"shuffle\": true\n            },\n            \"validation_split\": \"validation\",\n            \"validation_dataloader\": {\n                \"batch_size\": 8,\n                \"shuffle\": false\n            },\n            \"train_steps\": 100,\n            \"validate_every\": 10,\n            \"checkpoint_every\": 10,\n            \"log_every\": 1\n        }\n    }\n}\n"
  },
  {
    "path": "test_fixtures/integrations/torch/train_dist.jsonnet",
    "content": "{\n    \"steps\": {\n        \"data\": {\n            \"type\": \"generate_data\",\n        },\n        \"train\": {\n            \"type\": \"torch::train\",\n            \"model\": {\n                \"type\": \"basic_regression\",\n            },\n            \"training_engine\": {\n                \"optimizer\": {\n                    \"type\": \"torch::Adam\",\n                },\n            },\n            \"dataset_dict\": {\n                \"type\": \"ref\",\n                \"ref\": \"data\",\n            },\n            \"train_dataloader\": {\n                \"batch_size\": 8,\n                \"sampler\": {\n                    \"type\": \"torch::DistributedSampler\",\n                    \"shuffle\": true,\n                    \"drop_last\": true,\n                }\n            },\n            \"validation_split\": \"validation\",\n            \"validation_dataloader\": {\n                \"batch_size\": 8,\n                \"sampler\": {\n                    \"type\": \"torch::DistributedSampler\",\n                    \"shuffle\": true,\n                    \"drop_last\": true,\n                }\n            },\n            \"train_steps\": 100,\n            \"validate_every\": 10,\n            \"checkpoint_every\": 10,\n            \"log_every\": 1,\n            \"device_count\": 2,\n        }\n    }\n}\n"
  },
  {
    "path": "test_fixtures/integrations/torch/train_streaming.jsonnet",
    "content": "{\n    \"steps\": {\n        \"data\": {\n            \"type\": \"generate_streaming_data\",\n        },\n        \"train\": {\n            \"type\": \"torch::train\",\n            \"model\": {\n                \"type\": \"basic_regression\",\n            },\n            \"training_engine\": {\n                \"optimizer\": {\n                    \"type\": \"torch::Adam\",\n                },\n            },\n            \"dataset_dict\": {\n                \"type\": \"ref\",\n                \"ref\": \"data\"\n            },\n            \"train_dataloader\": {\n                \"batch_size\": 8,\n                \"shuffle\": true\n            },\n            \"train_steps\": 100,\n            \"checkpoint_every\": 10,\n            \"log_every\": 1\n        }\n    }\n}\n"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/cache-metadata.json",
    "content": "{\n    \"step\": \"AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP\"\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/conda-environment.yaml",
    "content": "name: tango\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - appnope=0.1.2=py38hecd8cb5_1001\n  - backcall=0.2.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - bzip2=1.0.8=h1de35cc_0\n  - ca-certificates=2021.10.26=hecd8cb5_2\n  - certifi=2021.10.8=py38hecd8cb5_0\n  - decorator=5.1.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=h0a44026_0\n  - freetype=2.11.0=hd8bbffd_0\n  - gettext=0.21.0=h7535e17_0\n  - giflib=5.2.1=haf1e3a3_0\n  - gmp=6.2.1=h23ab428_2\n  - gnutls=3.6.15=hed9c0bf_0\n  - icu=58.2=h0a44026_3\n  - intel-openmp=2021.4.0=hecd8cb5_3538\n  - ipython=7.29.0=py38h01d92e1_0\n  - jedi=0.18.0=py38hecd8cb5_1\n  - jpeg=9d=h9ed2024_0\n  - lame=3.100=h1de35cc_0\n  - lcms2=2.12=hf1fd2bf_0\n  - libcxx=12.0.0=h2f01273_0\n  - libffi=3.3=hb1e8313_2\n  - libiconv=1.16=h1de35cc_0\n  - libidn2=2.3.2=h9ed2024_0\n  - libpng=1.6.37=ha441bb4_0\n  - libtasn1=4.16.0=h9ed2024_0\n  - libtiff=4.2.0=h87d7836_0\n  - libunistring=0.9.10=h9ed2024_0\n  - libuv=1.40.0=haf1e3a3_0\n  - libwebp=1.2.0=hacca55c_0\n  - libwebp-base=1.2.0=h9ed2024_0\n  - libxml2=2.9.12=hcdb78fc_0\n  - llvm-openmp=12.0.0=h0dcd299_1\n  - lz4-c=1.9.3=h23ab428_1\n  - matplotlib-inline=0.1.2=pyhd3eb1b0_2\n  - mkl=2021.4.0=hecd8cb5_637\n  - mkl-service=2.4.0=py38h9ed2024_0\n  - mkl_fft=1.3.1=py38h4ab4a9b_0\n  - mkl_random=1.2.2=py38hb2f4e1b_0\n  - ncurses=6.3=hca72f7f_1\n  - nettle=3.7.3=h230ac6f_1\n  - numpy=1.21.2=py38h4b4dc7a_0\n  - numpy-base=1.21.2=py38he0bd621_0\n  - olefile=0.46=pyhd3eb1b0_0\n  - openh264=2.1.0=hd9629dc_0\n  - openssl=1.1.1l=h9ed2024_0\n  - parso=0.8.2=pyhd3eb1b0_0\n  - pexpect=4.8.0=pyhd3eb1b0_3\n  - pickleshare=0.7.5=pyhd3eb1b0_1003\n  - pillow=8.4.0=py38h98e4679_0\n  - pip=21.2.4=py38hecd8cb5_0\n  - prompt-toolkit=3.0.20=pyhd3eb1b0_0\n  - ptyprocess=0.7.0=pyhd3eb1b0_2\n  - pygments=2.10.0=pyhd3eb1b0_0\n  - python=3.8.12=h88f2d9e_0\n  - pytorch=1.10.0=py3.8_0\n  - readline=8.1=h9ed2024_0\n  - setuptools=58.0.4=py38hecd8cb5_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sqlite=3.36.0=hce871da_0\n  - tk=8.6.11=h7bc2e8c_0\n  - torchaudio=0.10.0=py38_cpu\n  - torchvision=0.11.1=py38_cpu\n  - traitlets=5.1.0=pyhd3eb1b0_0\n  - typing_extensions=3.10.0.2=pyh06a4308_0\n  - wcwidth=0.2.5=pyhd3eb1b0_0\n  - wheel=0.37.0=pyhd3eb1b0_1\n  - xz=5.2.5=h1de35cc_0\n  - zlib=1.2.11=h1de35cc_3\n  - zstd=1.4.9=h322a384_0\n  - pip:\n    - absl-py==0.15.0\n    - aiohttp==3.8.0\n    - aiosignal==1.2.0\n    - alabaster==0.7.12\n    - astunparse==1.6.3\n    - async-timeout==4.0.0\n    - attrs==21.2.0\n    - babel==2.9.1\n    - base58==2.1.1\n    - beautifulsoup4==4.10.0\n    - black==21.12b0\n    - bleach==4.1.0\n    - boto3==1.19.12\n    - botocore==1.22.12\n    - cached-path==1.0.0\n    - cachetools==4.2.4\n    - charset-normalizer==2.0.7\n    - click==8.0.3\n    - click-help-colors==0.9.1\n    - codecov==2.1.12\n    - colorama==0.4.4\n    - configparser==5.1.0\n    - coverage==6.1.1\n    - datasets==1.15.1\n    - dill==0.3.4\n    - docker-pycreds==0.4.0\n    - docutils==0.17.1\n    - filelock==3.4.0\n    - flake8==4.0.1\n    - flaky==3.7.0\n    - flatbuffers==2.0\n    - frozenlist==1.2.0\n    - fsspec==2021.11.0\n    - furo==2022.1.2\n    - future==0.18.2\n    - gast==0.4.0\n    - gitdb==4.0.9\n    - gitpython==3.1.24\n    - glob2==0.7\n    - google-api-core==2.2.2\n    - google-auth==2.3.3\n    - google-auth-oauthlib==0.4.6\n    - google-cloud-core==2.1.0\n    - google-cloud-storage==1.42.3\n    - google-crc32c==1.3.0\n    - google-pasta==0.2.0\n    - google-resumable-media==2.1.0\n    - googleapis-common-protos==1.53.0\n    - grpcio==1.41.1\n    - h5py==3.6.0\n    - huggingface-hub==0.1.1\n    - idna==3.3\n    - imagesize==1.2.0\n    - importlib-metadata==4.8.1\n    - iniconfig==1.1.1\n    - isort==5.10.1\n    - jinja2==3.0.2\n    - jmespath==0.10.0\n    - joblib==1.1.0\n    - jsonnet==0.17.0\n    - keras==2.7.0\n    - keras-preprocessing==1.1.2\n    - keyring==23.2.1\n    - libclang==12.0.0\n    - livereload==2.6.3\n    - markdown==3.3.4\n    - markdown-it-py==1.1.0\n    - markupsafe==2.0.1\n    - mccabe==0.6.1\n    - mdit-py-plugins==0.3.0\n    - more-itertools==8.10.0\n    - multidict==5.2.0\n    - multiprocess==0.70.12.2\n    - mypy==0.931\n    - mypy-extensions==0.4.3\n    - myst-parser==0.16.1\n    - nltk==3.6.7\n    - oauthlib==3.1.1\n    - opt-einsum==3.3.0\n    - overrides==6.1.0\n    - packaging==21.2\n    - pandas==1.3.4\n    - pathspec==0.9.0\n    - pathtools==0.1.2\n    - petname==2.6\n    - pkginfo==1.7.1\n    - platformdirs==2.4.0\n    - pluggy==1.0.0\n    - promise==2.3\n    - protobuf==3.19.1\n    - psutil==5.8.0\n    - py==1.11.0\n    - pyarrow==6.0.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycodestyle==2.8.0\n    - pydeprecate==0.3.1\n    - pyflakes==2.4.0\n    - pyparsing==2.4.7\n    - pytest==6.2.5\n    - pytest-cov==3.0.0\n    - pytest-sphinx==0.3.1\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.5.1\n    - pytz==2021.3\n    - pyyaml==6.0\n    - readme-renderer==30.0\n    - regex==2021.11.2\n    - requests==2.26.0\n    - requests-oauthlib==1.3.0\n    - requests-toolbelt==0.9.1\n    - rfc3986==1.5.0\n    - rouge-score==0.0.4\n    - rsa==4.7.2\n    - s3transfer==0.5.0\n    - sacremoses==0.0.46\n    - sentencepiece==0.1.96\n    - sentry-sdk==1.4.3\n    - shortuuid==1.0.1\n    - smmap==5.0.0\n    - snowballstemmer==2.1.0\n    - soupsieve==2.3\n    - sphinx==4.3.1\n    - sphinx-autobuild==2021.3.14\n    - sphinx-copybutton==0.4.0\n    - sphinxcontrib-applehelp==1.0.2\n    - sphinxcontrib-devhelp==1.0.2\n    - sphinxcontrib-htmlhelp==2.0.0\n    - sphinxcontrib-jsmath==1.0.1\n    - sphinxcontrib-qthelp==1.0.3\n    - sphinxcontrib-serializinghtml==1.1.5\n    - sqlitedict==1.7.0\n    - subprocess32==3.5.4\n    - tensorboard==2.7.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.0\n    - tensorflow==2.7.0\n    - tensorflow-estimator==2.7.0\n    - tensorflow-io-gcs-filesystem==0.23.1\n    - termcolor==1.1.0\n    - tokenizers==0.10.3\n    - toml==0.10.2\n    - tomli==1.2.2\n    - torchmetrics==0.6.0\n    - tornado==6.1\n    - tqdm==4.62.3\n    - transformers==4.12.3\n    - twine==3.5.0\n    - types-pyyaml==6.0.0\n    - types-setuptools==57.4.2\n    - typing-utils==0.1.0\n    - urllib3==1.26.7\n    - wandb==0.12.6\n    - webencodings==0.5.1\n    - werkzeug==2.0.2\n    - wrapt==1.13.3\n    - xxhash==2.0.2\n    - yarl==1.7.2\n    - yaspin==2.1.0\n    - zipp==3.6.0\nprefix: /opt/miniconda3/envs/tango\n"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/executor-metadata.json",
    "content": "{\n    \"config\": {\n        \"type\": \"cadd\",\n        \"a\": {\n            \"type\": \"ref\",\n            \"ref\": \"CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk\"\n        },\n        \"b\": {\n            \"type\": \"ref\",\n            \"ref\": \"MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf\"\n        }\n    },\n    \"duration\": 0.0007,\n    \"finished_at\": 1642546363.9658601,\n    \"git\": {\n        \"commit\": \"8e09b66caffbff20fd0b1504c961932b97417e8d\",\n        \"remote\": \"https://github.com/allenai/tango.git\"\n    },\n    \"platform\": {\n        \"cpu_count\": 16,\n        \"executable\": \"/opt/miniconda3/envs/tango/bin/python\",\n        \"host\": \"ip-192-168-1-194.us-west-2.compute.internal\",\n        \"operating_system\": \"macOS-10.16-x86_64-i386-64bit\",\n        \"python\": \"3.8.12\",\n        \"root\": \"/Users/dirkg/Documents/tango/examples/euler\",\n        \"user\": \"dirkg\"\n    },\n    \"started_at\": 1642546363.965193,\n    \"step\": \"AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP\",\n    \"tango\": {\n        \"command\": \"/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic\",\n        \"version\": \"0.4.0rc4\"\n    }\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/lock",
    "content": ""
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP/requirements.txt",
    "content": "absl-py==0.15.0\nai2-tango==0.4.0rc1\naiohttp==3.8.0\naiosignal==1.2.0\nalabaster==0.7.12\nappnope==0.1.2\nastunparse==1.6.3\nasync-timeout==4.0.0\nattrs==21.2.0\nbabel==2.9.1\nbackcall==0.2.0\nbase58==2.1.1\nbeautifulsoup4==4.10.0\nblack==21.12b0\nbleach==4.1.0\nboto3==1.19.12\nbotocore==1.22.12\ncached-path==1.0.0\ncachetools==4.2.4\ncertifi==2021.10.8\ncharset-normalizer==2.0.7\nclick-help-colors==0.9.1\nclick==8.0.3\ncodecov==2.1.12\ncolorama==0.4.4\nconfigparser==5.1.0\ncoverage==6.1.1\ndatasets==1.15.1\ndecorator==5.1.0\ndill==0.3.4\ndocker-pycreds==0.4.0\ndocutils==0.17.1\nfilelock==3.4.0\nflake8==4.0.1\nflaky==3.7.0\nflatbuffers==2.0\nfrozenlist==1.2.0\nfsspec==2021.11.0\nfuro==2022.1.2\nfuture==0.18.2\ngast==0.4.0\ngitdb==4.0.9\ngitpython==3.1.24\nglob2==0.7\ngoogle-api-core==2.2.2\ngoogle-auth-oauthlib==0.4.6\ngoogle-auth==2.3.3\ngoogle-cloud-core==2.1.0\ngoogle-cloud-storage==1.42.3\ngoogle-crc32c==1.3.0\ngoogle-pasta==0.2.0\ngoogle-resumable-media==2.1.0\ngoogleapis-common-protos==1.53.0\ngrpcio==1.41.1\nh5py==3.6.0\nhuggingface-hub==0.1.1\nidna==3.3\nimagesize==1.2.0\nimportlib-metadata==4.8.1\niniconfig==1.1.1\nipython==7.29.0\nisort==5.10.1\njedi==0.18.0\njinja2==3.0.2\njmespath==0.10.0\njoblib==1.1.0\njsonnet==0.17.0\nkeras-preprocessing==1.1.2\nkeras==2.7.0\nkeyring==23.2.1\nlibclang==12.0.0\nlivereload==2.6.3\nmarkdown-it-py==1.1.0\nmarkdown==3.3.4\nmarkupsafe==2.0.1\nmatplotlib-inline==0.1.2\nmccabe==0.6.1\nmdit-py-plugins==0.3.0\nmkl-fft==1.3.1\nmkl-random==1.2.2\nmkl-service==2.4.0\nmore-itertools==8.10.0\nmultidict==5.2.0\nmultiprocess==0.70.12.2\nmypy-extensions==0.4.3\nmypy==0.931\nmyst-parser==0.16.1\nnltk==3.6.7\nnumpy==1.21.2\noauthlib==3.1.1\nolefile==0.46\nopt-einsum==3.3.0\noverrides==6.1.0\npackaging==21.2\npandas==1.3.4\nparso==0.8.2\npathspec==0.9.0\npathtools==0.1.2\npetname==2.6\npexpect==4.8.0\npickleshare==0.7.5\npillow==8.4.0\npip==21.2.4\npkginfo==1.7.1\nplatformdirs==2.4.0\npluggy==1.0.0\npromise==2.3\nprompt-toolkit==3.0.20\nprotobuf==3.19.1\npsutil==5.8.0\nptyprocess==0.7.0\npy==1.11.0\npyarrow==6.0.0\npyasn1-modules==0.2.8\npyasn1==0.4.8\npycodestyle==2.8.0\npydeprecate==0.3.1\npyflakes==2.4.0\npygments==2.10.0\npyparsing==2.4.7\npytest-cov==3.0.0\npytest-sphinx==0.3.1\npytest==6.2.5\npython-dateutil==2.8.2\npytorch-lightning==1.5.1\npytz==2021.3\npyyaml==6.0\nreadme-renderer==30.0\nregex==2021.11.2\nrequests-oauthlib==1.3.0\nrequests-toolbelt==0.9.1\nrequests==2.26.0\nrfc3986==1.5.0\nrouge-score==0.0.4\nrsa==4.7.2\ns3transfer==0.5.0\nsacremoses==0.0.46\nsentencepiece==0.1.96\nsentry-sdk==1.4.3\nsetuptools==58.0.4\nshortuuid==1.0.1\nsix==1.16.0\nsmmap==5.0.0\nsnowballstemmer==2.1.0\nsoupsieve==2.3\nsphinx-autobuild==2021.3.14\nsphinx-copybutton==0.4.0\nsphinx==4.3.1\nsphinxcontrib-applehelp==1.0.2\nsphinxcontrib-devhelp==1.0.2\nsphinxcontrib-htmlhelp==2.0.0\nsphinxcontrib-jsmath==1.0.1\nsphinxcontrib-qthelp==1.0.3\nsphinxcontrib-serializinghtml==1.1.5\nsqlitedict==1.7.0\nsubprocess32==3.5.4\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.0\ntensorboard==2.7.0\ntensorflow-estimator==2.7.0\ntensorflow-io-gcs-filesystem==0.23.1\ntensorflow==2.7.0\ntermcolor==1.1.0\ntokenizers==0.10.3\ntoml==0.10.2\ntomli==1.2.2\ntorch==1.10.0\ntorchaudio==0.10.0\ntorchmetrics==0.6.0\ntorchvision==0.11.1\ntornado==6.1\ntqdm==4.62.3\ntraitlets==5.1.0\ntransformers==4.12.3\ntwine==3.5.0\ntypes-pyyaml==6.0.0\ntypes-setuptools==57.4.2\ntyping-extensions==3.10.0.2\ntyping-utils==0.1.0\nurllib3==1.26.7\nwandb==0.12.6\nwcwidth==0.2.5\nwebencodings==0.5.1\nwerkzeug==2.0.2\nwheel==0.37.0\nwrapt==1.13.3\nxxhash==2.0.2\nyarl==1.7.2\nyaspin==2.1.0\nzipp==3.6.0"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/cache-metadata.json",
    "content": "{\n    \"step\": \"CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk\"\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/conda-environment.yaml",
    "content": "name: tango\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - appnope=0.1.2=py38hecd8cb5_1001\n  - backcall=0.2.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - bzip2=1.0.8=h1de35cc_0\n  - ca-certificates=2021.10.26=hecd8cb5_2\n  - certifi=2021.10.8=py38hecd8cb5_0\n  - decorator=5.1.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=h0a44026_0\n  - freetype=2.11.0=hd8bbffd_0\n  - gettext=0.21.0=h7535e17_0\n  - giflib=5.2.1=haf1e3a3_0\n  - gmp=6.2.1=h23ab428_2\n  - gnutls=3.6.15=hed9c0bf_0\n  - icu=58.2=h0a44026_3\n  - intel-openmp=2021.4.0=hecd8cb5_3538\n  - ipython=7.29.0=py38h01d92e1_0\n  - jedi=0.18.0=py38hecd8cb5_1\n  - jpeg=9d=h9ed2024_0\n  - lame=3.100=h1de35cc_0\n  - lcms2=2.12=hf1fd2bf_0\n  - libcxx=12.0.0=h2f01273_0\n  - libffi=3.3=hb1e8313_2\n  - libiconv=1.16=h1de35cc_0\n  - libidn2=2.3.2=h9ed2024_0\n  - libpng=1.6.37=ha441bb4_0\n  - libtasn1=4.16.0=h9ed2024_0\n  - libtiff=4.2.0=h87d7836_0\n  - libunistring=0.9.10=h9ed2024_0\n  - libuv=1.40.0=haf1e3a3_0\n  - libwebp=1.2.0=hacca55c_0\n  - libwebp-base=1.2.0=h9ed2024_0\n  - libxml2=2.9.12=hcdb78fc_0\n  - llvm-openmp=12.0.0=h0dcd299_1\n  - lz4-c=1.9.3=h23ab428_1\n  - matplotlib-inline=0.1.2=pyhd3eb1b0_2\n  - mkl=2021.4.0=hecd8cb5_637\n  - mkl-service=2.4.0=py38h9ed2024_0\n  - mkl_fft=1.3.1=py38h4ab4a9b_0\n  - mkl_random=1.2.2=py38hb2f4e1b_0\n  - ncurses=6.3=hca72f7f_1\n  - nettle=3.7.3=h230ac6f_1\n  - numpy=1.21.2=py38h4b4dc7a_0\n  - numpy-base=1.21.2=py38he0bd621_0\n  - olefile=0.46=pyhd3eb1b0_0\n  - openh264=2.1.0=hd9629dc_0\n  - openssl=1.1.1l=h9ed2024_0\n  - parso=0.8.2=pyhd3eb1b0_0\n  - pexpect=4.8.0=pyhd3eb1b0_3\n  - pickleshare=0.7.5=pyhd3eb1b0_1003\n  - pillow=8.4.0=py38h98e4679_0\n  - pip=21.2.4=py38hecd8cb5_0\n  - prompt-toolkit=3.0.20=pyhd3eb1b0_0\n  - ptyprocess=0.7.0=pyhd3eb1b0_2\n  - pygments=2.10.0=pyhd3eb1b0_0\n  - python=3.8.12=h88f2d9e_0\n  - pytorch=1.10.0=py3.8_0\n  - readline=8.1=h9ed2024_0\n  - setuptools=58.0.4=py38hecd8cb5_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sqlite=3.36.0=hce871da_0\n  - tk=8.6.11=h7bc2e8c_0\n  - torchaudio=0.10.0=py38_cpu\n  - torchvision=0.11.1=py38_cpu\n  - traitlets=5.1.0=pyhd3eb1b0_0\n  - typing_extensions=3.10.0.2=pyh06a4308_0\n  - wcwidth=0.2.5=pyhd3eb1b0_0\n  - wheel=0.37.0=pyhd3eb1b0_1\n  - xz=5.2.5=h1de35cc_0\n  - zlib=1.2.11=h1de35cc_3\n  - zstd=1.4.9=h322a384_0\n  - pip:\n    - absl-py==0.15.0\n    - aiohttp==3.8.0\n    - aiosignal==1.2.0\n    - alabaster==0.7.12\n    - astunparse==1.6.3\n    - async-timeout==4.0.0\n    - attrs==21.2.0\n    - babel==2.9.1\n    - base58==2.1.1\n    - beautifulsoup4==4.10.0\n    - black==21.12b0\n    - bleach==4.1.0\n    - boto3==1.19.12\n    - botocore==1.22.12\n    - cached-path==1.0.0\n    - cachetools==4.2.4\n    - charset-normalizer==2.0.7\n    - click==8.0.3\n    - click-help-colors==0.9.1\n    - codecov==2.1.12\n    - colorama==0.4.4\n    - configparser==5.1.0\n    - coverage==6.1.1\n    - datasets==1.15.1\n    - dill==0.3.4\n    - docker-pycreds==0.4.0\n    - docutils==0.17.1\n    - filelock==3.4.0\n    - flake8==4.0.1\n    - flaky==3.7.0\n    - flatbuffers==2.0\n    - frozenlist==1.2.0\n    - fsspec==2021.11.0\n    - furo==2022.1.2\n    - future==0.18.2\n    - gast==0.4.0\n    - gitdb==4.0.9\n    - gitpython==3.1.24\n    - glob2==0.7\n    - google-api-core==2.2.2\n    - google-auth==2.3.3\n    - google-auth-oauthlib==0.4.6\n    - google-cloud-core==2.1.0\n    - google-cloud-storage==1.42.3\n    - google-crc32c==1.3.0\n    - google-pasta==0.2.0\n    - google-resumable-media==2.1.0\n    - googleapis-common-protos==1.53.0\n    - grpcio==1.41.1\n    - h5py==3.6.0\n    - huggingface-hub==0.1.1\n    - idna==3.3\n    - imagesize==1.2.0\n    - importlib-metadata==4.8.1\n    - iniconfig==1.1.1\n    - isort==5.10.1\n    - jinja2==3.0.2\n    - jmespath==0.10.0\n    - joblib==1.1.0\n    - jsonnet==0.17.0\n    - keras==2.7.0\n    - keras-preprocessing==1.1.2\n    - keyring==23.2.1\n    - libclang==12.0.0\n    - livereload==2.6.3\n    - markdown==3.3.4\n    - markdown-it-py==1.1.0\n    - markupsafe==2.0.1\n    - mccabe==0.6.1\n    - mdit-py-plugins==0.3.0\n    - more-itertools==8.10.0\n    - multidict==5.2.0\n    - multiprocess==0.70.12.2\n    - mypy==0.931\n    - mypy-extensions==0.4.3\n    - myst-parser==0.16.1\n    - nltk==3.6.7\n    - oauthlib==3.1.1\n    - opt-einsum==3.3.0\n    - overrides==6.1.0\n    - packaging==21.2\n    - pandas==1.3.4\n    - pathspec==0.9.0\n    - pathtools==0.1.2\n    - petname==2.6\n    - pkginfo==1.7.1\n    - platformdirs==2.4.0\n    - pluggy==1.0.0\n    - promise==2.3\n    - protobuf==3.19.1\n    - psutil==5.8.0\n    - py==1.11.0\n    - pyarrow==6.0.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycodestyle==2.8.0\n    - pydeprecate==0.3.1\n    - pyflakes==2.4.0\n    - pyparsing==2.4.7\n    - pytest==6.2.5\n    - pytest-cov==3.0.0\n    - pytest-sphinx==0.3.1\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.5.1\n    - pytz==2021.3\n    - pyyaml==6.0\n    - readme-renderer==30.0\n    - regex==2021.11.2\n    - requests==2.26.0\n    - requests-oauthlib==1.3.0\n    - requests-toolbelt==0.9.1\n    - rfc3986==1.5.0\n    - rouge-score==0.0.4\n    - rsa==4.7.2\n    - s3transfer==0.5.0\n    - sacremoses==0.0.46\n    - sentencepiece==0.1.96\n    - sentry-sdk==1.4.3\n    - shortuuid==1.0.1\n    - smmap==5.0.0\n    - snowballstemmer==2.1.0\n    - soupsieve==2.3\n    - sphinx==4.3.1\n    - sphinx-autobuild==2021.3.14\n    - sphinx-copybutton==0.4.0\n    - sphinxcontrib-applehelp==1.0.2\n    - sphinxcontrib-devhelp==1.0.2\n    - sphinxcontrib-htmlhelp==2.0.0\n    - sphinxcontrib-jsmath==1.0.1\n    - sphinxcontrib-qthelp==1.0.3\n    - sphinxcontrib-serializinghtml==1.1.5\n    - sqlitedict==1.7.0\n    - subprocess32==3.5.4\n    - tensorboard==2.7.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.0\n    - tensorflow==2.7.0\n    - tensorflow-estimator==2.7.0\n    - tensorflow-io-gcs-filesystem==0.23.1\n    - termcolor==1.1.0\n    - tokenizers==0.10.3\n    - toml==0.10.2\n    - tomli==1.2.2\n    - torchmetrics==0.6.0\n    - tornado==6.1\n    - tqdm==4.62.3\n    - transformers==4.12.3\n    - twine==3.5.0\n    - types-pyyaml==6.0.0\n    - types-setuptools==57.4.2\n    - typing-utils==0.1.0\n    - urllib3==1.26.7\n    - wandb==0.12.6\n    - webencodings==0.5.1\n    - werkzeug==2.0.2\n    - wrapt==1.13.3\n    - xxhash==2.0.2\n    - yarl==1.7.2\n    - yaspin==2.1.0\n    - zipp==3.6.0\nprefix: /opt/miniconda3/envs/tango\n"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/executor-metadata.json",
    "content": "{\n    \"config\": {\n        \"type\": \"ccos\",\n        \"x\": [\n            3.1415926535,\n            0\n        ]\n    },\n    \"duration\": 0.0004,\n    \"finished_at\": 1642546350.3743181,\n    \"git\": {\n        \"commit\": \"8e09b66caffbff20fd0b1504c961932b97417e8d\",\n        \"remote\": \"https://github.com/allenai/tango.git\"\n    },\n    \"platform\": {\n        \"cpu_count\": 16,\n        \"executable\": \"/opt/miniconda3/envs/tango/bin/python\",\n        \"host\": \"ip-192-168-1-194.us-west-2.compute.internal\",\n        \"operating_system\": \"macOS-10.16-x86_64-i386-64bit\",\n        \"python\": \"3.8.12\",\n        \"root\": \"/Users/dirkg/Documents/tango/examples/euler\",\n        \"user\": \"dirkg\"\n    },\n    \"started_at\": 1642546350.373902,\n    \"step\": \"CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk\",\n    \"tango\": {\n        \"command\": \"/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic\",\n        \"version\": \"0.4.0rc4\"\n    }\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/lock",
    "content": ""
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/CosineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/requirements.txt",
    "content": "absl-py==0.15.0\nai2-tango==0.4.0rc1\naiohttp==3.8.0\naiosignal==1.2.0\nalabaster==0.7.12\nappnope==0.1.2\nastunparse==1.6.3\nasync-timeout==4.0.0\nattrs==21.2.0\nbabel==2.9.1\nbackcall==0.2.0\nbase58==2.1.1\nbeautifulsoup4==4.10.0\nblack==21.12b0\nbleach==4.1.0\nboto3==1.19.12\nbotocore==1.22.12\ncached-path==1.0.0\ncachetools==4.2.4\ncertifi==2021.10.8\ncharset-normalizer==2.0.7\nclick-help-colors==0.9.1\nclick==8.0.3\ncodecov==2.1.12\ncolorama==0.4.4\nconfigparser==5.1.0\ncoverage==6.1.1\ndatasets==1.15.1\ndecorator==5.1.0\ndill==0.3.4\ndocker-pycreds==0.4.0\ndocutils==0.17.1\nfilelock==3.4.0\nflake8==4.0.1\nflaky==3.7.0\nflatbuffers==2.0\nfrozenlist==1.2.0\nfsspec==2021.11.0\nfuro==2022.1.2\nfuture==0.18.2\ngast==0.4.0\ngitdb==4.0.9\ngitpython==3.1.24\nglob2==0.7\ngoogle-api-core==2.2.2\ngoogle-auth-oauthlib==0.4.6\ngoogle-auth==2.3.3\ngoogle-cloud-core==2.1.0\ngoogle-cloud-storage==1.42.3\ngoogle-crc32c==1.3.0\ngoogle-pasta==0.2.0\ngoogle-resumable-media==2.1.0\ngoogleapis-common-protos==1.53.0\ngrpcio==1.41.1\nh5py==3.6.0\nhuggingface-hub==0.1.1\nidna==3.3\nimagesize==1.2.0\nimportlib-metadata==4.8.1\niniconfig==1.1.1\nipython==7.29.0\nisort==5.10.1\njedi==0.18.0\njinja2==3.0.2\njmespath==0.10.0\njoblib==1.1.0\njsonnet==0.17.0\nkeras-preprocessing==1.1.2\nkeras==2.7.0\nkeyring==23.2.1\nlibclang==12.0.0\nlivereload==2.6.3\nmarkdown-it-py==1.1.0\nmarkdown==3.3.4\nmarkupsafe==2.0.1\nmatplotlib-inline==0.1.2\nmccabe==0.6.1\nmdit-py-plugins==0.3.0\nmkl-fft==1.3.1\nmkl-random==1.2.2\nmkl-service==2.4.0\nmore-itertools==8.10.0\nmultidict==5.2.0\nmultiprocess==0.70.12.2\nmypy-extensions==0.4.3\nmypy==0.931\nmyst-parser==0.16.1\nnltk==3.6.7\nnumpy==1.21.2\noauthlib==3.1.1\nolefile==0.46\nopt-einsum==3.3.0\noverrides==6.1.0\npackaging==21.2\npandas==1.3.4\nparso==0.8.2\npathspec==0.9.0\npathtools==0.1.2\npetname==2.6\npexpect==4.8.0\npickleshare==0.7.5\npillow==8.4.0\npip==21.2.4\npkginfo==1.7.1\nplatformdirs==2.4.0\npluggy==1.0.0\npromise==2.3\nprompt-toolkit==3.0.20\nprotobuf==3.19.1\npsutil==5.8.0\nptyprocess==0.7.0\npy==1.11.0\npyarrow==6.0.0\npyasn1-modules==0.2.8\npyasn1==0.4.8\npycodestyle==2.8.0\npydeprecate==0.3.1\npyflakes==2.4.0\npygments==2.10.0\npyparsing==2.4.7\npytest-cov==3.0.0\npytest-sphinx==0.3.1\npytest==6.2.5\npython-dateutil==2.8.2\npytorch-lightning==1.5.1\npytz==2021.3\npyyaml==6.0\nreadme-renderer==30.0\nregex==2021.11.2\nrequests-oauthlib==1.3.0\nrequests-toolbelt==0.9.1\nrequests==2.26.0\nrfc3986==1.5.0\nrouge-score==0.0.4\nrsa==4.7.2\ns3transfer==0.5.0\nsacremoses==0.0.46\nsentencepiece==0.1.96\nsentry-sdk==1.4.3\nsetuptools==58.0.4\nshortuuid==1.0.1\nsix==1.16.0\nsmmap==5.0.0\nsnowballstemmer==2.1.0\nsoupsieve==2.3\nsphinx-autobuild==2021.3.14\nsphinx-copybutton==0.4.0\nsphinx==4.3.1\nsphinxcontrib-applehelp==1.0.2\nsphinxcontrib-devhelp==1.0.2\nsphinxcontrib-htmlhelp==2.0.0\nsphinxcontrib-jsmath==1.0.1\nsphinxcontrib-qthelp==1.0.3\nsphinxcontrib-serializinghtml==1.1.5\nsqlitedict==1.7.0\nsubprocess32==3.5.4\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.0\ntensorboard==2.7.0\ntensorflow-estimator==2.7.0\ntensorflow-io-gcs-filesystem==0.23.1\ntensorflow==2.7.0\ntermcolor==1.1.0\ntokenizers==0.10.3\ntoml==0.10.2\ntomli==1.2.2\ntorch==1.10.0\ntorchaudio==0.10.0\ntorchmetrics==0.6.0\ntorchvision==0.11.1\ntornado==6.1\ntqdm==4.62.3\ntraitlets==5.1.0\ntransformers==4.12.3\ntwine==3.5.0\ntypes-pyyaml==6.0.0\ntypes-setuptools==57.4.2\ntyping-extensions==3.10.0.2\ntyping-utils==0.1.0\nurllib3==1.26.7\nwandb==0.12.6\nwcwidth==0.2.5\nwebencodings==0.5.1\nwerkzeug==2.0.2\nwheel==0.37.0\nwrapt==1.13.3\nxxhash==2.0.2\nyarl==1.7.2\nyaspin==2.1.0\nzipp==3.6.0"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/cache-metadata.json",
    "content": "{\n    \"step\": \"ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9\"\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/conda-environment.yaml",
    "content": "name: tango\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - appnope=0.1.2=py38hecd8cb5_1001\n  - backcall=0.2.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - bzip2=1.0.8=h1de35cc_0\n  - ca-certificates=2021.10.26=hecd8cb5_2\n  - certifi=2021.10.8=py38hecd8cb5_0\n  - decorator=5.1.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=h0a44026_0\n  - freetype=2.11.0=hd8bbffd_0\n  - gettext=0.21.0=h7535e17_0\n  - giflib=5.2.1=haf1e3a3_0\n  - gmp=6.2.1=h23ab428_2\n  - gnutls=3.6.15=hed9c0bf_0\n  - icu=58.2=h0a44026_3\n  - intel-openmp=2021.4.0=hecd8cb5_3538\n  - ipython=7.29.0=py38h01d92e1_0\n  - jedi=0.18.0=py38hecd8cb5_1\n  - jpeg=9d=h9ed2024_0\n  - lame=3.100=h1de35cc_0\n  - lcms2=2.12=hf1fd2bf_0\n  - libcxx=12.0.0=h2f01273_0\n  - libffi=3.3=hb1e8313_2\n  - libiconv=1.16=h1de35cc_0\n  - libidn2=2.3.2=h9ed2024_0\n  - libpng=1.6.37=ha441bb4_0\n  - libtasn1=4.16.0=h9ed2024_0\n  - libtiff=4.2.0=h87d7836_0\n  - libunistring=0.9.10=h9ed2024_0\n  - libuv=1.40.0=haf1e3a3_0\n  - libwebp=1.2.0=hacca55c_0\n  - libwebp-base=1.2.0=h9ed2024_0\n  - libxml2=2.9.12=hcdb78fc_0\n  - llvm-openmp=12.0.0=h0dcd299_1\n  - lz4-c=1.9.3=h23ab428_1\n  - matplotlib-inline=0.1.2=pyhd3eb1b0_2\n  - mkl=2021.4.0=hecd8cb5_637\n  - mkl-service=2.4.0=py38h9ed2024_0\n  - mkl_fft=1.3.1=py38h4ab4a9b_0\n  - mkl_random=1.2.2=py38hb2f4e1b_0\n  - ncurses=6.3=hca72f7f_1\n  - nettle=3.7.3=h230ac6f_1\n  - numpy=1.21.2=py38h4b4dc7a_0\n  - numpy-base=1.21.2=py38he0bd621_0\n  - olefile=0.46=pyhd3eb1b0_0\n  - openh264=2.1.0=hd9629dc_0\n  - openssl=1.1.1l=h9ed2024_0\n  - parso=0.8.2=pyhd3eb1b0_0\n  - pexpect=4.8.0=pyhd3eb1b0_3\n  - pickleshare=0.7.5=pyhd3eb1b0_1003\n  - pillow=8.4.0=py38h98e4679_0\n  - pip=21.2.4=py38hecd8cb5_0\n  - prompt-toolkit=3.0.20=pyhd3eb1b0_0\n  - ptyprocess=0.7.0=pyhd3eb1b0_2\n  - pygments=2.10.0=pyhd3eb1b0_0\n  - python=3.8.12=h88f2d9e_0\n  - pytorch=1.10.0=py3.8_0\n  - readline=8.1=h9ed2024_0\n  - setuptools=58.0.4=py38hecd8cb5_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sqlite=3.36.0=hce871da_0\n  - tk=8.6.11=h7bc2e8c_0\n  - torchaudio=0.10.0=py38_cpu\n  - torchvision=0.11.1=py38_cpu\n  - traitlets=5.1.0=pyhd3eb1b0_0\n  - typing_extensions=3.10.0.2=pyh06a4308_0\n  - wcwidth=0.2.5=pyhd3eb1b0_0\n  - wheel=0.37.0=pyhd3eb1b0_1\n  - xz=5.2.5=h1de35cc_0\n  - zlib=1.2.11=h1de35cc_3\n  - zstd=1.4.9=h322a384_0\n  - pip:\n    - absl-py==0.15.0\n    - aiohttp==3.8.0\n    - aiosignal==1.2.0\n    - alabaster==0.7.12\n    - astunparse==1.6.3\n    - async-timeout==4.0.0\n    - attrs==21.2.0\n    - babel==2.9.1\n    - base58==2.1.1\n    - beautifulsoup4==4.10.0\n    - black==21.12b0\n    - bleach==4.1.0\n    - boto3==1.19.12\n    - botocore==1.22.12\n    - cached-path==1.0.0\n    - cachetools==4.2.4\n    - charset-normalizer==2.0.7\n    - click==8.0.3\n    - click-help-colors==0.9.1\n    - codecov==2.1.12\n    - colorama==0.4.4\n    - configparser==5.1.0\n    - coverage==6.1.1\n    - datasets==1.15.1\n    - dill==0.3.4\n    - docker-pycreds==0.4.0\n    - docutils==0.17.1\n    - filelock==3.4.0\n    - flake8==4.0.1\n    - flaky==3.7.0\n    - flatbuffers==2.0\n    - frozenlist==1.2.0\n    - fsspec==2021.11.0\n    - furo==2022.1.2\n    - future==0.18.2\n    - gast==0.4.0\n    - gitdb==4.0.9\n    - gitpython==3.1.24\n    - glob2==0.7\n    - google-api-core==2.2.2\n    - google-auth==2.3.3\n    - google-auth-oauthlib==0.4.6\n    - google-cloud-core==2.1.0\n    - google-cloud-storage==1.42.3\n    - google-crc32c==1.3.0\n    - google-pasta==0.2.0\n    - google-resumable-media==2.1.0\n    - googleapis-common-protos==1.53.0\n    - grpcio==1.41.1\n    - h5py==3.6.0\n    - huggingface-hub==0.1.1\n    - idna==3.3\n    - imagesize==1.2.0\n    - importlib-metadata==4.8.1\n    - iniconfig==1.1.1\n    - isort==5.10.1\n    - jinja2==3.0.2\n    - jmespath==0.10.0\n    - joblib==1.1.0\n    - jsonnet==0.17.0\n    - keras==2.7.0\n    - keras-preprocessing==1.1.2\n    - keyring==23.2.1\n    - libclang==12.0.0\n    - livereload==2.6.3\n    - markdown==3.3.4\n    - markdown-it-py==1.1.0\n    - markupsafe==2.0.1\n    - mccabe==0.6.1\n    - mdit-py-plugins==0.3.0\n    - more-itertools==8.10.0\n    - multidict==5.2.0\n    - multiprocess==0.70.12.2\n    - mypy==0.931\n    - mypy-extensions==0.4.3\n    - myst-parser==0.16.1\n    - nltk==3.6.7\n    - oauthlib==3.1.1\n    - opt-einsum==3.3.0\n    - overrides==6.1.0\n    - packaging==21.2\n    - pandas==1.3.4\n    - pathspec==0.9.0\n    - pathtools==0.1.2\n    - petname==2.6\n    - pkginfo==1.7.1\n    - platformdirs==2.4.0\n    - pluggy==1.0.0\n    - promise==2.3\n    - protobuf==3.19.1\n    - psutil==5.8.0\n    - py==1.11.0\n    - pyarrow==6.0.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycodestyle==2.8.0\n    - pydeprecate==0.3.1\n    - pyflakes==2.4.0\n    - pyparsing==2.4.7\n    - pytest==6.2.5\n    - pytest-cov==3.0.0\n    - pytest-sphinx==0.3.1\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.5.1\n    - pytz==2021.3\n    - pyyaml==6.0\n    - readme-renderer==30.0\n    - regex==2021.11.2\n    - requests==2.26.0\n    - requests-oauthlib==1.3.0\n    - requests-toolbelt==0.9.1\n    - rfc3986==1.5.0\n    - rouge-score==0.0.4\n    - rsa==4.7.2\n    - s3transfer==0.5.0\n    - sacremoses==0.0.46\n    - sentencepiece==0.1.96\n    - sentry-sdk==1.4.3\n    - shortuuid==1.0.1\n    - smmap==5.0.0\n    - snowballstemmer==2.1.0\n    - soupsieve==2.3\n    - sphinx==4.3.1\n    - sphinx-autobuild==2021.3.14\n    - sphinx-copybutton==0.4.0\n    - sphinxcontrib-applehelp==1.0.2\n    - sphinxcontrib-devhelp==1.0.2\n    - sphinxcontrib-htmlhelp==2.0.0\n    - sphinxcontrib-jsmath==1.0.1\n    - sphinxcontrib-qthelp==1.0.3\n    - sphinxcontrib-serializinghtml==1.1.5\n    - sqlitedict==1.7.0\n    - subprocess32==3.5.4\n    - tensorboard==2.7.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.0\n    - tensorflow==2.7.0\n    - tensorflow-estimator==2.7.0\n    - tensorflow-io-gcs-filesystem==0.23.1\n    - termcolor==1.1.0\n    - tokenizers==0.10.3\n    - toml==0.10.2\n    - tomli==1.2.2\n    - torchmetrics==0.6.0\n    - tornado==6.1\n    - tqdm==4.62.3\n    - transformers==4.12.3\n    - twine==3.5.0\n    - types-pyyaml==6.0.0\n    - types-setuptools==57.4.2\n    - typing-utils==0.1.0\n    - urllib3==1.26.7\n    - wandb==0.12.6\n    - webencodings==0.5.1\n    - werkzeug==2.0.2\n    - wrapt==1.13.3\n    - xxhash==2.0.2\n    - yarl==1.7.2\n    - yaspin==2.1.0\n    - zipp==3.6.0\nprefix: /opt/miniconda3/envs/tango\n"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/executor-metadata.json",
    "content": "{\n    \"config\": {\n        \"type\": \"cexp\",\n        \"x\": {\n            \"type\": \"ref\",\n            \"ref\": \"MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae\"\n        }\n    },\n    \"duration\": 0.0006,\n    \"finished_at\": 1642546361.347647,\n    \"git\": {\n        \"commit\": \"8e09b66caffbff20fd0b1504c961932b97417e8d\",\n        \"remote\": \"https://github.com/allenai/tango.git\"\n    },\n    \"platform\": {\n        \"cpu_count\": 16,\n        \"executable\": \"/opt/miniconda3/envs/tango/bin/python\",\n        \"host\": \"ip-192-168-1-194.us-west-2.compute.internal\",\n        \"operating_system\": \"macOS-10.16-x86_64-i386-64bit\",\n        \"python\": \"3.8.12\",\n        \"root\": \"/Users/dirkg/Documents/tango/examples/euler\",\n        \"user\": \"dirkg\"\n    },\n    \"started_at\": 1642546361.347095,\n    \"step\": \"ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9\",\n    \"tango\": {\n        \"command\": \"/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic\",\n        \"version\": \"0.4.0rc4\"\n    }\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/lock",
    "content": ""
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9/requirements.txt",
    "content": "absl-py==0.15.0\nai2-tango==0.4.0rc1\naiohttp==3.8.0\naiosignal==1.2.0\nalabaster==0.7.12\nappnope==0.1.2\nastunparse==1.6.3\nasync-timeout==4.0.0\nattrs==21.2.0\nbabel==2.9.1\nbackcall==0.2.0\nbase58==2.1.1\nbeautifulsoup4==4.10.0\nblack==21.12b0\nbleach==4.1.0\nboto3==1.19.12\nbotocore==1.22.12\ncached-path==1.0.0\ncachetools==4.2.4\ncertifi==2021.10.8\ncharset-normalizer==2.0.7\nclick-help-colors==0.9.1\nclick==8.0.3\ncodecov==2.1.12\ncolorama==0.4.4\nconfigparser==5.1.0\ncoverage==6.1.1\ndatasets==1.15.1\ndecorator==5.1.0\ndill==0.3.4\ndocker-pycreds==0.4.0\ndocutils==0.17.1\nfilelock==3.4.0\nflake8==4.0.1\nflaky==3.7.0\nflatbuffers==2.0\nfrozenlist==1.2.0\nfsspec==2021.11.0\nfuro==2022.1.2\nfuture==0.18.2\ngast==0.4.0\ngitdb==4.0.9\ngitpython==3.1.24\nglob2==0.7\ngoogle-api-core==2.2.2\ngoogle-auth-oauthlib==0.4.6\ngoogle-auth==2.3.3\ngoogle-cloud-core==2.1.0\ngoogle-cloud-storage==1.42.3\ngoogle-crc32c==1.3.0\ngoogle-pasta==0.2.0\ngoogle-resumable-media==2.1.0\ngoogleapis-common-protos==1.53.0\ngrpcio==1.41.1\nh5py==3.6.0\nhuggingface-hub==0.1.1\nidna==3.3\nimagesize==1.2.0\nimportlib-metadata==4.8.1\niniconfig==1.1.1\nipython==7.29.0\nisort==5.10.1\njedi==0.18.0\njinja2==3.0.2\njmespath==0.10.0\njoblib==1.1.0\njsonnet==0.17.0\nkeras-preprocessing==1.1.2\nkeras==2.7.0\nkeyring==23.2.1\nlibclang==12.0.0\nlivereload==2.6.3\nmarkdown-it-py==1.1.0\nmarkdown==3.3.4\nmarkupsafe==2.0.1\nmatplotlib-inline==0.1.2\nmccabe==0.6.1\nmdit-py-plugins==0.3.0\nmkl-fft==1.3.1\nmkl-random==1.2.2\nmkl-service==2.4.0\nmore-itertools==8.10.0\nmultidict==5.2.0\nmultiprocess==0.70.12.2\nmypy-extensions==0.4.3\nmypy==0.931\nmyst-parser==0.16.1\nnltk==3.6.7\nnumpy==1.21.2\noauthlib==3.1.1\nolefile==0.46\nopt-einsum==3.3.0\noverrides==6.1.0\npackaging==21.2\npandas==1.3.4\nparso==0.8.2\npathspec==0.9.0\npathtools==0.1.2\npetname==2.6\npexpect==4.8.0\npickleshare==0.7.5\npillow==8.4.0\npip==21.2.4\npkginfo==1.7.1\nplatformdirs==2.4.0\npluggy==1.0.0\npromise==2.3\nprompt-toolkit==3.0.20\nprotobuf==3.19.1\npsutil==5.8.0\nptyprocess==0.7.0\npy==1.11.0\npyarrow==6.0.0\npyasn1-modules==0.2.8\npyasn1==0.4.8\npycodestyle==2.8.0\npydeprecate==0.3.1\npyflakes==2.4.0\npygments==2.10.0\npyparsing==2.4.7\npytest-cov==3.0.0\npytest-sphinx==0.3.1\npytest==6.2.5\npython-dateutil==2.8.2\npytorch-lightning==1.5.1\npytz==2021.3\npyyaml==6.0\nreadme-renderer==30.0\nregex==2021.11.2\nrequests-oauthlib==1.3.0\nrequests-toolbelt==0.9.1\nrequests==2.26.0\nrfc3986==1.5.0\nrouge-score==0.0.4\nrsa==4.7.2\ns3transfer==0.5.0\nsacremoses==0.0.46\nsentencepiece==0.1.96\nsentry-sdk==1.4.3\nsetuptools==58.0.4\nshortuuid==1.0.1\nsix==1.16.0\nsmmap==5.0.0\nsnowballstemmer==2.1.0\nsoupsieve==2.3\nsphinx-autobuild==2021.3.14\nsphinx-copybutton==0.4.0\nsphinx==4.3.1\nsphinxcontrib-applehelp==1.0.2\nsphinxcontrib-devhelp==1.0.2\nsphinxcontrib-htmlhelp==2.0.0\nsphinxcontrib-jsmath==1.0.1\nsphinxcontrib-qthelp==1.0.3\nsphinxcontrib-serializinghtml==1.1.5\nsqlitedict==1.7.0\nsubprocess32==3.5.4\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.0\ntensorboard==2.7.0\ntensorflow-estimator==2.7.0\ntensorflow-io-gcs-filesystem==0.23.1\ntensorflow==2.7.0\ntermcolor==1.1.0\ntokenizers==0.10.3\ntoml==0.10.2\ntomli==1.2.2\ntorch==1.10.0\ntorchaudio==0.10.0\ntorchmetrics==0.6.0\ntorchvision==0.11.1\ntornado==6.1\ntqdm==4.62.3\ntraitlets==5.1.0\ntransformers==4.12.3\ntwine==3.5.0\ntypes-pyyaml==6.0.0\ntypes-setuptools==57.4.2\ntyping-extensions==3.10.0.2\ntyping-utils==0.1.0\nurllib3==1.26.7\nwandb==0.12.6\nwcwidth==0.2.5\nwebencodings==0.5.1\nwerkzeug==2.0.2\nwheel==0.37.0\nwrapt==1.13.3\nxxhash==2.0.2\nyarl==1.7.2\nyaspin==2.1.0\nzipp==3.6.0"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/cache-metadata.json",
    "content": "{\n    \"step\": \"MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf\"\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/conda-environment.yaml",
    "content": "name: tango\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - appnope=0.1.2=py38hecd8cb5_1001\n  - backcall=0.2.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - bzip2=1.0.8=h1de35cc_0\n  - ca-certificates=2021.10.26=hecd8cb5_2\n  - certifi=2021.10.8=py38hecd8cb5_0\n  - decorator=5.1.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=h0a44026_0\n  - freetype=2.11.0=hd8bbffd_0\n  - gettext=0.21.0=h7535e17_0\n  - giflib=5.2.1=haf1e3a3_0\n  - gmp=6.2.1=h23ab428_2\n  - gnutls=3.6.15=hed9c0bf_0\n  - icu=58.2=h0a44026_3\n  - intel-openmp=2021.4.0=hecd8cb5_3538\n  - ipython=7.29.0=py38h01d92e1_0\n  - jedi=0.18.0=py38hecd8cb5_1\n  - jpeg=9d=h9ed2024_0\n  - lame=3.100=h1de35cc_0\n  - lcms2=2.12=hf1fd2bf_0\n  - libcxx=12.0.0=h2f01273_0\n  - libffi=3.3=hb1e8313_2\n  - libiconv=1.16=h1de35cc_0\n  - libidn2=2.3.2=h9ed2024_0\n  - libpng=1.6.37=ha441bb4_0\n  - libtasn1=4.16.0=h9ed2024_0\n  - libtiff=4.2.0=h87d7836_0\n  - libunistring=0.9.10=h9ed2024_0\n  - libuv=1.40.0=haf1e3a3_0\n  - libwebp=1.2.0=hacca55c_0\n  - libwebp-base=1.2.0=h9ed2024_0\n  - libxml2=2.9.12=hcdb78fc_0\n  - llvm-openmp=12.0.0=h0dcd299_1\n  - lz4-c=1.9.3=h23ab428_1\n  - matplotlib-inline=0.1.2=pyhd3eb1b0_2\n  - mkl=2021.4.0=hecd8cb5_637\n  - mkl-service=2.4.0=py38h9ed2024_0\n  - mkl_fft=1.3.1=py38h4ab4a9b_0\n  - mkl_random=1.2.2=py38hb2f4e1b_0\n  - ncurses=6.3=hca72f7f_1\n  - nettle=3.7.3=h230ac6f_1\n  - numpy=1.21.2=py38h4b4dc7a_0\n  - numpy-base=1.21.2=py38he0bd621_0\n  - olefile=0.46=pyhd3eb1b0_0\n  - openh264=2.1.0=hd9629dc_0\n  - openssl=1.1.1l=h9ed2024_0\n  - parso=0.8.2=pyhd3eb1b0_0\n  - pexpect=4.8.0=pyhd3eb1b0_3\n  - pickleshare=0.7.5=pyhd3eb1b0_1003\n  - pillow=8.4.0=py38h98e4679_0\n  - pip=21.2.4=py38hecd8cb5_0\n  - prompt-toolkit=3.0.20=pyhd3eb1b0_0\n  - ptyprocess=0.7.0=pyhd3eb1b0_2\n  - pygments=2.10.0=pyhd3eb1b0_0\n  - python=3.8.12=h88f2d9e_0\n  - pytorch=1.10.0=py3.8_0\n  - readline=8.1=h9ed2024_0\n  - setuptools=58.0.4=py38hecd8cb5_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sqlite=3.36.0=hce871da_0\n  - tk=8.6.11=h7bc2e8c_0\n  - torchaudio=0.10.0=py38_cpu\n  - torchvision=0.11.1=py38_cpu\n  - traitlets=5.1.0=pyhd3eb1b0_0\n  - typing_extensions=3.10.0.2=pyh06a4308_0\n  - wcwidth=0.2.5=pyhd3eb1b0_0\n  - wheel=0.37.0=pyhd3eb1b0_1\n  - xz=5.2.5=h1de35cc_0\n  - zlib=1.2.11=h1de35cc_3\n  - zstd=1.4.9=h322a384_0\n  - pip:\n    - absl-py==0.15.0\n    - aiohttp==3.8.0\n    - aiosignal==1.2.0\n    - alabaster==0.7.12\n    - astunparse==1.6.3\n    - async-timeout==4.0.0\n    - attrs==21.2.0\n    - babel==2.9.1\n    - base58==2.1.1\n    - beautifulsoup4==4.10.0\n    - black==21.12b0\n    - bleach==4.1.0\n    - boto3==1.19.12\n    - botocore==1.22.12\n    - cached-path==1.0.0\n    - cachetools==4.2.4\n    - charset-normalizer==2.0.7\n    - click==8.0.3\n    - click-help-colors==0.9.1\n    - codecov==2.1.12\n    - colorama==0.4.4\n    - configparser==5.1.0\n    - coverage==6.1.1\n    - datasets==1.15.1\n    - dill==0.3.4\n    - docker-pycreds==0.4.0\n    - docutils==0.17.1\n    - filelock==3.4.0\n    - flake8==4.0.1\n    - flaky==3.7.0\n    - flatbuffers==2.0\n    - frozenlist==1.2.0\n    - fsspec==2021.11.0\n    - furo==2022.1.2\n    - future==0.18.2\n    - gast==0.4.0\n    - gitdb==4.0.9\n    - gitpython==3.1.24\n    - glob2==0.7\n    - google-api-core==2.2.2\n    - google-auth==2.3.3\n    - google-auth-oauthlib==0.4.6\n    - google-cloud-core==2.1.0\n    - google-cloud-storage==1.42.3\n    - google-crc32c==1.3.0\n    - google-pasta==0.2.0\n    - google-resumable-media==2.1.0\n    - googleapis-common-protos==1.53.0\n    - grpcio==1.41.1\n    - h5py==3.6.0\n    - huggingface-hub==0.1.1\n    - idna==3.3\n    - imagesize==1.2.0\n    - importlib-metadata==4.8.1\n    - iniconfig==1.1.1\n    - isort==5.10.1\n    - jinja2==3.0.2\n    - jmespath==0.10.0\n    - joblib==1.1.0\n    - jsonnet==0.17.0\n    - keras==2.7.0\n    - keras-preprocessing==1.1.2\n    - keyring==23.2.1\n    - libclang==12.0.0\n    - livereload==2.6.3\n    - markdown==3.3.4\n    - markdown-it-py==1.1.0\n    - markupsafe==2.0.1\n    - mccabe==0.6.1\n    - mdit-py-plugins==0.3.0\n    - more-itertools==8.10.0\n    - multidict==5.2.0\n    - multiprocess==0.70.12.2\n    - mypy==0.931\n    - mypy-extensions==0.4.3\n    - myst-parser==0.16.1\n    - nltk==3.6.7\n    - oauthlib==3.1.1\n    - opt-einsum==3.3.0\n    - overrides==6.1.0\n    - packaging==21.2\n    - pandas==1.3.4\n    - pathspec==0.9.0\n    - pathtools==0.1.2\n    - petname==2.6\n    - pkginfo==1.7.1\n    - platformdirs==2.4.0\n    - pluggy==1.0.0\n    - promise==2.3\n    - protobuf==3.19.1\n    - psutil==5.8.0\n    - py==1.11.0\n    - pyarrow==6.0.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycodestyle==2.8.0\n    - pydeprecate==0.3.1\n    - pyflakes==2.4.0\n    - pyparsing==2.4.7\n    - pytest==6.2.5\n    - pytest-cov==3.0.0\n    - pytest-sphinx==0.3.1\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.5.1\n    - pytz==2021.3\n    - pyyaml==6.0\n    - readme-renderer==30.0\n    - regex==2021.11.2\n    - requests==2.26.0\n    - requests-oauthlib==1.3.0\n    - requests-toolbelt==0.9.1\n    - rfc3986==1.5.0\n    - rouge-score==0.0.4\n    - rsa==4.7.2\n    - s3transfer==0.5.0\n    - sacremoses==0.0.46\n    - sentencepiece==0.1.96\n    - sentry-sdk==1.4.3\n    - shortuuid==1.0.1\n    - smmap==5.0.0\n    - snowballstemmer==2.1.0\n    - soupsieve==2.3\n    - sphinx==4.3.1\n    - sphinx-autobuild==2021.3.14\n    - sphinx-copybutton==0.4.0\n    - sphinxcontrib-applehelp==1.0.2\n    - sphinxcontrib-devhelp==1.0.2\n    - sphinxcontrib-htmlhelp==2.0.0\n    - sphinxcontrib-jsmath==1.0.1\n    - sphinxcontrib-qthelp==1.0.3\n    - sphinxcontrib-serializinghtml==1.1.5\n    - sqlitedict==1.7.0\n    - subprocess32==3.5.4\n    - tensorboard==2.7.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.0\n    - tensorflow==2.7.0\n    - tensorflow-estimator==2.7.0\n    - tensorflow-io-gcs-filesystem==0.23.1\n    - termcolor==1.1.0\n    - tokenizers==0.10.3\n    - toml==0.10.2\n    - tomli==1.2.2\n    - torchmetrics==0.6.0\n    - tornado==6.1\n    - tqdm==4.62.3\n    - transformers==4.12.3\n    - twine==3.5.0\n    - types-pyyaml==6.0.0\n    - types-setuptools==57.4.2\n    - typing-utils==0.1.0\n    - urllib3==1.26.7\n    - wandb==0.12.6\n    - webencodings==0.5.1\n    - werkzeug==2.0.2\n    - wrapt==1.13.3\n    - xxhash==2.0.2\n    - yarl==1.7.2\n    - yaspin==2.1.0\n    - zipp==3.6.0\nprefix: /opt/miniconda3/envs/tango\n"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/executor-metadata.json",
    "content": "{\n    \"config\": {\n        \"type\": \"cmul\",\n        \"a\": [\n            0,\n            1\n        ],\n        \"b\": {\n            \"type\": \"ref\",\n            \"ref\": \"SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk\"\n        }\n    },\n    \"duration\": 0.0004,\n    \"finished_at\": 1642546358.776982,\n    \"git\": {\n        \"commit\": \"8e09b66caffbff20fd0b1504c961932b97417e8d\",\n        \"remote\": \"https://github.com/allenai/tango.git\"\n    },\n    \"platform\": {\n        \"cpu_count\": 16,\n        \"executable\": \"/opt/miniconda3/envs/tango/bin/python\",\n        \"host\": \"ip-192-168-1-194.us-west-2.compute.internal\",\n        \"operating_system\": \"macOS-10.16-x86_64-i386-64bit\",\n        \"python\": \"3.8.12\",\n        \"root\": \"/Users/dirkg/Documents/tango/examples/euler\",\n        \"user\": \"dirkg\"\n    },\n    \"started_at\": 1642546358.776602,\n    \"step\": \"MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf\",\n    \"tango\": {\n        \"command\": \"/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic\",\n        \"version\": \"0.4.0rc4\"\n    }\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/lock",
    "content": ""
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-2ZG7wPj9WLn5PgpYyPVHw9Qg7VM1mhwf/requirements.txt",
    "content": "absl-py==0.15.0\nai2-tango==0.4.0rc1\naiohttp==3.8.0\naiosignal==1.2.0\nalabaster==0.7.12\nappnope==0.1.2\nastunparse==1.6.3\nasync-timeout==4.0.0\nattrs==21.2.0\nbabel==2.9.1\nbackcall==0.2.0\nbase58==2.1.1\nbeautifulsoup4==4.10.0\nblack==21.12b0\nbleach==4.1.0\nboto3==1.19.12\nbotocore==1.22.12\ncached-path==1.0.0\ncachetools==4.2.4\ncertifi==2021.10.8\ncharset-normalizer==2.0.7\nclick-help-colors==0.9.1\nclick==8.0.3\ncodecov==2.1.12\ncolorama==0.4.4\nconfigparser==5.1.0\ncoverage==6.1.1\ndatasets==1.15.1\ndecorator==5.1.0\ndill==0.3.4\ndocker-pycreds==0.4.0\ndocutils==0.17.1\nfilelock==3.4.0\nflake8==4.0.1\nflaky==3.7.0\nflatbuffers==2.0\nfrozenlist==1.2.0\nfsspec==2021.11.0\nfuro==2022.1.2\nfuture==0.18.2\ngast==0.4.0\ngitdb==4.0.9\ngitpython==3.1.24\nglob2==0.7\ngoogle-api-core==2.2.2\ngoogle-auth-oauthlib==0.4.6\ngoogle-auth==2.3.3\ngoogle-cloud-core==2.1.0\ngoogle-cloud-storage==1.42.3\ngoogle-crc32c==1.3.0\ngoogle-pasta==0.2.0\ngoogle-resumable-media==2.1.0\ngoogleapis-common-protos==1.53.0\ngrpcio==1.41.1\nh5py==3.6.0\nhuggingface-hub==0.1.1\nidna==3.3\nimagesize==1.2.0\nimportlib-metadata==4.8.1\niniconfig==1.1.1\nipython==7.29.0\nisort==5.10.1\njedi==0.18.0\njinja2==3.0.2\njmespath==0.10.0\njoblib==1.1.0\njsonnet==0.17.0\nkeras-preprocessing==1.1.2\nkeras==2.7.0\nkeyring==23.2.1\nlibclang==12.0.0\nlivereload==2.6.3\nmarkdown-it-py==1.1.0\nmarkdown==3.3.4\nmarkupsafe==2.0.1\nmatplotlib-inline==0.1.2\nmccabe==0.6.1\nmdit-py-plugins==0.3.0\nmkl-fft==1.3.1\nmkl-random==1.2.2\nmkl-service==2.4.0\nmore-itertools==8.10.0\nmultidict==5.2.0\nmultiprocess==0.70.12.2\nmypy-extensions==0.4.3\nmypy==0.931\nmyst-parser==0.16.1\nnltk==3.6.7\nnumpy==1.21.2\noauthlib==3.1.1\nolefile==0.46\nopt-einsum==3.3.0\noverrides==6.1.0\npackaging==21.2\npandas==1.3.4\nparso==0.8.2\npathspec==0.9.0\npathtools==0.1.2\npetname==2.6\npexpect==4.8.0\npickleshare==0.7.5\npillow==8.4.0\npip==21.2.4\npkginfo==1.7.1\nplatformdirs==2.4.0\npluggy==1.0.0\npromise==2.3\nprompt-toolkit==3.0.20\nprotobuf==3.19.1\npsutil==5.8.0\nptyprocess==0.7.0\npy==1.11.0\npyarrow==6.0.0\npyasn1-modules==0.2.8\npyasn1==0.4.8\npycodestyle==2.8.0\npydeprecate==0.3.1\npyflakes==2.4.0\npygments==2.10.0\npyparsing==2.4.7\npytest-cov==3.0.0\npytest-sphinx==0.3.1\npytest==6.2.5\npython-dateutil==2.8.2\npytorch-lightning==1.5.1\npytz==2021.3\npyyaml==6.0\nreadme-renderer==30.0\nregex==2021.11.2\nrequests-oauthlib==1.3.0\nrequests-toolbelt==0.9.1\nrequests==2.26.0\nrfc3986==1.5.0\nrouge-score==0.0.4\nrsa==4.7.2\ns3transfer==0.5.0\nsacremoses==0.0.46\nsentencepiece==0.1.96\nsentry-sdk==1.4.3\nsetuptools==58.0.4\nshortuuid==1.0.1\nsix==1.16.0\nsmmap==5.0.0\nsnowballstemmer==2.1.0\nsoupsieve==2.3\nsphinx-autobuild==2021.3.14\nsphinx-copybutton==0.4.0\nsphinx==4.3.1\nsphinxcontrib-applehelp==1.0.2\nsphinxcontrib-devhelp==1.0.2\nsphinxcontrib-htmlhelp==2.0.0\nsphinxcontrib-jsmath==1.0.1\nsphinxcontrib-qthelp==1.0.3\nsphinxcontrib-serializinghtml==1.1.5\nsqlitedict==1.7.0\nsubprocess32==3.5.4\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.0\ntensorboard==2.7.0\ntensorflow-estimator==2.7.0\ntensorflow-io-gcs-filesystem==0.23.1\ntensorflow==2.7.0\ntermcolor==1.1.0\ntokenizers==0.10.3\ntoml==0.10.2\ntomli==1.2.2\ntorch==1.10.0\ntorchaudio==0.10.0\ntorchmetrics==0.6.0\ntorchvision==0.11.1\ntornado==6.1\ntqdm==4.62.3\ntraitlets==5.1.0\ntransformers==4.12.3\ntwine==3.5.0\ntypes-pyyaml==6.0.0\ntypes-setuptools==57.4.2\ntyping-extensions==3.10.0.2\ntyping-utils==0.1.0\nurllib3==1.26.7\nwandb==0.12.6\nwcwidth==0.2.5\nwebencodings==0.5.1\nwerkzeug==2.0.2\nwheel==0.37.0\nwrapt==1.13.3\nxxhash==2.0.2\nyarl==1.7.2\nyaspin==2.1.0\nzipp==3.6.0"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/cache-metadata.json",
    "content": "{\n    \"step\": \"MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae\"\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/conda-environment.yaml",
    "content": "name: tango\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - appnope=0.1.2=py38hecd8cb5_1001\n  - backcall=0.2.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - bzip2=1.0.8=h1de35cc_0\n  - ca-certificates=2021.10.26=hecd8cb5_2\n  - certifi=2021.10.8=py38hecd8cb5_0\n  - decorator=5.1.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=h0a44026_0\n  - freetype=2.11.0=hd8bbffd_0\n  - gettext=0.21.0=h7535e17_0\n  - giflib=5.2.1=haf1e3a3_0\n  - gmp=6.2.1=h23ab428_2\n  - gnutls=3.6.15=hed9c0bf_0\n  - icu=58.2=h0a44026_3\n  - intel-openmp=2021.4.0=hecd8cb5_3538\n  - ipython=7.29.0=py38h01d92e1_0\n  - jedi=0.18.0=py38hecd8cb5_1\n  - jpeg=9d=h9ed2024_0\n  - lame=3.100=h1de35cc_0\n  - lcms2=2.12=hf1fd2bf_0\n  - libcxx=12.0.0=h2f01273_0\n  - libffi=3.3=hb1e8313_2\n  - libiconv=1.16=h1de35cc_0\n  - libidn2=2.3.2=h9ed2024_0\n  - libpng=1.6.37=ha441bb4_0\n  - libtasn1=4.16.0=h9ed2024_0\n  - libtiff=4.2.0=h87d7836_0\n  - libunistring=0.9.10=h9ed2024_0\n  - libuv=1.40.0=haf1e3a3_0\n  - libwebp=1.2.0=hacca55c_0\n  - libwebp-base=1.2.0=h9ed2024_0\n  - libxml2=2.9.12=hcdb78fc_0\n  - llvm-openmp=12.0.0=h0dcd299_1\n  - lz4-c=1.9.3=h23ab428_1\n  - matplotlib-inline=0.1.2=pyhd3eb1b0_2\n  - mkl=2021.4.0=hecd8cb5_637\n  - mkl-service=2.4.0=py38h9ed2024_0\n  - mkl_fft=1.3.1=py38h4ab4a9b_0\n  - mkl_random=1.2.2=py38hb2f4e1b_0\n  - ncurses=6.3=hca72f7f_1\n  - nettle=3.7.3=h230ac6f_1\n  - numpy=1.21.2=py38h4b4dc7a_0\n  - numpy-base=1.21.2=py38he0bd621_0\n  - olefile=0.46=pyhd3eb1b0_0\n  - openh264=2.1.0=hd9629dc_0\n  - openssl=1.1.1l=h9ed2024_0\n  - parso=0.8.2=pyhd3eb1b0_0\n  - pexpect=4.8.0=pyhd3eb1b0_3\n  - pickleshare=0.7.5=pyhd3eb1b0_1003\n  - pillow=8.4.0=py38h98e4679_0\n  - pip=21.2.4=py38hecd8cb5_0\n  - prompt-toolkit=3.0.20=pyhd3eb1b0_0\n  - ptyprocess=0.7.0=pyhd3eb1b0_2\n  - pygments=2.10.0=pyhd3eb1b0_0\n  - python=3.8.12=h88f2d9e_0\n  - pytorch=1.10.0=py3.8_0\n  - readline=8.1=h9ed2024_0\n  - setuptools=58.0.4=py38hecd8cb5_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sqlite=3.36.0=hce871da_0\n  - tk=8.6.11=h7bc2e8c_0\n  - torchaudio=0.10.0=py38_cpu\n  - torchvision=0.11.1=py38_cpu\n  - traitlets=5.1.0=pyhd3eb1b0_0\n  - typing_extensions=3.10.0.2=pyh06a4308_0\n  - wcwidth=0.2.5=pyhd3eb1b0_0\n  - wheel=0.37.0=pyhd3eb1b0_1\n  - xz=5.2.5=h1de35cc_0\n  - zlib=1.2.11=h1de35cc_3\n  - zstd=1.4.9=h322a384_0\n  - pip:\n    - absl-py==0.15.0\n    - aiohttp==3.8.0\n    - aiosignal==1.2.0\n    - alabaster==0.7.12\n    - astunparse==1.6.3\n    - async-timeout==4.0.0\n    - attrs==21.2.0\n    - babel==2.9.1\n    - base58==2.1.1\n    - beautifulsoup4==4.10.0\n    - black==21.12b0\n    - bleach==4.1.0\n    - boto3==1.19.12\n    - botocore==1.22.12\n    - cached-path==1.0.0\n    - cachetools==4.2.4\n    - charset-normalizer==2.0.7\n    - click==8.0.3\n    - click-help-colors==0.9.1\n    - codecov==2.1.12\n    - colorama==0.4.4\n    - configparser==5.1.0\n    - coverage==6.1.1\n    - datasets==1.15.1\n    - dill==0.3.4\n    - docker-pycreds==0.4.0\n    - docutils==0.17.1\n    - filelock==3.4.0\n    - flake8==4.0.1\n    - flaky==3.7.0\n    - flatbuffers==2.0\n    - frozenlist==1.2.0\n    - fsspec==2021.11.0\n    - furo==2022.1.2\n    - future==0.18.2\n    - gast==0.4.0\n    - gitdb==4.0.9\n    - gitpython==3.1.24\n    - glob2==0.7\n    - google-api-core==2.2.2\n    - google-auth==2.3.3\n    - google-auth-oauthlib==0.4.6\n    - google-cloud-core==2.1.0\n    - google-cloud-storage==1.42.3\n    - google-crc32c==1.3.0\n    - google-pasta==0.2.0\n    - google-resumable-media==2.1.0\n    - googleapis-common-protos==1.53.0\n    - grpcio==1.41.1\n    - h5py==3.6.0\n    - huggingface-hub==0.1.1\n    - idna==3.3\n    - imagesize==1.2.0\n    - importlib-metadata==4.8.1\n    - iniconfig==1.1.1\n    - isort==5.10.1\n    - jinja2==3.0.2\n    - jmespath==0.10.0\n    - joblib==1.1.0\n    - jsonnet==0.17.0\n    - keras==2.7.0\n    - keras-preprocessing==1.1.2\n    - keyring==23.2.1\n    - libclang==12.0.0\n    - livereload==2.6.3\n    - markdown==3.3.4\n    - markdown-it-py==1.1.0\n    - markupsafe==2.0.1\n    - mccabe==0.6.1\n    - mdit-py-plugins==0.3.0\n    - more-itertools==8.10.0\n    - multidict==5.2.0\n    - multiprocess==0.70.12.2\n    - mypy==0.931\n    - mypy-extensions==0.4.3\n    - myst-parser==0.16.1\n    - nltk==3.6.7\n    - oauthlib==3.1.1\n    - opt-einsum==3.3.0\n    - overrides==6.1.0\n    - packaging==21.2\n    - pandas==1.3.4\n    - pathspec==0.9.0\n    - pathtools==0.1.2\n    - petname==2.6\n    - pkginfo==1.7.1\n    - platformdirs==2.4.0\n    - pluggy==1.0.0\n    - promise==2.3\n    - protobuf==3.19.1\n    - psutil==5.8.0\n    - py==1.11.0\n    - pyarrow==6.0.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycodestyle==2.8.0\n    - pydeprecate==0.3.1\n    - pyflakes==2.4.0\n    - pyparsing==2.4.7\n    - pytest==6.2.5\n    - pytest-cov==3.0.0\n    - pytest-sphinx==0.3.1\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.5.1\n    - pytz==2021.3\n    - pyyaml==6.0\n    - readme-renderer==30.0\n    - regex==2021.11.2\n    - requests==2.26.0\n    - requests-oauthlib==1.3.0\n    - requests-toolbelt==0.9.1\n    - rfc3986==1.5.0\n    - rouge-score==0.0.4\n    - rsa==4.7.2\n    - s3transfer==0.5.0\n    - sacremoses==0.0.46\n    - sentencepiece==0.1.96\n    - sentry-sdk==1.4.3\n    - shortuuid==1.0.1\n    - smmap==5.0.0\n    - snowballstemmer==2.1.0\n    - soupsieve==2.3\n    - sphinx==4.3.1\n    - sphinx-autobuild==2021.3.14\n    - sphinx-copybutton==0.4.0\n    - sphinxcontrib-applehelp==1.0.2\n    - sphinxcontrib-devhelp==1.0.2\n    - sphinxcontrib-htmlhelp==2.0.0\n    - sphinxcontrib-jsmath==1.0.1\n    - sphinxcontrib-qthelp==1.0.3\n    - sphinxcontrib-serializinghtml==1.1.5\n    - sqlitedict==1.7.0\n    - subprocess32==3.5.4\n    - tensorboard==2.7.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.0\n    - tensorflow==2.7.0\n    - tensorflow-estimator==2.7.0\n    - tensorflow-io-gcs-filesystem==0.23.1\n    - termcolor==1.1.0\n    - tokenizers==0.10.3\n    - toml==0.10.2\n    - tomli==1.2.2\n    - torchmetrics==0.6.0\n    - tornado==6.1\n    - tqdm==4.62.3\n    - transformers==4.12.3\n    - twine==3.5.0\n    - types-pyyaml==6.0.0\n    - types-setuptools==57.4.2\n    - typing-utils==0.1.0\n    - urllib3==1.26.7\n    - wandb==0.12.6\n    - webencodings==0.5.1\n    - werkzeug==2.0.2\n    - wrapt==1.13.3\n    - xxhash==2.0.2\n    - yarl==1.7.2\n    - yaspin==2.1.0\n    - zipp==3.6.0\nprefix: /opt/miniconda3/envs/tango\n"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/executor-metadata.json",
    "content": "{\n    \"config\": {\n        \"type\": \"cmul\",\n        \"a\": [\n            0,\n            1\n        ],\n        \"b\": [\n            3.1415926535,\n            0\n        ]\n    },\n    \"duration\": 0.0005,\n    \"finished_at\": 1642546353.7523232,\n    \"git\": {\n        \"commit\": \"8e09b66caffbff20fd0b1504c961932b97417e8d\",\n        \"remote\": \"https://github.com/allenai/tango.git\"\n    },\n    \"platform\": {\n        \"cpu_count\": 16,\n        \"executable\": \"/opt/miniconda3/envs/tango/bin/python\",\n        \"host\": \"ip-192-168-1-194.us-west-2.compute.internal\",\n        \"operating_system\": \"macOS-10.16-x86_64-i386-64bit\",\n        \"python\": \"3.8.12\",\n        \"root\": \"/Users/dirkg/Documents/tango/examples/euler\",\n        \"user\": \"dirkg\"\n    },\n    \"started_at\": 1642546353.751795,\n    \"step\": \"MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae\",\n    \"tango\": {\n        \"command\": \"/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic\",\n        \"version\": \"0.4.0rc4\"\n    }\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/lock",
    "content": ""
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/MultiplyStep-4SRzHCCqYGs2PLeT8LeL5ukrCWGJoiae/requirements.txt",
    "content": "absl-py==0.15.0\nai2-tango==0.4.0rc1\naiohttp==3.8.0\naiosignal==1.2.0\nalabaster==0.7.12\nappnope==0.1.2\nastunparse==1.6.3\nasync-timeout==4.0.0\nattrs==21.2.0\nbabel==2.9.1\nbackcall==0.2.0\nbase58==2.1.1\nbeautifulsoup4==4.10.0\nblack==21.12b0\nbleach==4.1.0\nboto3==1.19.12\nbotocore==1.22.12\ncached-path==1.0.0\ncachetools==4.2.4\ncertifi==2021.10.8\ncharset-normalizer==2.0.7\nclick-help-colors==0.9.1\nclick==8.0.3\ncodecov==2.1.12\ncolorama==0.4.4\nconfigparser==5.1.0\ncoverage==6.1.1\ndatasets==1.15.1\ndecorator==5.1.0\ndill==0.3.4\ndocker-pycreds==0.4.0\ndocutils==0.17.1\nfilelock==3.4.0\nflake8==4.0.1\nflaky==3.7.0\nflatbuffers==2.0\nfrozenlist==1.2.0\nfsspec==2021.11.0\nfuro==2022.1.2\nfuture==0.18.2\ngast==0.4.0\ngitdb==4.0.9\ngitpython==3.1.24\nglob2==0.7\ngoogle-api-core==2.2.2\ngoogle-auth-oauthlib==0.4.6\ngoogle-auth==2.3.3\ngoogle-cloud-core==2.1.0\ngoogle-cloud-storage==1.42.3\ngoogle-crc32c==1.3.0\ngoogle-pasta==0.2.0\ngoogle-resumable-media==2.1.0\ngoogleapis-common-protos==1.53.0\ngrpcio==1.41.1\nh5py==3.6.0\nhuggingface-hub==0.1.1\nidna==3.3\nimagesize==1.2.0\nimportlib-metadata==4.8.1\niniconfig==1.1.1\nipython==7.29.0\nisort==5.10.1\njedi==0.18.0\njinja2==3.0.2\njmespath==0.10.0\njoblib==1.1.0\njsonnet==0.17.0\nkeras-preprocessing==1.1.2\nkeras==2.7.0\nkeyring==23.2.1\nlibclang==12.0.0\nlivereload==2.6.3\nmarkdown-it-py==1.1.0\nmarkdown==3.3.4\nmarkupsafe==2.0.1\nmatplotlib-inline==0.1.2\nmccabe==0.6.1\nmdit-py-plugins==0.3.0\nmkl-fft==1.3.1\nmkl-random==1.2.2\nmkl-service==2.4.0\nmore-itertools==8.10.0\nmultidict==5.2.0\nmultiprocess==0.70.12.2\nmypy-extensions==0.4.3\nmypy==0.931\nmyst-parser==0.16.1\nnltk==3.6.7\nnumpy==1.21.2\noauthlib==3.1.1\nolefile==0.46\nopt-einsum==3.3.0\noverrides==6.1.0\npackaging==21.2\npandas==1.3.4\nparso==0.8.2\npathspec==0.9.0\npathtools==0.1.2\npetname==2.6\npexpect==4.8.0\npickleshare==0.7.5\npillow==8.4.0\npip==21.2.4\npkginfo==1.7.1\nplatformdirs==2.4.0\npluggy==1.0.0\npromise==2.3\nprompt-toolkit==3.0.20\nprotobuf==3.19.1\npsutil==5.8.0\nptyprocess==0.7.0\npy==1.11.0\npyarrow==6.0.0\npyasn1-modules==0.2.8\npyasn1==0.4.8\npycodestyle==2.8.0\npydeprecate==0.3.1\npyflakes==2.4.0\npygments==2.10.0\npyparsing==2.4.7\npytest-cov==3.0.0\npytest-sphinx==0.3.1\npytest==6.2.5\npython-dateutil==2.8.2\npytorch-lightning==1.5.1\npytz==2021.3\npyyaml==6.0\nreadme-renderer==30.0\nregex==2021.11.2\nrequests-oauthlib==1.3.0\nrequests-toolbelt==0.9.1\nrequests==2.26.0\nrfc3986==1.5.0\nrouge-score==0.0.4\nrsa==4.7.2\ns3transfer==0.5.0\nsacremoses==0.0.46\nsentencepiece==0.1.96\nsentry-sdk==1.4.3\nsetuptools==58.0.4\nshortuuid==1.0.1\nsix==1.16.0\nsmmap==5.0.0\nsnowballstemmer==2.1.0\nsoupsieve==2.3\nsphinx-autobuild==2021.3.14\nsphinx-copybutton==0.4.0\nsphinx==4.3.1\nsphinxcontrib-applehelp==1.0.2\nsphinxcontrib-devhelp==1.0.2\nsphinxcontrib-htmlhelp==2.0.0\nsphinxcontrib-jsmath==1.0.1\nsphinxcontrib-qthelp==1.0.3\nsphinxcontrib-serializinghtml==1.1.5\nsqlitedict==1.7.0\nsubprocess32==3.5.4\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.0\ntensorboard==2.7.0\ntensorflow-estimator==2.7.0\ntensorflow-io-gcs-filesystem==0.23.1\ntensorflow==2.7.0\ntermcolor==1.1.0\ntokenizers==0.10.3\ntoml==0.10.2\ntomli==1.2.2\ntorch==1.10.0\ntorchaudio==0.10.0\ntorchmetrics==0.6.0\ntorchvision==0.11.1\ntornado==6.1\ntqdm==4.62.3\ntraitlets==5.1.0\ntransformers==4.12.3\ntwine==3.5.0\ntypes-pyyaml==6.0.0\ntypes-setuptools==57.4.2\ntyping-extensions==3.10.0.2\ntyping-utils==0.1.0\nurllib3==1.26.7\nwandb==0.12.6\nwcwidth==0.2.5\nwebencodings==0.5.1\nwerkzeug==2.0.2\nwheel==0.37.0\nwrapt==1.13.3\nxxhash==2.0.2\nyarl==1.7.2\nyaspin==2.1.0\nzipp==3.6.0"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/cache-metadata.json",
    "content": "{\n    \"step\": \"SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk\"\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/conda-environment.yaml",
    "content": "name: tango\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - appnope=0.1.2=py38hecd8cb5_1001\n  - backcall=0.2.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - bzip2=1.0.8=h1de35cc_0\n  - ca-certificates=2021.10.26=hecd8cb5_2\n  - certifi=2021.10.8=py38hecd8cb5_0\n  - decorator=5.1.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=h0a44026_0\n  - freetype=2.11.0=hd8bbffd_0\n  - gettext=0.21.0=h7535e17_0\n  - giflib=5.2.1=haf1e3a3_0\n  - gmp=6.2.1=h23ab428_2\n  - gnutls=3.6.15=hed9c0bf_0\n  - icu=58.2=h0a44026_3\n  - intel-openmp=2021.4.0=hecd8cb5_3538\n  - ipython=7.29.0=py38h01d92e1_0\n  - jedi=0.18.0=py38hecd8cb5_1\n  - jpeg=9d=h9ed2024_0\n  - lame=3.100=h1de35cc_0\n  - lcms2=2.12=hf1fd2bf_0\n  - libcxx=12.0.0=h2f01273_0\n  - libffi=3.3=hb1e8313_2\n  - libiconv=1.16=h1de35cc_0\n  - libidn2=2.3.2=h9ed2024_0\n  - libpng=1.6.37=ha441bb4_0\n  - libtasn1=4.16.0=h9ed2024_0\n  - libtiff=4.2.0=h87d7836_0\n  - libunistring=0.9.10=h9ed2024_0\n  - libuv=1.40.0=haf1e3a3_0\n  - libwebp=1.2.0=hacca55c_0\n  - libwebp-base=1.2.0=h9ed2024_0\n  - libxml2=2.9.12=hcdb78fc_0\n  - llvm-openmp=12.0.0=h0dcd299_1\n  - lz4-c=1.9.3=h23ab428_1\n  - matplotlib-inline=0.1.2=pyhd3eb1b0_2\n  - mkl=2021.4.0=hecd8cb5_637\n  - mkl-service=2.4.0=py38h9ed2024_0\n  - mkl_fft=1.3.1=py38h4ab4a9b_0\n  - mkl_random=1.2.2=py38hb2f4e1b_0\n  - ncurses=6.3=hca72f7f_1\n  - nettle=3.7.3=h230ac6f_1\n  - numpy=1.21.2=py38h4b4dc7a_0\n  - numpy-base=1.21.2=py38he0bd621_0\n  - olefile=0.46=pyhd3eb1b0_0\n  - openh264=2.1.0=hd9629dc_0\n  - openssl=1.1.1l=h9ed2024_0\n  - parso=0.8.2=pyhd3eb1b0_0\n  - pexpect=4.8.0=pyhd3eb1b0_3\n  - pickleshare=0.7.5=pyhd3eb1b0_1003\n  - pillow=8.4.0=py38h98e4679_0\n  - pip=21.2.4=py38hecd8cb5_0\n  - prompt-toolkit=3.0.20=pyhd3eb1b0_0\n  - ptyprocess=0.7.0=pyhd3eb1b0_2\n  - pygments=2.10.0=pyhd3eb1b0_0\n  - python=3.8.12=h88f2d9e_0\n  - pytorch=1.10.0=py3.8_0\n  - readline=8.1=h9ed2024_0\n  - setuptools=58.0.4=py38hecd8cb5_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sqlite=3.36.0=hce871da_0\n  - tk=8.6.11=h7bc2e8c_0\n  - torchaudio=0.10.0=py38_cpu\n  - torchvision=0.11.1=py38_cpu\n  - traitlets=5.1.0=pyhd3eb1b0_0\n  - typing_extensions=3.10.0.2=pyh06a4308_0\n  - wcwidth=0.2.5=pyhd3eb1b0_0\n  - wheel=0.37.0=pyhd3eb1b0_1\n  - xz=5.2.5=h1de35cc_0\n  - zlib=1.2.11=h1de35cc_3\n  - zstd=1.4.9=h322a384_0\n  - pip:\n    - absl-py==0.15.0\n    - aiohttp==3.8.0\n    - aiosignal==1.2.0\n    - alabaster==0.7.12\n    - astunparse==1.6.3\n    - async-timeout==4.0.0\n    - attrs==21.2.0\n    - babel==2.9.1\n    - base58==2.1.1\n    - beautifulsoup4==4.10.0\n    - black==21.12b0\n    - bleach==4.1.0\n    - boto3==1.19.12\n    - botocore==1.22.12\n    - cached-path==1.0.0\n    - cachetools==4.2.4\n    - charset-normalizer==2.0.7\n    - click==8.0.3\n    - click-help-colors==0.9.1\n    - codecov==2.1.12\n    - colorama==0.4.4\n    - configparser==5.1.0\n    - coverage==6.1.1\n    - datasets==1.15.1\n    - dill==0.3.4\n    - docker-pycreds==0.4.0\n    - docutils==0.17.1\n    - filelock==3.4.0\n    - flake8==4.0.1\n    - flaky==3.7.0\n    - flatbuffers==2.0\n    - frozenlist==1.2.0\n    - fsspec==2021.11.0\n    - furo==2022.1.2\n    - future==0.18.2\n    - gast==0.4.0\n    - gitdb==4.0.9\n    - gitpython==3.1.24\n    - glob2==0.7\n    - google-api-core==2.2.2\n    - google-auth==2.3.3\n    - google-auth-oauthlib==0.4.6\n    - google-cloud-core==2.1.0\n    - google-cloud-storage==1.42.3\n    - google-crc32c==1.3.0\n    - google-pasta==0.2.0\n    - google-resumable-media==2.1.0\n    - googleapis-common-protos==1.53.0\n    - grpcio==1.41.1\n    - h5py==3.6.0\n    - huggingface-hub==0.1.1\n    - idna==3.3\n    - imagesize==1.2.0\n    - importlib-metadata==4.8.1\n    - iniconfig==1.1.1\n    - isort==5.10.1\n    - jinja2==3.0.2\n    - jmespath==0.10.0\n    - joblib==1.1.0\n    - jsonnet==0.17.0\n    - keras==2.7.0\n    - keras-preprocessing==1.1.2\n    - keyring==23.2.1\n    - libclang==12.0.0\n    - livereload==2.6.3\n    - markdown==3.3.4\n    - markdown-it-py==1.1.0\n    - markupsafe==2.0.1\n    - mccabe==0.6.1\n    - mdit-py-plugins==0.3.0\n    - more-itertools==8.10.0\n    - multidict==5.2.0\n    - multiprocess==0.70.12.2\n    - mypy==0.931\n    - mypy-extensions==0.4.3\n    - myst-parser==0.16.1\n    - nltk==3.6.7\n    - oauthlib==3.1.1\n    - opt-einsum==3.3.0\n    - overrides==6.1.0\n    - packaging==21.2\n    - pandas==1.3.4\n    - pathspec==0.9.0\n    - pathtools==0.1.2\n    - petname==2.6\n    - pkginfo==1.7.1\n    - platformdirs==2.4.0\n    - pluggy==1.0.0\n    - promise==2.3\n    - protobuf==3.19.1\n    - psutil==5.8.0\n    - py==1.11.0\n    - pyarrow==6.0.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycodestyle==2.8.0\n    - pydeprecate==0.3.1\n    - pyflakes==2.4.0\n    - pyparsing==2.4.7\n    - pytest==6.2.5\n    - pytest-cov==3.0.0\n    - pytest-sphinx==0.3.1\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.5.1\n    - pytz==2021.3\n    - pyyaml==6.0\n    - readme-renderer==30.0\n    - regex==2021.11.2\n    - requests==2.26.0\n    - requests-oauthlib==1.3.0\n    - requests-toolbelt==0.9.1\n    - rfc3986==1.5.0\n    - rouge-score==0.0.4\n    - rsa==4.7.2\n    - s3transfer==0.5.0\n    - sacremoses==0.0.46\n    - sentencepiece==0.1.96\n    - sentry-sdk==1.4.3\n    - shortuuid==1.0.1\n    - smmap==5.0.0\n    - snowballstemmer==2.1.0\n    - soupsieve==2.3\n    - sphinx==4.3.1\n    - sphinx-autobuild==2021.3.14\n    - sphinx-copybutton==0.4.0\n    - sphinxcontrib-applehelp==1.0.2\n    - sphinxcontrib-devhelp==1.0.2\n    - sphinxcontrib-htmlhelp==2.0.0\n    - sphinxcontrib-jsmath==1.0.1\n    - sphinxcontrib-qthelp==1.0.3\n    - sphinxcontrib-serializinghtml==1.1.5\n    - sqlitedict==1.7.0\n    - subprocess32==3.5.4\n    - tensorboard==2.7.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.0\n    - tensorflow==2.7.0\n    - tensorflow-estimator==2.7.0\n    - tensorflow-io-gcs-filesystem==0.23.1\n    - termcolor==1.1.0\n    - tokenizers==0.10.3\n    - toml==0.10.2\n    - tomli==1.2.2\n    - torchmetrics==0.6.0\n    - tornado==6.1\n    - tqdm==4.62.3\n    - transformers==4.12.3\n    - twine==3.5.0\n    - types-pyyaml==6.0.0\n    - types-setuptools==57.4.2\n    - typing-utils==0.1.0\n    - urllib3==1.26.7\n    - wandb==0.12.6\n    - webencodings==0.5.1\n    - werkzeug==2.0.2\n    - wrapt==1.13.3\n    - xxhash==2.0.2\n    - yarl==1.7.2\n    - yaspin==2.1.0\n    - zipp==3.6.0\nprefix: /opt/miniconda3/envs/tango\n"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/executor-metadata.json",
    "content": "{\n    \"config\": {\n        \"type\": \"csin\",\n        \"x\": [\n            3.1415926535,\n            0\n        ]\n    },\n    \"duration\": 0.0004,\n    \"finished_at\": 1642546356.265017,\n    \"git\": {\n        \"commit\": \"8e09b66caffbff20fd0b1504c961932b97417e8d\",\n        \"remote\": \"https://github.com/allenai/tango.git\"\n    },\n    \"platform\": {\n        \"cpu_count\": 16,\n        \"executable\": \"/opt/miniconda3/envs/tango/bin/python\",\n        \"host\": \"ip-192-168-1-194.us-west-2.compute.internal\",\n        \"operating_system\": \"macOS-10.16-x86_64-i386-64bit\",\n        \"python\": \"3.8.12\",\n        \"root\": \"/Users/dirkg/Documents/tango/examples/euler\",\n        \"user\": \"dirkg\"\n    },\n    \"started_at\": 1642546356.264595,\n    \"step\": \"SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk\",\n    \"tango\": {\n        \"command\": \"/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic\",\n        \"version\": \"0.4.0rc4\"\n    }\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/lock",
    "content": ""
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SineStep-5aes9CUTRmkz5gJ5J6JSRbJZ4qkFu4kk/requirements.txt",
    "content": "absl-py==0.15.0\nai2-tango==0.4.0rc1\naiohttp==3.8.0\naiosignal==1.2.0\nalabaster==0.7.12\nappnope==0.1.2\nastunparse==1.6.3\nasync-timeout==4.0.0\nattrs==21.2.0\nbabel==2.9.1\nbackcall==0.2.0\nbase58==2.1.1\nbeautifulsoup4==4.10.0\nblack==21.12b0\nbleach==4.1.0\nboto3==1.19.12\nbotocore==1.22.12\ncached-path==1.0.0\ncachetools==4.2.4\ncertifi==2021.10.8\ncharset-normalizer==2.0.7\nclick-help-colors==0.9.1\nclick==8.0.3\ncodecov==2.1.12\ncolorama==0.4.4\nconfigparser==5.1.0\ncoverage==6.1.1\ndatasets==1.15.1\ndecorator==5.1.0\ndill==0.3.4\ndocker-pycreds==0.4.0\ndocutils==0.17.1\nfilelock==3.4.0\nflake8==4.0.1\nflaky==3.7.0\nflatbuffers==2.0\nfrozenlist==1.2.0\nfsspec==2021.11.0\nfuro==2022.1.2\nfuture==0.18.2\ngast==0.4.0\ngitdb==4.0.9\ngitpython==3.1.24\nglob2==0.7\ngoogle-api-core==2.2.2\ngoogle-auth-oauthlib==0.4.6\ngoogle-auth==2.3.3\ngoogle-cloud-core==2.1.0\ngoogle-cloud-storage==1.42.3\ngoogle-crc32c==1.3.0\ngoogle-pasta==0.2.0\ngoogle-resumable-media==2.1.0\ngoogleapis-common-protos==1.53.0\ngrpcio==1.41.1\nh5py==3.6.0\nhuggingface-hub==0.1.1\nidna==3.3\nimagesize==1.2.0\nimportlib-metadata==4.8.1\niniconfig==1.1.1\nipython==7.29.0\nisort==5.10.1\njedi==0.18.0\njinja2==3.0.2\njmespath==0.10.0\njoblib==1.1.0\njsonnet==0.17.0\nkeras-preprocessing==1.1.2\nkeras==2.7.0\nkeyring==23.2.1\nlibclang==12.0.0\nlivereload==2.6.3\nmarkdown-it-py==1.1.0\nmarkdown==3.3.4\nmarkupsafe==2.0.1\nmatplotlib-inline==0.1.2\nmccabe==0.6.1\nmdit-py-plugins==0.3.0\nmkl-fft==1.3.1\nmkl-random==1.2.2\nmkl-service==2.4.0\nmore-itertools==8.10.0\nmultidict==5.2.0\nmultiprocess==0.70.12.2\nmypy-extensions==0.4.3\nmypy==0.931\nmyst-parser==0.16.1\nnltk==3.6.7\nnumpy==1.21.2\noauthlib==3.1.1\nolefile==0.46\nopt-einsum==3.3.0\noverrides==6.1.0\npackaging==21.2\npandas==1.3.4\nparso==0.8.2\npathspec==0.9.0\npathtools==0.1.2\npetname==2.6\npexpect==4.8.0\npickleshare==0.7.5\npillow==8.4.0\npip==21.2.4\npkginfo==1.7.1\nplatformdirs==2.4.0\npluggy==1.0.0\npromise==2.3\nprompt-toolkit==3.0.20\nprotobuf==3.19.1\npsutil==5.8.0\nptyprocess==0.7.0\npy==1.11.0\npyarrow==6.0.0\npyasn1-modules==0.2.8\npyasn1==0.4.8\npycodestyle==2.8.0\npydeprecate==0.3.1\npyflakes==2.4.0\npygments==2.10.0\npyparsing==2.4.7\npytest-cov==3.0.0\npytest-sphinx==0.3.1\npytest==6.2.5\npython-dateutil==2.8.2\npytorch-lightning==1.5.1\npytz==2021.3\npyyaml==6.0\nreadme-renderer==30.0\nregex==2021.11.2\nrequests-oauthlib==1.3.0\nrequests-toolbelt==0.9.1\nrequests==2.26.0\nrfc3986==1.5.0\nrouge-score==0.0.4\nrsa==4.7.2\ns3transfer==0.5.0\nsacremoses==0.0.46\nsentencepiece==0.1.96\nsentry-sdk==1.4.3\nsetuptools==58.0.4\nshortuuid==1.0.1\nsix==1.16.0\nsmmap==5.0.0\nsnowballstemmer==2.1.0\nsoupsieve==2.3\nsphinx-autobuild==2021.3.14\nsphinx-copybutton==0.4.0\nsphinx==4.3.1\nsphinxcontrib-applehelp==1.0.2\nsphinxcontrib-devhelp==1.0.2\nsphinxcontrib-htmlhelp==2.0.0\nsphinxcontrib-jsmath==1.0.1\nsphinxcontrib-qthelp==1.0.3\nsphinxcontrib-serializinghtml==1.1.5\nsqlitedict==1.7.0\nsubprocess32==3.5.4\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.0\ntensorboard==2.7.0\ntensorflow-estimator==2.7.0\ntensorflow-io-gcs-filesystem==0.23.1\ntensorflow==2.7.0\ntermcolor==1.1.0\ntokenizers==0.10.3\ntoml==0.10.2\ntomli==1.2.2\ntorch==1.10.0\ntorchaudio==0.10.0\ntorchmetrics==0.6.0\ntorchvision==0.11.1\ntornado==6.1\ntqdm==4.62.3\ntraitlets==5.1.0\ntransformers==4.12.3\ntwine==3.5.0\ntypes-pyyaml==6.0.0\ntypes-setuptools==57.4.2\ntyping-extensions==3.10.0.2\ntyping-utils==0.1.0\nurllib3==1.26.7\nwandb==0.12.6\nwcwidth==0.2.5\nwebencodings==0.5.1\nwerkzeug==2.0.2\nwheel==0.37.0\nwrapt==1.13.3\nxxhash==2.0.2\nyarl==1.7.2\nyaspin==2.1.0\nzipp==3.6.0"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/cache-metadata.json",
    "content": "{\n    \"step\": \"SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz\"\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/conda-environment.yaml",
    "content": "name: tango\nchannels:\n  - pytorch\n  - defaults\ndependencies:\n  - appnope=0.1.2=py38hecd8cb5_1001\n  - backcall=0.2.0=pyhd3eb1b0_0\n  - blas=1.0=mkl\n  - bzip2=1.0.8=h1de35cc_0\n  - ca-certificates=2021.10.26=hecd8cb5_2\n  - certifi=2021.10.8=py38hecd8cb5_0\n  - decorator=5.1.0=pyhd3eb1b0_0\n  - ffmpeg=4.3=h0a44026_0\n  - freetype=2.11.0=hd8bbffd_0\n  - gettext=0.21.0=h7535e17_0\n  - giflib=5.2.1=haf1e3a3_0\n  - gmp=6.2.1=h23ab428_2\n  - gnutls=3.6.15=hed9c0bf_0\n  - icu=58.2=h0a44026_3\n  - intel-openmp=2021.4.0=hecd8cb5_3538\n  - ipython=7.29.0=py38h01d92e1_0\n  - jedi=0.18.0=py38hecd8cb5_1\n  - jpeg=9d=h9ed2024_0\n  - lame=3.100=h1de35cc_0\n  - lcms2=2.12=hf1fd2bf_0\n  - libcxx=12.0.0=h2f01273_0\n  - libffi=3.3=hb1e8313_2\n  - libiconv=1.16=h1de35cc_0\n  - libidn2=2.3.2=h9ed2024_0\n  - libpng=1.6.37=ha441bb4_0\n  - libtasn1=4.16.0=h9ed2024_0\n  - libtiff=4.2.0=h87d7836_0\n  - libunistring=0.9.10=h9ed2024_0\n  - libuv=1.40.0=haf1e3a3_0\n  - libwebp=1.2.0=hacca55c_0\n  - libwebp-base=1.2.0=h9ed2024_0\n  - libxml2=2.9.12=hcdb78fc_0\n  - llvm-openmp=12.0.0=h0dcd299_1\n  - lz4-c=1.9.3=h23ab428_1\n  - matplotlib-inline=0.1.2=pyhd3eb1b0_2\n  - mkl=2021.4.0=hecd8cb5_637\n  - mkl-service=2.4.0=py38h9ed2024_0\n  - mkl_fft=1.3.1=py38h4ab4a9b_0\n  - mkl_random=1.2.2=py38hb2f4e1b_0\n  - ncurses=6.3=hca72f7f_1\n  - nettle=3.7.3=h230ac6f_1\n  - numpy=1.21.2=py38h4b4dc7a_0\n  - numpy-base=1.21.2=py38he0bd621_0\n  - olefile=0.46=pyhd3eb1b0_0\n  - openh264=2.1.0=hd9629dc_0\n  - openssl=1.1.1l=h9ed2024_0\n  - parso=0.8.2=pyhd3eb1b0_0\n  - pexpect=4.8.0=pyhd3eb1b0_3\n  - pickleshare=0.7.5=pyhd3eb1b0_1003\n  - pillow=8.4.0=py38h98e4679_0\n  - pip=21.2.4=py38hecd8cb5_0\n  - prompt-toolkit=3.0.20=pyhd3eb1b0_0\n  - ptyprocess=0.7.0=pyhd3eb1b0_2\n  - pygments=2.10.0=pyhd3eb1b0_0\n  - python=3.8.12=h88f2d9e_0\n  - pytorch=1.10.0=py3.8_0\n  - readline=8.1=h9ed2024_0\n  - setuptools=58.0.4=py38hecd8cb5_0\n  - six=1.16.0=pyhd3eb1b0_0\n  - sqlite=3.36.0=hce871da_0\n  - tk=8.6.11=h7bc2e8c_0\n  - torchaudio=0.10.0=py38_cpu\n  - torchvision=0.11.1=py38_cpu\n  - traitlets=5.1.0=pyhd3eb1b0_0\n  - typing_extensions=3.10.0.2=pyh06a4308_0\n  - wcwidth=0.2.5=pyhd3eb1b0_0\n  - wheel=0.37.0=pyhd3eb1b0_1\n  - xz=5.2.5=h1de35cc_0\n  - zlib=1.2.11=h1de35cc_3\n  - zstd=1.4.9=h322a384_0\n  - pip:\n    - absl-py==0.15.0\n    - aiohttp==3.8.0\n    - aiosignal==1.2.0\n    - alabaster==0.7.12\n    - astunparse==1.6.3\n    - async-timeout==4.0.0\n    - attrs==21.2.0\n    - babel==2.9.1\n    - base58==2.1.1\n    - beautifulsoup4==4.10.0\n    - black==21.12b0\n    - bleach==4.1.0\n    - boto3==1.19.12\n    - botocore==1.22.12\n    - cached-path==1.0.0\n    - cachetools==4.2.4\n    - charset-normalizer==2.0.7\n    - click==8.0.3\n    - click-help-colors==0.9.1\n    - codecov==2.1.12\n    - colorama==0.4.4\n    - configparser==5.1.0\n    - coverage==6.1.1\n    - datasets==1.15.1\n    - dill==0.3.4\n    - docker-pycreds==0.4.0\n    - docutils==0.17.1\n    - filelock==3.4.0\n    - flake8==4.0.1\n    - flaky==3.7.0\n    - flatbuffers==2.0\n    - frozenlist==1.2.0\n    - fsspec==2021.11.0\n    - furo==2022.1.2\n    - future==0.18.2\n    - gast==0.4.0\n    - gitdb==4.0.9\n    - gitpython==3.1.24\n    - glob2==0.7\n    - google-api-core==2.2.2\n    - google-auth==2.3.3\n    - google-auth-oauthlib==0.4.6\n    - google-cloud-core==2.1.0\n    - google-cloud-storage==1.42.3\n    - google-crc32c==1.3.0\n    - google-pasta==0.2.0\n    - google-resumable-media==2.1.0\n    - googleapis-common-protos==1.53.0\n    - grpcio==1.41.1\n    - h5py==3.6.0\n    - huggingface-hub==0.1.1\n    - idna==3.3\n    - imagesize==1.2.0\n    - importlib-metadata==4.8.1\n    - iniconfig==1.1.1\n    - isort==5.10.1\n    - jinja2==3.0.2\n    - jmespath==0.10.0\n    - joblib==1.1.0\n    - jsonnet==0.17.0\n    - keras==2.7.0\n    - keras-preprocessing==1.1.2\n    - keyring==23.2.1\n    - libclang==12.0.0\n    - livereload==2.6.3\n    - markdown==3.3.4\n    - markdown-it-py==1.1.0\n    - markupsafe==2.0.1\n    - mccabe==0.6.1\n    - mdit-py-plugins==0.3.0\n    - more-itertools==8.10.0\n    - multidict==5.2.0\n    - multiprocess==0.70.12.2\n    - mypy==0.931\n    - mypy-extensions==0.4.3\n    - myst-parser==0.16.1\n    - nltk==3.6.7\n    - oauthlib==3.1.1\n    - opt-einsum==3.3.0\n    - overrides==6.1.0\n    - packaging==21.2\n    - pandas==1.3.4\n    - pathspec==0.9.0\n    - pathtools==0.1.2\n    - petname==2.6\n    - pkginfo==1.7.1\n    - platformdirs==2.4.0\n    - pluggy==1.0.0\n    - promise==2.3\n    - protobuf==3.19.1\n    - psutil==5.8.0\n    - py==1.11.0\n    - pyarrow==6.0.0\n    - pyasn1==0.4.8\n    - pyasn1-modules==0.2.8\n    - pycodestyle==2.8.0\n    - pydeprecate==0.3.1\n    - pyflakes==2.4.0\n    - pyparsing==2.4.7\n    - pytest==6.2.5\n    - pytest-cov==3.0.0\n    - pytest-sphinx==0.3.1\n    - python-dateutil==2.8.2\n    - pytorch-lightning==1.5.1\n    - pytz==2021.3\n    - pyyaml==6.0\n    - readme-renderer==30.0\n    - regex==2021.11.2\n    - requests==2.26.0\n    - requests-oauthlib==1.3.0\n    - requests-toolbelt==0.9.1\n    - rfc3986==1.5.0\n    - rouge-score==0.0.4\n    - rsa==4.7.2\n    - s3transfer==0.5.0\n    - sacremoses==0.0.46\n    - sentencepiece==0.1.96\n    - sentry-sdk==1.4.3\n    - shortuuid==1.0.1\n    - smmap==5.0.0\n    - snowballstemmer==2.1.0\n    - soupsieve==2.3\n    - sphinx==4.3.1\n    - sphinx-autobuild==2021.3.14\n    - sphinx-copybutton==0.4.0\n    - sphinxcontrib-applehelp==1.0.2\n    - sphinxcontrib-devhelp==1.0.2\n    - sphinxcontrib-htmlhelp==2.0.0\n    - sphinxcontrib-jsmath==1.0.1\n    - sphinxcontrib-qthelp==1.0.3\n    - sphinxcontrib-serializinghtml==1.1.5\n    - sqlitedict==1.7.0\n    - subprocess32==3.5.4\n    - tensorboard==2.7.0\n    - tensorboard-data-server==0.6.1\n    - tensorboard-plugin-wit==1.8.0\n    - tensorflow==2.7.0\n    - tensorflow-estimator==2.7.0\n    - tensorflow-io-gcs-filesystem==0.23.1\n    - termcolor==1.1.0\n    - tokenizers==0.10.3\n    - toml==0.10.2\n    - tomli==1.2.2\n    - torchmetrics==0.6.0\n    - tornado==6.1\n    - tqdm==4.62.3\n    - transformers==4.12.3\n    - twine==3.5.0\n    - types-pyyaml==6.0.0\n    - types-setuptools==57.4.2\n    - typing-utils==0.1.0\n    - urllib3==1.26.7\n    - wandb==0.12.6\n    - webencodings==0.5.1\n    - werkzeug==2.0.2\n    - wrapt==1.13.3\n    - xxhash==2.0.2\n    - yarl==1.7.2\n    - yaspin==2.1.0\n    - zipp==3.6.0\nprefix: /opt/miniconda3/envs/tango\n"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/executor-metadata.json",
    "content": "{\n    \"config\": {\n        \"type\": \"csub\",\n        \"a\": {\n            \"type\": \"ref\",\n            \"ref\": \"AdditionStep-34AiXoyiPKADMUnhcBzFYd6JeMcgx4DP\"\n        },\n        \"b\": {\n            \"type\": \"ref\",\n            \"ref\": \"ExponentiateStep-Rf73w34zWJcBrQafpAkxDvXR4mq3MXC9\"\n        }\n    },\n    \"duration\": 0.0005,\n    \"finished_at\": 1642546366.57007,\n    \"git\": {\n        \"commit\": \"8e09b66caffbff20fd0b1504c961932b97417e8d\",\n        \"remote\": \"https://github.com/allenai/tango.git\"\n    },\n    \"platform\": {\n        \"cpu_count\": 16,\n        \"executable\": \"/opt/miniconda3/envs/tango/bin/python\",\n        \"host\": \"ip-192-168-1-194.us-west-2.compute.internal\",\n        \"operating_system\": \"macOS-10.16-x86_64-i386-64bit\",\n        \"python\": \"3.8.12\",\n        \"root\": \"/Users/dirkg/Documents/tango/examples/euler\",\n        \"user\": \"dirkg\"\n    },\n    \"started_at\": 1642546366.569589,\n    \"step\": \"SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz\",\n    \"tango\": {\n        \"command\": \"/opt/miniconda3/envs/tango/bin/tango run euler_general.jsonnet -d workspace --include-package complex_arithmetic\",\n        \"version\": \"0.4.0rc4\"\n    }\n}"
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/lock",
    "content": ""
  },
  {
    "path": "test_fixtures/v1_local_workspace/cache/SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz/requirements.txt",
    "content": "absl-py==0.15.0\nai2-tango==0.4.0rc1\naiohttp==3.8.0\naiosignal==1.2.0\nalabaster==0.7.12\nappnope==0.1.2\nastunparse==1.6.3\nasync-timeout==4.0.0\nattrs==21.2.0\nbabel==2.9.1\nbackcall==0.2.0\nbase58==2.1.1\nbeautifulsoup4==4.10.0\nblack==21.12b0\nbleach==4.1.0\nboto3==1.19.12\nbotocore==1.22.12\ncached-path==1.0.0\ncachetools==4.2.4\ncertifi==2021.10.8\ncharset-normalizer==2.0.7\nclick-help-colors==0.9.1\nclick==8.0.3\ncodecov==2.1.12\ncolorama==0.4.4\nconfigparser==5.1.0\ncoverage==6.1.1\ndatasets==1.15.1\ndecorator==5.1.0\ndill==0.3.4\ndocker-pycreds==0.4.0\ndocutils==0.17.1\nfilelock==3.4.0\nflake8==4.0.1\nflaky==3.7.0\nflatbuffers==2.0\nfrozenlist==1.2.0\nfsspec==2021.11.0\nfuro==2022.1.2\nfuture==0.18.2\ngast==0.4.0\ngitdb==4.0.9\ngitpython==3.1.24\nglob2==0.7\ngoogle-api-core==2.2.2\ngoogle-auth-oauthlib==0.4.6\ngoogle-auth==2.3.3\ngoogle-cloud-core==2.1.0\ngoogle-cloud-storage==1.42.3\ngoogle-crc32c==1.3.0\ngoogle-pasta==0.2.0\ngoogle-resumable-media==2.1.0\ngoogleapis-common-protos==1.53.0\ngrpcio==1.41.1\nh5py==3.6.0\nhuggingface-hub==0.1.1\nidna==3.3\nimagesize==1.2.0\nimportlib-metadata==4.8.1\niniconfig==1.1.1\nipython==7.29.0\nisort==5.10.1\njedi==0.18.0\njinja2==3.0.2\njmespath==0.10.0\njoblib==1.1.0\njsonnet==0.17.0\nkeras-preprocessing==1.1.2\nkeras==2.7.0\nkeyring==23.2.1\nlibclang==12.0.0\nlivereload==2.6.3\nmarkdown-it-py==1.1.0\nmarkdown==3.3.4\nmarkupsafe==2.0.1\nmatplotlib-inline==0.1.2\nmccabe==0.6.1\nmdit-py-plugins==0.3.0\nmkl-fft==1.3.1\nmkl-random==1.2.2\nmkl-service==2.4.0\nmore-itertools==8.10.0\nmultidict==5.2.0\nmultiprocess==0.70.12.2\nmypy-extensions==0.4.3\nmypy==0.931\nmyst-parser==0.16.1\nnltk==3.6.7\nnumpy==1.21.2\noauthlib==3.1.1\nolefile==0.46\nopt-einsum==3.3.0\noverrides==6.1.0\npackaging==21.2\npandas==1.3.4\nparso==0.8.2\npathspec==0.9.0\npathtools==0.1.2\npetname==2.6\npexpect==4.8.0\npickleshare==0.7.5\npillow==8.4.0\npip==21.2.4\npkginfo==1.7.1\nplatformdirs==2.4.0\npluggy==1.0.0\npromise==2.3\nprompt-toolkit==3.0.20\nprotobuf==3.19.1\npsutil==5.8.0\nptyprocess==0.7.0\npy==1.11.0\npyarrow==6.0.0\npyasn1-modules==0.2.8\npyasn1==0.4.8\npycodestyle==2.8.0\npydeprecate==0.3.1\npyflakes==2.4.0\npygments==2.10.0\npyparsing==2.4.7\npytest-cov==3.0.0\npytest-sphinx==0.3.1\npytest==6.2.5\npython-dateutil==2.8.2\npytorch-lightning==1.5.1\npytz==2021.3\npyyaml==6.0\nreadme-renderer==30.0\nregex==2021.11.2\nrequests-oauthlib==1.3.0\nrequests-toolbelt==0.9.1\nrequests==2.26.0\nrfc3986==1.5.0\nrouge-score==0.0.4\nrsa==4.7.2\ns3transfer==0.5.0\nsacremoses==0.0.46\nsentencepiece==0.1.96\nsentry-sdk==1.4.3\nsetuptools==58.0.4\nshortuuid==1.0.1\nsix==1.16.0\nsmmap==5.0.0\nsnowballstemmer==2.1.0\nsoupsieve==2.3\nsphinx-autobuild==2021.3.14\nsphinx-copybutton==0.4.0\nsphinx==4.3.1\nsphinxcontrib-applehelp==1.0.2\nsphinxcontrib-devhelp==1.0.2\nsphinxcontrib-htmlhelp==2.0.0\nsphinxcontrib-jsmath==1.0.1\nsphinxcontrib-qthelp==1.0.3\nsphinxcontrib-serializinghtml==1.1.5\nsqlitedict==1.7.0\nsubprocess32==3.5.4\ntensorboard-data-server==0.6.1\ntensorboard-plugin-wit==1.8.0\ntensorboard==2.7.0\ntensorflow-estimator==2.7.0\ntensorflow-io-gcs-filesystem==0.23.1\ntensorflow==2.7.0\ntermcolor==1.1.0\ntokenizers==0.10.3\ntoml==0.10.2\ntomli==1.2.2\ntorch==1.10.0\ntorchaudio==0.10.0\ntorchmetrics==0.6.0\ntorchvision==0.11.1\ntornado==6.1\ntqdm==4.62.3\ntraitlets==5.1.0\ntransformers==4.12.3\ntwine==3.5.0\ntypes-pyyaml==6.0.0\ntypes-setuptools==57.4.2\ntyping-extensions==3.10.0.2\ntyping-utils==0.1.0\nurllib3==1.26.7\nwandb==0.12.6\nwcwidth==0.2.5\nwebencodings==0.5.1\nwerkzeug==2.0.2\nwheel==0.37.0\nwrapt==1.13.3\nxxhash==2.0.2\nyarl==1.7.2\nyaspin==2.1.0\nzipp==3.6.0"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/common/__init__.py",
    "content": ""
  },
  {
    "path": "tests/common/dataset_dict_test.py",
    "content": "from tango.common.dataset_dict import DatasetDict\n\n\ndef test_dataset_dict():\n    dataset_dict = DatasetDict(splits={\"train\": list(range(10)), \"test\": list(range(5))})\n    assert len(dataset_dict) == 2\n    assert \"train\" in dataset_dict\n    assert \"test\" in dataset_dict\n    assert len(dataset_dict[\"train\"]) == 10\n    assert len(dataset_dict[\"test\"]) == 5\n    assert set(dataset_dict) == set(dataset_dict.keys()) == {\"train\", \"test\"}\n"
  },
  {
    "path": "tests/common/det_hash_test.py",
    "content": "from tango.common.det_hash import DetHashWithVersion, det_hash\n\n\ndef test_normal_det_hash():\n    class C:\n        VERSION = 1\n\n        def __init__(self, x: int):\n            self.x = x\n\n    c1_1 = C(10)\n    c2_1 = C(10)\n    c3_1 = C(20)\n    assert det_hash(c1_1) == det_hash(c2_1)\n    assert det_hash(c3_1) != det_hash(c2_1)\n\n    class C:\n        VERSION = 2\n\n        def __init__(self, x: int):\n            self.x = x\n\n    c1_2 = C(10)\n    c2_2 = C(10)\n    c3_2 = C(20)\n    assert det_hash(c1_2) == det_hash(c2_2)\n    assert det_hash(c3_2) != det_hash(c2_2)\n    assert det_hash(c1_2) == det_hash(c1_1)  # because the version isn't taken into account\n    assert det_hash(c3_2) == det_hash(c3_1)  # because the version isn't taken into account\n\n\ndef test_versioned_det_hash():\n    class C(DetHashWithVersion):\n        VERSION = \"1\"\n\n        def __init__(self, x: int):\n            self.x = x\n\n    c1_1 = C(10)\n    c2_1 = C(10)\n    c3_1 = C(20)\n    assert det_hash(c1_1) == det_hash(c2_1)\n    assert det_hash(c3_1) != det_hash(c2_1)\n\n    class C(DetHashWithVersion):\n        VERSION = \"2\"\n\n        def __init__(self, x: int):\n            self.x = x\n\n    c1_2 = C(10)\n    c2_2 = C(10)\n    c3_2 = C(20)\n    assert det_hash(c1_2) == det_hash(c2_2)\n    assert det_hash(c3_2) != det_hash(c2_2)\n    assert det_hash(c1_2) != det_hash(c1_1)  # because the version is taken into account\n    assert det_hash(c3_2) != det_hash(c3_1)  # because the version is taken into account\n"
  },
  {
    "path": "tests/common/from_params_pep_563_test.py",
    "content": "\"\"\"\nThis tests `FromParams` functionality with https://www.python.org/dev/peps/pep-0563/.\n\"\"\"\n\nfrom __future__ import annotations\n\nfrom tango.common.from_params import FromParams, infer_method_params\nfrom tango.common.lazy import Lazy\n\n\nclass Foo(FromParams):\n    def __init__(self, x: int):\n        self.x = x\n\n\nclass Bar(FromParams):\n    def __init__(self, foo: Lazy[Foo]):\n        self.foo = foo.construct()\n\n\nclass Baz(FromParams):\n    def __init__(self, bar: Lazy[Bar]):\n        self.bar = bar.construct()\n\n\ndef test_infer_method_params():\n    parameters = infer_method_params(Baz, Baz.__init__)\n    assert not isinstance(parameters[\"bar\"].annotation, str)\n\n\ndef test_from_params():\n    baz = Baz.from_params({\"bar\": {\"foo\": {\"x\": 1}}})\n    assert baz.bar.foo.x == 1\n"
  },
  {
    "path": "tests/common/from_params_test.py",
    "content": "import sys\nfrom copy import deepcopy\nfrom dataclasses import dataclass\nfrom numbers import Number\nfrom typing import (\n    Dict,\n    Generic,\n    Iterable,\n    List,\n    Mapping,\n    Optional,\n    Set,\n    Tuple,\n    TypeVar,\n    Union,\n)\n\nimport pytest\n\nfrom tango.common import det_hash\nfrom tango.common.det_hash import DetHashWithVersion\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.from_params import (\n    FromParams,\n    create_kwargs,\n    is_base_registrable,\n    remove_optional,\n    takes_arg,\n)\nfrom tango.common.lazy import Lazy\nfrom tango.common.params import Params\nfrom tango.common.registrable import Registrable\nfrom tango.common.testing import TangoTestCase\nfrom tango.step import Step\n\n\nclass TestFromParams(TangoTestCase):\n    def test_takes_arg(self):\n        def bare_function(some_input: int) -> int:\n            return some_input + 1\n\n        assert takes_arg(bare_function, \"some_input\")\n        assert not takes_arg(bare_function, \"some_other_input\")\n\n        class SomeClass:\n            total = 0\n\n            def __init__(self, constructor_param: str) -> None:\n                self.constructor_param = constructor_param\n\n            def check_param(self, check: str) -> bool:\n                return self.constructor_param == check\n\n            @classmethod\n            def set_total(cls, new_total: int) -> None:\n                cls.total = new_total\n\n        assert takes_arg(SomeClass, \"self\")\n        assert takes_arg(SomeClass, \"constructor_param\")\n        assert not takes_arg(SomeClass, \"check\")\n\n        assert takes_arg(SomeClass.check_param, \"check\")\n        assert not takes_arg(SomeClass.check_param, \"other_check\")\n\n        assert takes_arg(SomeClass.set_total, \"new_total\")\n        assert not takes_arg(SomeClass.set_total, \"total\")\n\n    def test_remove_optional(self):\n        optional_type = Optional[Dict[str, str]]\n        bare_type = remove_optional(optional_type)  # type: ignore\n        bare_bare_type = remove_optional(bare_type)\n\n        assert bare_type == Dict[str, str]\n        assert bare_bare_type == Dict[str, str]\n\n        assert remove_optional(Optional[str]) == str  # type: ignore[arg-type]\n        assert remove_optional(str) == str\n\n    @pytest.mark.parametrize(\"input_type\", [dict, Params])\n    def test_from_params(self, input_type):\n        params = {\"my_int\": 10}\n        my_class = MyClass.from_params(input_type(params), my_bool=True)\n\n        assert isinstance(my_class, MyClass)\n        assert my_class.my_int == 10\n        assert my_class.my_bool\n\n    def test_create_kwargs(self):\n        kwargs = create_kwargs(\n            MyClass, MyClass, Params({\"my_int\": 5}), dict(my_bool=True, my_float=4.4)\n        )\n\n        # my_float should not be included because it's not a param of the MyClass constructor\n        assert kwargs == {\"my_int\": 5, \"my_bool\": True}\n\n    def test_extras(self):\n        class A(Registrable):\n            pass\n\n        @A.register(\"b\")\n        class B(A):\n            def __init__(self, size: int, name: str) -> None:\n                self.size = size\n                self.name = name\n\n        @A.register(\"c\")\n        class C(A):\n            def __init__(self, size: int, name: str) -> None:\n                self.size = size\n                self.name = name\n\n            # custom from params\n            @classmethod\n            def from_params(cls, params: Params, size: int, **extras) -> \"C\":  # type: ignore\n                name = params.pop(\"name\")\n                return cls(size=size, name=name)\n\n        # Check that extras get passed, even though A doesn't need them.\n        params = Params({\"type\": \"b\", \"size\": 10})\n        b: B = A.from_params(params, name=\"extra\")  # type: ignore[assignment]\n\n        assert b.name == \"extra\"\n        assert b.size == 10\n\n        # Check that extra extras don't get passed.\n        params = Params({\"type\": \"b\", \"size\": 10})\n        b = A.from_params(params, name=\"extra\", unwanted=True)  # type: ignore[assignment]\n\n        assert b.name == \"extra\"  # type: ignore[attr-defined]\n        assert b.size == 10  # type: ignore[attr-defined]\n\n        # Now the same with a custom from_params.\n        params = Params({\"type\": \"c\", \"name\": \"extra_c\"})\n        c: C = A.from_params(params, size=20)  # type: ignore[assignment]\n        assert c.name == \"extra_c\"\n        assert c.size == 20\n\n        # Check that extra extras don't get passed.\n        params = Params({\"type\": \"c\", \"name\": \"extra_c\"})\n        c = A.from_params(params, size=20, unwanted=True)  # type: ignore[assignment]\n\n        assert c.name == \"extra_c\"  # type: ignore[attr-defined]\n        assert c.size == 20  # type: ignore[attr-defined]\n\n    def test_variable_length_tuple(self):\n        class Foo(FromParams):\n            def __init__(self, x: Tuple[Optional[int], ...]):\n                self.x = x\n\n        assert Foo.from_params({\"x\": [None, 1, 2, 3]}).x == (None, 1, 2, 3)\n        assert Foo.from_params({\"x\": [1, 2]}).x == (1, 2)\n        assert Foo.from_params({\"x\": [1]}).x == (1,)\n\n    def test_union(self):\n        class A(FromParams):\n            def __init__(self, a: Union[int, List[int]]) -> None:\n                self.a = a\n\n        class B(FromParams):\n            def __init__(self, b: Union[A, List[A]]) -> None:\n                # Really you would want to be sure that `self.b` has a consistent type, but for\n                # this test we'll ignore that.\n                self.b = b\n\n        params = Params({\"a\": 3})\n        a = A.from_params(params)\n        assert a.a == 3\n\n        params = Params({\"a\": [3, 4, 5]})\n        a = A.from_params(params)\n        assert a.a == [3, 4, 5]\n\n        params = Params({\"b\": {\"a\": 3}})\n        b = B.from_params(params)\n        assert isinstance(b.b, A)\n        assert b.b.a == 3\n\n        params = Params({\"b\": [{\"a\": 3}, {\"a\": [4, 5]}]})\n        b = B.from_params(params)\n        assert isinstance(b.b, list)\n        assert b.b[0].a == 3\n        assert b.b[1].a == [4, 5]\n\n    def test_non_params_object_with_params(self):\n        bar = Bar.from_params({\"foo\": Foo(a=1)})\n        assert bar.foo.a == 1\n\n    def test_crazy_nested_union(self):\n        class A(FromParams):\n            def __init__(self, a: Union[int, List[int]]) -> None:\n                self.a = a\n\n        class B(FromParams):\n            def __init__(self, b: Union[A, List[A]]) -> None:\n                # Really you would want to be sure that `self.b` has a consistent type, but for\n                # this test we'll ignore that.\n                self.b = b\n\n        class C(FromParams):\n            def __init__(self, c: Union[A, B, Dict[str, A]]) -> None:\n                # Really you would want to be sure that `self.c` has a consistent type, but for\n                # this test we'll ignore that.\n                self.c = c\n\n        # This is a contrived, ugly example (why would you want to duplicate names in a nested\n        # structure like this??), but it demonstrates a potential bug when dealing with mutatable\n        # parameters.  If you're not careful about keeping the parameters un-mutated in two\n        # separate places, you'll end up with a B, or with a dict that's missing the 'b' key.\n        params = Params({\"c\": {\"a\": {\"a\": 3}, \"b\": {\"a\": [4, 5]}}})\n        c = C.from_params(params)\n        assert isinstance(c.c, dict)\n        assert c.c[\"a\"].a == 3\n        assert c.c[\"b\"].a == [4, 5]\n\n    def test_union_of_castable_types(self):\n        class IntFloat(FromParams):\n            def __init__(self, a: Union[int, float]) -> None:\n                self.a = a\n\n        class FloatInt(FromParams):\n            def __init__(self, a: Union[float, int]) -> None:\n                self.a = a\n\n        float_param_str = '{\"a\": 1.0}'\n        int_param_str = '{\"a\": 1}'\n        import json\n\n        for expected_type, param_str in [(int, int_param_str), (float, float_param_str)]:\n            for cls in [IntFloat, FloatInt]:\n                c = cls.from_params(Params(json.loads(param_str)))\n                assert type(c.a) == expected_type  # type: ignore[attr-defined]\n\n    def test_invalid_type_conversions(self):\n        class A(FromParams):\n            def __init__(self, a: int) -> None:\n                self.a = a\n\n        with pytest.raises(TypeError):\n            A.from_params(Params({\"a\": \"1\"}))\n        with pytest.raises(TypeError):\n            A.from_params(Params({\"a\": 1.0}))\n\n    def test_dict(self):\n        class A(Registrable):\n            pass\n\n        @A.register(\"b\")\n        class B(A):\n            def __init__(self, size: int) -> None:\n                self.size = size\n\n        class C(Registrable):\n            pass\n\n        @C.register(\"d\")\n        class D(C):\n            def __init__(self, items: Dict[str, A]) -> None:\n                self.items = items\n\n        params = Params(\n            {\n                \"type\": \"d\",\n                \"items\": {\"first\": {\"type\": \"b\", \"size\": 1}, \"second\": {\"type\": \"b\", \"size\": 2}},\n            }\n        )\n        d: D = C.from_params(params)  # type: ignore[assignment]\n\n        assert isinstance(d.items, dict)\n        assert len(d.items) == 2\n        assert all(isinstance(key, str) for key in d.items.keys())\n        assert all(isinstance(value, B) for value in d.items.values())\n        assert d.items[\"first\"].size == 1  # type: ignore[attr-defined]\n        assert d.items[\"second\"].size == 2  # type: ignore[attr-defined]\n\n    def test_dict_not_params(self):\n        class A(FromParams):\n            def __init__(self, counts: Dict[str, int]) -> None:\n                self.counts = counts\n\n        params = Params({\"counts\": {\"a\": 10, \"b\": 20}})\n        a = A.from_params(params)\n\n        assert isinstance(a.counts, dict)\n        assert not isinstance(a.counts, Params)\n\n    def test_list(self):\n        class A(Registrable):\n            pass\n\n        @A.register(\"b\")\n        class B(A):\n            def __init__(self, size: int) -> None:\n                self.size = size\n\n        class C(Registrable):\n            pass\n\n        @C.register(\"d\")\n        class D(C):\n            def __init__(self, items: List[A]) -> None:\n                self.items = items\n\n        params = Params(\n            {\"type\": \"d\", \"items\": [{\"type\": \"b\", \"size\": 1}, {\"type\": \"b\", \"size\": 2}]}\n        )\n        d: D = C.from_params(params)  # type: ignore[assignment]\n\n        assert isinstance(d.items, list)\n        assert len(d.items) == 2\n        assert all(isinstance(item, B) for item in d.items)\n        assert d.items[0].size == 1  # type: ignore[attr-defined]\n        assert d.items[1].size == 2  # type: ignore[attr-defined]\n\n    def test_tuple(self):\n        class A(Registrable):\n            pass\n\n        @A.register(\"b\")\n        class B(A):\n            def __init__(self, size: int) -> None:\n                self.size = size\n\n        class C(Registrable):\n            pass\n\n        @C.register(\"d\")\n        class D(C):\n            def __init__(self, name: str) -> None:\n                self.name = name\n\n        class E(Registrable):\n            pass\n\n        @E.register(\"f\")\n        class F(E):\n            def __init__(self, items: Tuple[A, C]) -> None:\n                self.items = items\n\n        params = Params(\n            {\"type\": \"f\", \"items\": [{\"type\": \"b\", \"size\": 1}, {\"type\": \"d\", \"name\": \"item2\"}]}\n        )\n        f: F = E.from_params(params)  # type: ignore[assignment]\n\n        assert isinstance(f.items, tuple)\n        assert len(f.items) == 2\n        assert isinstance(f.items[0], B)\n        assert isinstance(f.items[1], D)\n        assert f.items[0].size == 1\n        assert f.items[1].name == \"item2\"\n\n    def test_set(self):\n        class A(Registrable):\n            def __init__(self, name: str) -> None:\n                self.name = name\n\n            def __eq__(self, other):\n                return self.name == other.name\n\n            def __hash__(self):\n                return hash(self.name)\n\n        @A.register(\"b\")\n        class B(A):\n            pass\n\n        class C(Registrable):\n            pass\n\n        @C.register(\"d\")\n        class D(C):\n            def __init__(self, items: Set[A]) -> None:\n                self.items = items\n\n        params = Params(\n            {\n                \"type\": \"d\",\n                \"items\": [\n                    {\"type\": \"b\", \"name\": \"item1\"},\n                    {\"type\": \"b\", \"name\": \"item2\"},\n                    {\"type\": \"b\", \"name\": \"item2\"},\n                ],\n            }\n        )\n        d: D = C.from_params(params)  # type: ignore[assignment]\n\n        assert isinstance(d.items, set)\n        assert len(d.items) == 2\n        assert all(isinstance(item, B) for item in d.items)\n        assert any(item.name == \"item1\" for item in d.items)\n        assert any(item.name == \"item2\" for item in d.items)\n\n    def test_kwargs_with_multiple_inheritance(self):\n        # Basic idea: have two identical classes, differing only in the order of their multiple\n        # inheritance, and make sure that passing kwargs up to the super class works in both cases.\n        class A(Registrable):\n            def __init__(self, a: int):\n                self.a = a\n\n        @A.register(\"b1\")  # type: ignore\n        class B1(A, Number):\n            def __init__(self, b: float, **kwargs):\n                super().__init__(**kwargs)\n                self.b = b\n\n        @A.register(\"b2\")  # type: ignore\n        class B2(Number, A):\n            def __init__(self, b: float, **kwargs):\n                super().__init__(**kwargs)\n                self.b = b\n\n        b1 = B1.from_params(Params({\"a\": 4, \"b\": 5}))\n        assert b1.b == 5\n        assert b1.a == 4\n\n        b2 = B2.from_params(Params({\"a\": 4, \"b\": 5}))\n        assert b2.b == 5\n        assert b2.a == 4\n\n    def test_instantiating_with_multiple_inheritance(self):\n        class A(Registrable):\n            def __init__(self, a: int):\n                self.a = a\n\n        @A.register(\"b\")  # type: ignore\n        class B(A, Number):\n            def __init__(self, b: float, **kwargs):\n                super().__init__(**kwargs)\n                self.b = b\n\n        assert not is_base_registrable(B)\n\n        @B.register(\"c\")\n        class C(B):\n            def __init__(self, c: float, **kwargs):\n                super().__init__(**kwargs)\n                self.c = c\n\n        # make sure we can instantiate B directly.\n        b = B.from_params({\"b\": 1.0, \"a\": 1})\n        assert isinstance(b, B)\n\n        # and also make sure we can instantiate subclasses of B.\n        c = B.from_params({\"type\": \"c\", \"c\": 2.0, \"b\": 1.0, \"a\": 1})\n        assert isinstance(c, C)\n\n    def test_only_infer_superclass_params_if_unknown(self):\n        class BaseClass(Registrable):\n            def __init__(self):\n                self.x = None\n                self.a = None\n                self.rest = None\n\n        @BaseClass.register(\"a\")\n        class A(BaseClass):\n            def __init__(self, a: int, x: int, **kwargs):\n                super().__init__()\n                self.x = x\n                self.a = a\n                self.rest = kwargs\n\n        @BaseClass.register(\"b\")\n        class B(A):\n            def __init__(self, a: str, x: int = 42, **kwargs):\n                super().__init__(x=x, a=-1, raw_a=a, **kwargs)\n\n        params = Params({\"type\": \"b\", \"a\": \"123\"})\n        # The param `x` should not be required as it has default value in `B`\n        # The correct type of the param `a` should be inferred from `B` as well.\n        instance = BaseClass.from_params(params)\n        assert instance.x == 42\n        assert instance.a == -1\n        assert len(instance.rest) == 1  # type: ignore\n        assert isinstance(instance.rest[\"raw_a\"], str)  # type: ignore\n        assert instance.rest[\"raw_a\"] == \"123\"  # type: ignore\n\n    def test_kwargs_are_passed_to_deeper_superclasses(self):\n        class BaseClass(Registrable):\n            def __init__(self):\n                self.a = None\n                self.b = None\n                self.c = None\n\n        @BaseClass.register(\"a\")\n        class A(BaseClass):\n            def __init__(self, a: str):\n                super().__init__()\n                self.a = a\n\n        @BaseClass.register(\"b\")\n        class B(A):\n            def __init__(self, b: str, **kwargs):\n                super().__init__(**kwargs)\n                self.b = b\n\n        @BaseClass.register(\"c\")\n        class C(B):\n            def __init__(self, c, **kwargs):\n                super().__init__(**kwargs)\n                self.c = c\n\n        params = Params({\"type\": \"c\", \"a\": \"a_value\", \"b\": \"b_value\", \"c\": \"c_value\"})\n\n        instance = BaseClass.from_params(params)\n        assert instance.a == \"a_value\"\n        assert instance.b == \"b_value\"\n        assert instance.c == \"c_value\"\n\n    def test_lazy_construction_can_happen_multiple_times(self):\n        test_string = \"this is a test\"\n        extra_string = \"extra string\"\n\n        class ConstructedObject(FromParams):\n            def __init__(self, string: str, extra: str):\n                self.string = string\n                self.extra = extra\n\n        class Testing(FromParams):\n            def __init__(self, lazy_object: Lazy[ConstructedObject]):\n                first_time = lazy_object.construct(extra=extra_string)\n                second_time = lazy_object.construct(extra=extra_string)\n                assert first_time.string == test_string\n                assert first_time.extra == extra_string\n                assert second_time.string == test_string\n                assert second_time.extra == extra_string\n\n        Testing.from_params(Params({\"lazy_object\": {\"string\": test_string}}))\n\n    def test_lazy_and_from_params_can_be_pickled(self):\n        import pickle\n\n        baz = Baz.from_params(Params({\"bar\": {\"foo\": {\"a\": 2}}}))\n        pickle.dumps(baz)\n\n    def test_optional_vs_required_lazy_objects(self):\n        class ConstructedObject(FromParams):\n            def __init__(self, a: int):\n                self.a = a\n\n        class Testing(FromParams):\n            def __init__(\n                self,\n                lazy1: Lazy[ConstructedObject],\n                lazy2: Lazy[ConstructedObject] = Lazy(ConstructedObject),\n                lazy3: Lazy[ConstructedObject] = None,\n                lazy4: Optional[Lazy[ConstructedObject]] = Lazy(ConstructedObject),\n            ) -> None:\n                self.lazy1 = lazy1.construct()\n                self.lazy2 = lazy2.construct(a=2)\n                self.lazy3 = None if lazy3 is None else lazy3.construct()\n                self.lazy4 = None if lazy4 is None else lazy4.construct(a=1)\n\n        test1 = Testing.from_params(Params({\"lazy1\": {\"a\": 1}}))\n        assert test1.lazy1.a == 1\n        assert test1.lazy2.a == 2\n        assert test1.lazy3 is None\n        assert test1.lazy4 is not None\n\n        test2 = Testing.from_params(Params({\"lazy1\": {\"a\": 1}, \"lazy2\": {\"a\": 3}}))\n        assert test2.lazy1.a == 1\n        assert test2.lazy2.a == 3\n        assert test2.lazy3 is None\n        assert test2.lazy4 is not None\n\n        test3 = Testing.from_params(Params({\"lazy1\": {\"a\": 1}, \"lazy3\": {\"a\": 3}, \"lazy4\": None}))\n        assert test3.lazy1.a == 1\n        assert test3.lazy2.a == 2\n        assert test3.lazy3 is not None\n        assert test3.lazy3.a == 3\n        assert test3.lazy4 is None\n\n        with pytest.raises(ConfigurationError, match='Missing key \"lazy1\" for Testing'):\n            Testing.from_params(Params({}))\n\n    def test_wrapper_kwargs_passed_down(self):\n        class BaseObject:\n            def __init__(self, x: int = 1):\n                self.x = x\n\n        class BaseWrapper(BaseObject, FromParams):\n            def __init__(self, y: int = 2, **kwargs):\n                super().__init__(**kwargs)\n                self.y = y\n\n        o = BaseWrapper.from_params(Params({\"y\": 3}), x=2)\n        assert o.x == 2\n\n    def test_iterable(self):\n        class A(Registrable):\n            pass\n\n        @A.register(\"b\")\n        class B(A):\n            def __init__(self, size: int) -> None:\n                self.size = size\n\n        class C(Registrable):\n            pass\n\n        @C.register(\"d\")\n        class D(C):\n            def __init__(self, items: Iterable[A]) -> None:\n                self.items = items\n\n        params = Params(\n            {\"type\": \"d\", \"items\": [{\"type\": \"b\", \"size\": 1}, {\"type\": \"b\", \"size\": 2}]}\n        )\n        d: D = C.from_params(params)  # type: ignore[assignment]\n\n        assert isinstance(d.items, Iterable)\n        items = list(d.items)\n        assert len(items) == 2\n        assert all(isinstance(item, B) for item in items)\n        assert items[0].size == 1  # type: ignore\n        assert items[1].size == 2  # type: ignore\n\n    def test_mapping(self):\n        class A(Registrable):\n            pass\n\n        @A.register(\"b\")\n        class B(A):\n            def __init__(self, size: int) -> None:\n                self.size = size\n\n        class C(Registrable):\n            pass\n\n        @C.register(\"d\")\n        class D(C):\n            def __init__(self, items: Mapping[str, A]) -> None:\n                self.items = items\n\n        params = Params(\n            {\n                \"type\": \"d\",\n                \"items\": {\"first\": {\"type\": \"b\", \"size\": 1}, \"second\": {\"type\": \"b\", \"size\": 2}},\n            }\n        )\n        d: D = C.from_params(params)  # type: ignore[assignment]\n\n        assert isinstance(d.items, Mapping)\n        assert len(d.items) == 2\n        assert all(isinstance(key, str) for key in d.items.keys())\n        assert all(isinstance(value, B) for value in d.items.values())\n        assert d.items[\"first\"].size == 1  # type: ignore\n        assert d.items[\"second\"].size == 2  # type: ignore\n\n    def test_custom_abc_mapping(self):\n        from collections import abc\n\n        class CustomMapping(abc.Mapping):\n            def __init__(self, data: Dict[str, int]):\n                self.data = data\n\n            def __getitem__(self, key):\n                return self.data[key]\n\n            def __iter__(self):\n                return iter(self.data)\n\n            def __len__(self):\n                return len(self.data)\n\n        class ClassWithCustomMapping(FromParams):\n            def __init__(self, mapping: CustomMapping):\n                self.mapping = mapping\n\n        o = ClassWithCustomMapping.from_params({\"mapping\": {\"data\": {\"a\": 1}}})\n        assert isinstance(o.mapping, CustomMapping)\n        assert o.mapping[\"a\"] == 1\n\n    def test_extra_parameters_are_not_allowed_when_there_is_no_constructor(self):\n        class A(FromParams):\n            pass\n\n        with pytest.raises(ConfigurationError, match=\"Extra parameters\"):\n            A.from_params(Params({\"some_spurious\": \"key\", \"value\": \"pairs\"}))\n\n    def test_explicit_kwargs_always_passed_to_constructor(self):\n        class Base(FromParams):\n            def __init__(self, lazy: bool = False, x: int = 0) -> None:\n                self.lazy = lazy\n                self.x = x\n\n        class A(Base):\n            def __init__(self, **kwargs) -> None:\n                assert \"lazy\" in kwargs\n                super().__init__(**kwargs)\n\n        A.from_params(Params({\"lazy\": False}))\n\n        class B(Base):\n            def __init__(self, **kwargs) -> None:\n                super().__init__(lazy=True, **kwargs)\n\n        b = B.from_params(Params({}))\n        assert b.lazy is True\n\n    def test_raises_when_there_are_no_implementations(self):\n        class A(Registrable):\n            pass\n\n        with pytest.raises(ConfigurationError, match=\"not in acceptable choices for type\"):\n            A.from_params(\"nonexistent_class\")\n\n        with pytest.raises(ConfigurationError, match='key \"type\" is required'):\n            A.from_params(Params({\"some_spurious\": \"key\", \"value\": \"pairs\"}))\n\n        with pytest.raises(ConfigurationError, match='key \"type\" is required'):\n            A.from_params(Params({}))\n\n        # Some paths through the code are different if there is a constructor here versus not.  We\n        # don't actually go through this logic anymore, but it's here as a regression test.\n        class B(Registrable):\n            def __init__(self):\n                pass\n\n        with pytest.raises(ConfigurationError, match=\"not in acceptable choices for type\"):\n            B.from_params(\"nonexistent_class\")\n\n        with pytest.raises(ConfigurationError, match='key \"type\" is required'):\n            B.from_params(Params({\"some_spurious\": \"key\", \"value\": \"pairs\"}))\n\n        with pytest.raises(ConfigurationError, match='key \"type\" is required'):\n            B.from_params(Params({}))\n\n    def test_from_params_raises_error_on_wrong_parameter_name_in_optional_union(self):\n        class NestedClass(FromParams):\n            def __init__(self, varname: Optional[str] = None):\n                self.varname = varname\n\n        class WrapperClass(FromParams):\n            def __init__(self, nested_class: Optional[Union[str, NestedClass]] = None):\n                if isinstance(nested_class, str):\n                    nested_class = NestedClass(varname=nested_class)\n                self.nested_class = nested_class\n\n        with pytest.raises(ConfigurationError):\n            WrapperClass.from_params(Params({\"nested_class\": {\"wrong_varname\": \"varstring\"}}))\n\n    def test_from_params_handles_base_class_kwargs(self):\n        class Foo(FromParams):\n            def __init__(self, a: int, b: str = None, **kwargs) -> None:\n                self.a = a\n                self.b = b\n                for key, value in kwargs.items():\n                    setattr(self, key, value)\n\n        foo = Foo.from_params(Params({\"a\": 2, \"b\": \"hi\"}))\n        assert foo.a == 2\n        assert foo.b == \"hi\"\n\n        foo = Foo.from_params(Params({\"a\": 2, \"b\": \"hi\", \"c\": {\"2\": \"3\"}}))\n        assert foo.a == 2\n        assert foo.b == \"hi\"\n        assert foo.c == {\"2\": \"3\"}  # type: ignore[attr-defined]\n\n        class Bar(Foo):\n            def __init__(self, a: int, b: str, d: int, **kwargs) -> None:\n                super().__init__(a, b=b, **kwargs)\n                self.d = d\n\n        bar = Bar.from_params(Params({\"a\": 2, \"b\": \"hi\", \"c\": {\"2\": \"3\"}, \"d\": 0}))\n        assert bar.a == 2\n        assert bar.b == \"hi\"\n        assert bar.c == {\"2\": \"3\"}  # type: ignore[attr-defined]\n        assert bar.d == 0\n\n        class Baz(Foo):\n            def __init__(self, a: int, b: Optional[str] = \"a\", **kwargs) -> None:\n                super().__init__(a, b=b, **kwargs)\n\n        baz = Baz.from_params(Params({\"a\": 2, \"b\": None}))\n        assert baz.b is None\n\n        baz = Baz.from_params(Params({\"a\": 2}))\n        assert baz.b == \"a\"\n\n    def test_from_params_base_class_kwargs_crashes_if_params_not_handled(self):\n        class Bar(FromParams):\n            def __init__(self, c: str = None) -> None:\n                self.c = c\n\n        class Foo(Bar):\n            def __init__(self, a: int, b: str = None, **kwargs) -> None:\n                super().__init__(**kwargs)\n                self.a = a\n                self.b = b\n\n        foo = Foo.from_params(Params({\"a\": 2, \"b\": \"hi\", \"c\": \"some value\"}))\n        assert foo.a == 2\n        assert foo.b == \"hi\"\n        assert foo.c == \"some value\"\n\n        with pytest.raises(TypeError, match=\"invalid_key\"):\n            Foo.from_params(Params({\"a\": 2, \"b\": \"hi\", \"invalid_key\": \"some value\"}))\n\n    def test_from_params_handles_kwargs_in_non_from_params_registered_class(self):\n        class Bar(Registrable):\n            pass\n\n        class Baz:\n            def __init__(self, a: int) -> None:\n                self.a = a\n\n        @Bar.register(\"foo\")\n        class Foo(Baz):\n            def __init__(self, a: int, b: str = None, **kwargs) -> None:\n                super().__init__(a)\n                self.b = b\n                for key, value in kwargs.items():\n                    setattr(self, key, value)\n\n        foo: Foo = Bar.from_params(Params({\"type\": \"foo\", \"a\": 2, \"b\": \"hi\"}))  # type: ignore[assignment]\n        assert foo.a == 2\n        assert foo.b == \"hi\"\n\n        foo = Bar.from_params(  # type: ignore[assignment]\n            Params({\"type\": \"foo\", \"a\": 2, \"b\": \"hi\", \"c\": {\"2\": \"3\"}})\n        )\n        assert foo.a == 2  # type: ignore[attr-defined]\n        assert foo.b == \"hi\"  # type: ignore[attr-defined]\n        assert foo.c == {\"2\": \"3\"}  # type: ignore[attr-defined]\n\n    def test_from_params_passes_extras_to_non_from_params_registered_class(self):\n        class Bar(Registrable):\n            pass\n\n        class Baz:\n            def __init__(self, a: int, c: Dict[str, str] = None, extra: str = \"idk\") -> None:\n                self.a = a\n                self.c = c\n                self.extra = extra\n\n        @Bar.register(\"foo\")\n        class Foo(Baz):\n            def __init__(self, a: int, b: str = None, **kwargs) -> None:\n                super().__init__(a, **kwargs)\n                self.b = b\n\n        foo: Foo = Bar.from_params(Params({\"type\": \"foo\", \"a\": 2, \"b\": \"hi\"}))  # type: ignore[assignment]\n        assert foo.a == 2\n        assert foo.b == \"hi\"\n        assert foo.c is None\n\n        foo = Bar.from_params(  # type: ignore[assignment]\n            Params({\"type\": \"foo\", \"a\": 2, \"b\": \"hi\", \"c\": {\"2\": \"3\"}}), extra=\"4\"\n        )\n        assert foo.a == 2  # type: ignore[attr-defined]\n        assert foo.b == \"hi\"  # type: ignore[attr-defined]\n        assert foo.c == {\"2\": \"3\"}  # type: ignore[attr-defined]\n        assert foo.extra == \"4\"  # type: ignore[attr-defined]\n\n    def test_from_params_child_has_kwargs_base_implicit_constructor(self):\n        class Foo(FromParams):\n            pass\n\n        class Bar(Foo):\n            def __init__(self, a: int, **kwargs) -> None:\n                self.a = a\n\n        bar = Bar.from_params(Params({\"a\": 2}))\n        assert bar.a == 2\n\n    def test_from_params_has_args(self):\n        class Foo(FromParams):\n            def __init__(self, a: int, *args) -> None:\n                self.a = a\n\n        foo = Foo.from_params(Params({\"a\": 2}))\n        assert foo.a == 2\n\n    def test_from_params_with_dataclass(self):\n        @dataclass\n        class Foo(FromParams):\n            x: int\n            y: str\n\n        assert Foo.from_params({\"x\": 1, \"y\": \"2\"}).x == 1\n        with pytest.raises(TypeError):\n            Foo.from_params({\"x\": 1, \"y\": 2})\n\n    def test_to_params(self):\n        @dataclass\n        class Bar(FromParams):\n            z: bool\n\n        @dataclass\n        class Foo(FromParams):\n            x: int\n            bar: Bar\n\n        params_dict = {\"x\": 1, \"bar\": {\"z\": True}}\n        foo = Foo.from_params(deepcopy(params_dict))\n        assert foo.bar.z\n        params = foo.to_params()\n        assert params.as_dict() == params_dict\n\n    def test_to_params_needs_custom_to_params(self):\n        @dataclass\n        class Bar:\n            z: bool\n\n        @dataclass\n        class Foo(FromParams):\n            x: int\n            bar: Bar\n\n        foo = Foo.from_params({\"x\": 1}, bar=Bar(z=True))\n        with pytest.raises(NotImplementedError):\n            foo.to_params()\n\n    @pytest.mark.skipif(sys.version_info < (3, 9), reason=\"requires python 3.9 or higher\")\n    def test_type_hinting_generics_from_std_collections(self):\n        class Item(FromParams):\n            def __init__(self, a: int) -> None:\n                self.a = a\n\n        class ClassWithStdGenerics(FromParams):\n            def __init__(self, x: list[Item], y: dict[str, Item]) -> None:  # type: ignore[syntax]\n                self.x = x\n                self.y = y\n\n        o = ClassWithStdGenerics.from_params({\"x\": [{\"a\": 1}], \"y\": {\"b\": {\"a\": 1}}})\n        assert isinstance(o.x, list)\n        assert isinstance(o.x[0], Item)\n        assert isinstance(o.y[\"b\"], Item)\n\n    def test_with_non_from_params_generics(self):\n        T = TypeVar(\"T\")\n\n        class Item(Generic[T]):\n            def __init__(self, x: T):\n                self.x = x\n\n        class ClassWithGenerics(FromParams):\n            def __init__(self, item: Item[T]):\n                self.item = item\n\n        o = ClassWithGenerics.from_params({\"item\": {\"x\": 1}})\n        assert isinstance(o.item, Item)\n\n    @pytest.mark.skipif(sys.version_info < (3, 10), reason=\"requires python 3.10 or higher\")\n    def test_with_union_pipe(self):\n        class Item(FromParams):\n            def __init__(self, a: int) -> None:\n                self.a = a\n\n        class ClassWithUnionType(FromParams):\n            def __init__(self, x: Item | str):  # type: ignore[syntax]\n                self.x = x\n\n        o = ClassWithUnionType.from_params({\"x\": {\"a\": 1}})\n        assert isinstance(o.x, Item)\n\n    def test_from_params_with_function(self):\n        \"\"\"\n        Tests that a function registered as a constructor for a registrable class\n        will properly construct arguments.\n        \"\"\"\n\n        class MyRegistrableClass(Registrable):\n            def __init__(self, a: int, b: int):\n                self.a = a\n                self.b = b\n\n        @dataclass\n        class OptionsClass(FromParams):\n            a: int\n            b: int\n\n        @MyRegistrableClass.register(\"func_constructor\")  # type: ignore\n        def constructor(options: OptionsClass) -> MyRegistrableClass:\n            assert isinstance(options, OptionsClass)\n            return MyRegistrableClass(options.a, options.b)\n\n        MyRegistrableClass.from_params({\"type\": \"func_constructor\", \"options\": {\"a\": 1, \"b\": 2}})\n\n    def test_from_params_passes_no_extra_args_in_factory_construction(self):\n        class InnerBase(Registrable):\n            pass\n\n        from typing import Callable\n\n        def innerbase_with_x_factory(cls) -> Callable[..., InnerBase]:\n            def factory(x: int, **kwargs) -> InnerBase:\n                return cls(x=x, **kwargs)\n\n            return factory\n\n        class Inner(InnerBase):\n            def __init__(self, x: int):\n                self.x = x\n\n        InnerBase.register(\"inner\")(innerbase_with_x_factory(Inner))  # type: ignore[arg-type]\n\n        class OuterBase(Registrable):\n            default_implementation = \"default\"\n\n            def __init__(self, y: str, i: InnerBase, c: int):\n                self.i = i\n                self.y = y\n                self.c = c\n\n        OuterBase.register(\"default\")(OuterBase)\n\n        config = {\"c\": 4, \"i\": {\"type\": \"inner\", \"x\": 5}}\n\n        outer_lazy = Lazy(OuterBase, Params(config))\n        outer = outer_lazy.construct(y=\"placeholder\")\n        assert outer.i.x == 5  # type: ignore[attr-defined]\n\n    def test_lazy_from_params_with_version(self):\n        class Gizmo(Registrable):\n            pass\n\n        @Gizmo.register(\"widget\")\n        class WidgetGizmo(Gizmo, DetHashWithVersion):\n            VERSION = \"001\"\n\n            def __init__(self, x: int):\n                self.x = x\n\n            @classmethod\n            def default(cls):\n                return WidgetGizmo(0)\n\n        Gizmo.register(\"default_widget\", \"default\")(WidgetGizmo)\n\n        lazy = Lazy(Gizmo, params=Params({\"type\": \"widget\", \"x\": 1}))\n\n        hash_before = det_hash(lazy)\n        WidgetGizmo.VERSION = \"001\"\n        assert hash_before == det_hash(lazy)\n        WidgetGizmo.VERSION = \"002\"\n        assert hash_before != det_hash(lazy)\n        assert lazy.construct().x == 1  # type: ignore[attr-defined]\n\n        default_lazy = Lazy(\n            Gizmo,\n            params=Params(\n                {\n                    \"type\": \"default_widget\",\n                }\n            ),\n        )\n        assert hash_before != det_hash(default_lazy)\n        assert det_hash(lazy) != det_hash(default_lazy)\n        hash_before = det_hash(default_lazy)\n        WidgetGizmo.VERSION = \"003\"\n        assert hash_before != det_hash(default_lazy)\n        assert default_lazy.construct().x == 0  # type: ignore[attr-defined]\n\n    def test_from_params_that_takes_step_directly(self):\n        class FakeStepBase(Step):\n            def run(self, test_input: int) -> int:  # type: ignore\n                return test_input\n\n        @FakeStepBase.register(\"fake_step\")\n        class FakeStep(FakeStepBase):\n            def run(self, test_input: int) -> int:  # type: ignore\n                return test_input * 2\n\n        class FromParamsWithStepInput(FromParams):\n            def __init__(self, fake_step: FakeStepBase):\n                self.fake_step = fake_step\n\n        o = FromParamsWithStepInput.from_params(\n            {\"fake_step\": {\"type\": \"fake_step\", \"test_input\": 1}}\n        )\n        assert isinstance(o.fake_step, FakeStep)\n\n\nclass MyClass(FromParams):\n    def __init__(self, my_int: int, my_bool: bool = False) -> None:\n        self.my_int = my_int\n        self.my_bool = my_bool\n\n\nclass Foo(FromParams):\n    def __init__(self, a: int = 1) -> None:\n        self.a = a\n\n\nclass Bar(FromParams):\n    def __init__(self, foo: Foo) -> None:\n        self.foo = foo\n\n\nclass Baz(FromParams):\n    def __init__(self, bar: Lazy[Bar]) -> None:\n        self._bar = bar\n\n    @property\n    def bar(self):\n        return self._bar.construct()\n"
  },
  {
    "path": "tests/common/params_test.py",
    "content": "import json\nimport os\nimport re\nfrom collections import OrderedDict\n\nimport pytest\n\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.params import (\n    Params,\n    infer_and_cast,\n    remove_keys_from_params,\n    with_overrides,\n)\nfrom tango.common.testing import TangoTestCase\n\n\nclass TestParams(TangoTestCase):\n    @pytest.mark.parametrize(\"extension\", [\"jsonnet\", \"yaml\"])\n    def test_load_from_file(self, extension):\n        filename = self.FIXTURES_ROOT / \"common\" / f\"params_example.{extension}\"\n        params = Params.from_file(filename)\n        assert params[\"model\"][\"type\"] == \"classifier\"\n\n    def test_replace_none(self):\n        params = Params({\"a\": \"None\", \"b\": [1.0, \"None\", 2], \"c\": {\"d\": \"None\"}})\n        assert params[\"a\"] is None\n        assert params[\"b\"][1] is None\n        assert params[\"c\"][\"d\"] is None\n\n    def test_init_with_different_types(self):\n        assert Params({\"a\": 1, \"b\": 2}) == Params(Params({\"a\": 1, \"b\": 2}))\n\n    def test_bad_unicode_environment_variables(self):\n        filename = self.FIXTURES_ROOT / \"common\" / \"params_example.jsonnet\"\n        os.environ[\"BAD_ENVIRONMENT_VARIABLE\"] = \"\\udce2\"\n        Params.from_file(filename)\n        del os.environ[\"BAD_ENVIRONMENT_VARIABLE\"]\n\n    def test_with_overrides(self):\n        original = {\n            \"foo\": {\"bar\": {\"baz\": 3}, \"x\": 0},\n            \"bar\": [\"a\", \"b\", \"c\"],\n            \"baz\": {\"bar\": 2, \"y\": 3, \"x\": [0, 1, 2]},\n        }\n        overrides = {\n            \"foo.bar\": {\"z\": 2},\n            \"bar.0\": \"d\",\n            \"baz.bar\": 1,\n            \"baz.x\": [0, 0],\n            \"z\": 2,\n        }\n        assert with_overrides(original, overrides) == {\n            \"foo\": {\"bar\": {\"z\": 2}, \"x\": 0},\n            \"bar\": [\"d\", \"b\", \"c\"],\n            \"baz\": {\"bar\": 1, \"y\": 3, \"x\": [0, 0]},\n            \"z\": 2,\n        }\n\n    def test_bad_overrides(self):\n        with pytest.raises(ValueError, match=\"contains unused keys\"):\n            with_overrides({\"foo\": [0, 1, 2]}, {\"foo.3\": 4})\n        with pytest.raises(ValueError, match=\"expected list or dict\"):\n            with_overrides({\"foo\": 3}, {\"foo.x\": 2})\n\n    @pytest.mark.parametrize(\"input_type\", [dict, str])\n    def test_overrides(self, input_type):\n        filename = self.FIXTURES_ROOT / \"common\" / \"params_example.jsonnet\"\n        overrides = {\n            \"data_path\": \"train.txt\",\n            \"model.type\": \"new_classifier\",\n            \"model.layers.0.activation\": \"gelu\",\n            \"model.layers.1\": {\"type\": \"classifier\"},\n        }\n        params = Params.from_file(\n            filename, overrides if input_type == dict else json.dumps(overrides)\n        )\n\n        assert params[\"data_path\"] == \"train.txt\"\n        assert params[\"model\"][\"type\"] == \"new_classifier\"\n        assert len(params[\"model\"][\"layers\"]) == 2\n        assert params[\"model\"][\"layers\"][0][\"activation\"] == \"gelu\"\n        assert params[\"model\"][\"layers\"][1][\"type\"] == \"classifier\"\n\n    def test_as_flat_dict(self):\n        params = Params({\"a\": 10, \"b\": {\"c\": 20, \"d\": \"stuff\"}}).as_flat_dict()\n\n        assert params == {\"a\": 10, \"b.c\": 20, \"b.d\": \"stuff\"}\n\n    def test_jsonnet_features(self):\n        config_file = self.TEST_DIR / \"config.jsonnet\"\n        with open(config_file, \"w\") as f:\n            f.write(\n                \"\"\"{\n                            // This example is copied straight from the jsonnet docs\n                            person1: {\n                                name: \"Alice\",\n                                welcome: \"Hello \" + self.name + \"!\",\n                            },\n                            person2: self.person1 { name: \"Bob\" },\n                        }\"\"\"\n            )\n\n        params = Params.from_file(config_file)\n\n        alice = params.pop(\"person1\")\n        bob = params.pop(\"person2\")\n\n        assert alice.as_dict() == {\"name\": \"Alice\", \"welcome\": \"Hello Alice!\"}\n        assert bob.as_dict() == {\"name\": \"Bob\", \"welcome\": \"Hello Bob!\"}\n\n        params.assert_empty(\"TestParams\")\n\n    def test_regexes_with_backslashes(self):\n        bad_regex = self.TEST_DIR / \"bad_regex.jsonnet\"\n        good_regex = self.TEST_DIR / \"good_regex.jsonnet\"\n\n        with open(bad_regex, \"w\") as f:\n            f.write(r'{\"myRegex\": \"a\\.b\"}')\n\n        with open(good_regex, \"w\") as f:\n            f.write(r'{\"myRegex\": \"a\\\\.b\"}')\n\n        with pytest.raises(RuntimeError):\n            Params.from_file(bad_regex)\n\n        params = Params.from_file(good_regex)\n        regex = params[\"myRegex\"]\n\n        assert re.match(regex, \"a.b\")\n        assert not re.match(regex, \"a-b\")\n\n        # Check roundtripping\n        good_regex2 = self.TEST_DIR / \"good_regex2.jsonnet\"\n        with open(good_regex2, \"w\") as f:\n            f.write(json.dumps(params.as_dict()))\n        params2 = Params.from_file(good_regex2)\n\n        assert params.as_dict() == params2.as_dict()\n\n    def test_env_var_substitution(self):\n        substitutor = self.TEST_DIR / \"substitutor.jsonnet\"\n        key = \"TEST_ENV_VAR_SUBSTITUTION\"\n\n        assert os.environ.get(key) is None\n\n        with open(substitutor, \"w\") as f:\n            f.write(f'{{\"path\": std.extVar(\"{key}\")}}')\n\n        # raises without environment variable set\n        with pytest.raises(RuntimeError):\n            Params.from_file(substitutor)\n\n        os.environ[key] = \"PERFECT\"\n\n        params = Params.from_file(substitutor)\n        assert params[\"path\"] == \"PERFECT\"\n\n        del os.environ[key]\n\n    def test_as_ordered_dict(self):\n        # keyD > keyC > keyE; keyDA > keyDB; Next all other keys alphabetically\n        preference_orders = [[\"keyD\", \"keyC\", \"keyE\"], [\"keyDA\", \"keyDB\"]]\n        params = Params(\n            {\n                \"keyC\": \"valC\",\n                \"keyB\": \"valB\",\n                \"keyA\": \"valA\",\n                \"keyE\": \"valE\",\n                \"keyD\": {\"keyDB\": \"valDB\", \"keyDA\": \"valDA\"},\n            }\n        )\n        ordered_params_dict = params.as_ordered_dict(preference_orders)\n        expected_ordered_params_dict = OrderedDict(\n            {\n                \"keyD\": {\"keyDA\": \"valDA\", \"keyDB\": \"valDB\"},\n                \"keyC\": \"valC\",\n                \"keyE\": \"valE\",\n                \"keyA\": \"valA\",\n                \"keyB\": \"valB\",\n            }\n        )\n        assert json.dumps(ordered_params_dict) == json.dumps(expected_ordered_params_dict)\n\n    def test_to_file(self):\n        # Test to_file works with or without preference orders\n        params_dict = {\"keyA\": \"valA\", \"keyB\": \"valB\"}\n        expected_ordered_params_dict = OrderedDict({\"keyB\": \"valB\", \"keyA\": \"valA\"})\n        params = Params(params_dict)\n        file_path = self.TEST_DIR / \"config.jsonnet\"\n        # check with preference orders\n        params.to_file(file_path, [[\"keyB\", \"keyA\"]])\n        with open(file_path, \"r\") as handle:\n            ordered_params_dict = OrderedDict(json.load(handle))\n        assert json.dumps(expected_ordered_params_dict) == json.dumps(ordered_params_dict)\n        # check without preference orders doesn't give error\n        params.to_file(file_path)\n\n    def test_infer_and_cast(self):\n        lots_of_strings = {\n            \"a\": [\"10\", \"1.3\", \"true\"],\n            \"b\": {\"x\": 10, \"y\": \"20.1\", \"z\": \"other things\"},\n            \"c\": \"just a string\",\n        }\n\n        casted = {\n            \"a\": [10, 1.3, True],\n            \"b\": {\"x\": 10, \"y\": 20.1, \"z\": \"other things\"},\n            \"c\": \"just a string\",\n        }\n\n        assert infer_and_cast(lots_of_strings) == casted\n\n        contains_bad_data = {\"x\": 10, \"y\": int}\n        with pytest.raises(ValueError, match=\"cannot infer type\"):\n            infer_and_cast(contains_bad_data)\n\n        params = Params(lots_of_strings)\n\n        assert params.as_dict() == lots_of_strings\n        assert params.as_dict(infer_type_and_cast=True) == casted\n\n    def test_pop_choice(self):\n        choices = [\"my_model\", \"other_model\"]\n        params = Params({\"model\": \"my_model\"})\n        assert params.pop_choice(\"model\", choices) == \"my_model\"\n\n        params = Params({\"model\": \"non_existent_model\"})\n        with pytest.raises(ConfigurationError):\n            params.pop_choice(\"model\", choices)\n\n        params = Params({\"model\": \"module.submodule.ModelName\"})\n        assert params.pop_choice(\"model\", choices) == \"module.submodule.ModelName\"\n\n        params = Params({\"model\": \"module.submodule.ModelName\"})\n        with pytest.raises(ConfigurationError):\n            params.pop_choice(\"model\", choices, allow_class_names=False)\n\n    def test_remove_keys_from_params(self):\n        filename = self.FIXTURES_ROOT / \"common\" / \"params_example.jsonnet\"\n        params = Params.from_file(filename)\n\n        assert params[\"model\"][\"layers\"][0][\"activation\"] == \"relu\"\n        assert params[\"model\"][\"layers\"][1][\"activation\"] == \"softmax\"\n\n        remove_keys_from_params(params, keys=[\"activation\"])\n        assert \"activation\" not in params[\"model\"][\"layers\"][0]\n        assert \"activation\" not in params[\"model\"][\"layers\"][1]\n"
  },
  {
    "path": "tests/common/registrable_test.py",
    "content": "import pytest\n\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.registrable import Registrable\nfrom tango.common.testing import TangoTestCase\nfrom tango.step import Step\n\n\nclass TestRegistrable(TangoTestCase):\n    def test_basic_functionality(self):\n        class MockBaseClass(Registrable):\n            pass\n\n        assert \"mock-1\" not in MockBaseClass.list_available()\n\n        @MockBaseClass.register(\"mock-1\")\n        class MockSubclass1(MockBaseClass):\n            pass\n\n        assert MockBaseClass in Registrable._registry\n        assert MockBaseClass.by_name(\"mock-1\") == MockSubclass1\n\n        # Verify that registering under a name that already exists\n        # causes a ConfigurationError.\n        with pytest.raises(ConfigurationError):\n\n            @MockBaseClass.register(\"mock-1\")\n            class MockAlternate(MockBaseClass):\n                pass\n\n        # Registering under a name that already exists should overwrite\n        # if exist_ok=True.\n        @MockBaseClass.register(\"mock-1\", exist_ok=True)\n        class MockAlternate2(MockBaseClass):\n            pass\n\n        assert MockBaseClass.by_name(\"mock-1\") == MockAlternate2\n\n        # Test that we get a suggestion when the name is close.\n        with pytest.raises(ConfigurationError) as exc:\n            MockBaseClass.by_name(\"mock_1\")\n            assert \"did you mean 'mock-1'?\" in str(exc.value)\n\n    def test_registering_step_by_reserved_name(self):\n        with pytest.raises(ConfigurationError, match=\"cannot use the name 'ref'\"):\n\n            @Step.register(\"ref\")\n            class BadStep(Step):\n                pass\n\n    def test_search_modules(self):\n        Step.search_modules(\"foo-bar-baz-non-existent\")\n"
  },
  {
    "path": "tests/common/sequences_test.py",
    "content": "import os\nfrom tempfile import TemporaryDirectory\n\nimport pytest\n\nfrom tango.common.sequences import (\n    ConcatenatedSequence,\n    MappedSequence,\n    ShuffledSequence,\n    SlicedSequence,\n    SqliteSparseSequence,\n)\n\n\ndef assert_equal_including_exceptions(expected_fn, actual_fn):\n    try:\n        expected = expected_fn()\n    except Exception as e:\n        with pytest.raises(e.__class__):\n            actual_fn()\n    else:\n        assert expected == actual_fn()\n\n\ndef test_shuffled_sequence():\n    seq = ShuffledSequence(list(range(10)))\n    assert 5 in seq\n    assert len(seq) == 10\n\n\ndef test_sliced_sequence():\n    seq = SlicedSequence(list(range(10)), slice(10))\n    assert len(seq) == 10\n    assert seq[0] == 0\n    assert seq[-1] == 9\n    seq2 = seq[-2:]\n    assert len(seq2) == 2\n\n\ndef test_concatenated_sequence():\n    l1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n    l2 = ConcatenatedSequence([0, 1], [], [2, 3, 4], [5, 6, 7, 8, 9], [])\n\n    # __len__()\n    assert len(l1) == len(l2)\n\n    # index()\n    for item in l1 + [999]:\n        # no indices\n        assert_equal_including_exceptions(lambda: l1.index(item), lambda: l2.index(item))\n\n        # only start index\n        for index in range(-15, 15):\n            assert_equal_including_exceptions(\n                lambda: l1.index(item, index), lambda: l2.index(item, index)\n            )\n\n        # start and stop index\n        for start_index in range(-15, 15):\n            for end_index in range(-15, 15):\n                assert_equal_including_exceptions(\n                    lambda: l1.index(item, start_index, end_index),\n                    lambda: l2.index(item, start_index, end_index),\n                )\n\n    # __getitem__()\n    for index in range(-15, 15):\n        assert_equal_including_exceptions(lambda: l1[index], lambda: l2[index])\n\n    for start_index in range(-15, 15):\n        for end_index in range(-15, 15):\n            assert_equal_including_exceptions(\n                lambda: l1[start_index:end_index], lambda: list(l2[start_index:end_index])\n            )\n\n    # count()\n    for item in l1 + [999]:\n        assert_equal_including_exceptions(lambda: l1.count(item), lambda: l2.count(item))\n\n    # __contains__()\n    for item in l1 + [999]:\n        assert_equal_including_exceptions(lambda: item in l1, lambda: item in l2)\n\n\ndef test_sqlite_sparse_sequence():\n    with TemporaryDirectory(prefix=\"test_sparse_sequence-\") as temp_dir:\n        s = SqliteSparseSequence(os.path.join(temp_dir, \"test.sqlite\"))\n        assert len(s) == 0\n        s.extend([])\n        assert len(s) == 0\n        s.append(\"one\")\n        assert len(s) == 1\n        s.extend([\"two\", \"three\"])\n        s.insert(1, \"two\")\n        assert s[1] == \"two\"\n        assert s.count(\"two\") == 2\n        ss = s[1:3]\n        assert list(ss) == [\"two\", \"two\"]\n        del s[1:3]\n        assert len(s) == 2\n        assert s[-1] == \"three\"\n        s.clear()\n        assert len(s) == 0\n\n\ndef test_mapped_sequence():\n    my_very_long_sequence = [\"John\", \"Paul\", \"George\", \"Ringo\"]\n    m = MappedSequence(lambda x: len(x), my_very_long_sequence)\n    assert m[0] == 4\n    assert len(m) == len(my_very_long_sequence)\n    for i in range(len(m)):\n        assert m[i] == m[i:][0]\n"
  },
  {
    "path": "tests/common/util_test.py",
    "content": "import os\nimport time\nfrom pathlib import Path\n\nimport pytest\nfrom flaky import flaky\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.common.util import (\n    could_be_class_name,\n    find_integrations,\n    find_submodules,\n    resolve_module_name,\n    threaded_generator,\n)\n\n\nclass TestResolveModuleName(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        self._work_dir_restore = os.getcwd()\n        os.chdir(self.TEST_DIR)\n\n    def teardown_method(self):\n        super().teardown_method()\n        os.chdir(self._work_dir_restore)\n\n    def test_with_package_init_file(self):\n        path = Path(\"fake_package/fake_module/__init__.py\")\n        (self.TEST_DIR / path.parent).mkdir(parents=True)\n        open(path, \"w\").close()\n        open(path.parent.parent / \"__init__.py\", \"w\").close()\n        assert resolve_module_name(str(path)) == (\"fake_package.fake_module\", Path(\".\"))\n\n    def test_with_submodule(self):\n        path = Path(\"fake_package/fake_module\")\n        (self.TEST_DIR / path).mkdir(parents=True)\n        open(path / \"__init__.py\", \"w\").close()\n        open(path.parent / \"__init__.py\", \"w\").close()\n        assert resolve_module_name(str(path)) == (\"fake_package.fake_module\", Path(\".\"))\n\n    def test_with_module_in_child_directory(self):\n        path = Path(\"some_dir/fake_module.py\")\n        (self.TEST_DIR / path.parent).mkdir(parents=True)\n        open(path, \"w\").close()\n        assert resolve_module_name(str(path)) == (\"fake_module\", Path(\"./some_dir\"))\n\n\ndef test_find_submodules():\n    assert \"tango.version\" in set(find_submodules())\n    assert \"tango.common.registrable\" in set(find_submodules())\n    assert \"tango.common\" in set(find_submodules(recursive=False))\n    assert \"tango.common.registrable\" not in set(find_submodules(recursive=False))\n    assert \"tango.integrations.torch\" in set(find_submodules(\"integrations\"))\n    assert \"tango.integrations.torch\" not in set(find_submodules(exclude={\"tango.integrations*\"}))\n\n\ndef test_find_integrations():\n    integrations = set(find_integrations())\n    assert \"tango.integrations.torch\" in integrations\n    assert \"tango.integrations.torch.format\" not in integrations\n\n\n@pytest.mark.parametrize(\n    \"name, result\",\n    [\n        (\"\", False),\n        (\"foo.Bar\", True),\n        (\"foo.Bar.\", False),\n        (\"1foo.Bar\", False),\n        (\"lib.my_package.MyClass\", True),\n    ],\n)\ndef test_could_be_class_name(name: str, result: bool):\n    assert could_be_class_name(name) is result\n\n\n@flaky(max_runs=3)\ndef test_threaded_generator():\n    def generate_slowly():\n        for i in range(10):\n            yield i\n            time.sleep(0.1)\n\n    start = time.time()\n    for i in threaded_generator(generate_slowly()):\n        time.sleep(0.1)\n    end = time.time()\n\n    assert end - start < 11\n"
  },
  {
    "path": "tests/end_to_end/test_dataset_dict_from_separate_steps.py",
    "content": "from typing import Any, Sequence\n\nfrom tango import Format, JsonFormat, Step\nfrom tango.common import DatasetDict\nfrom tango.common.testing import run_experiment\n\n\n@Step.register(\"train_data\")\nclass TrainData(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(self) -> Sequence[int]:  # type: ignore\n        return list(range(10))\n\n\n@Step.register(\"val_data\")\nclass ValData(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(self) -> Sequence[int]:  # type: ignore\n        return list(range(10, 20))\n\n\n@Step.register(\"save_data\")\nclass SaveData(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = JsonFormat()\n\n    def run(self, dataset_dict: DatasetDict) -> Any:  # type: ignore\n        return dataset_dict.splits\n\n\ndef test_experiment():\n    with run_experiment(\n        {\n            \"steps\": {\n                \"train_data\": {\n                    \"type\": \"train_data\",\n                },\n                \"val_data\": {\n                    \"type\": \"val_data\",\n                },\n                \"saved_data\": {\n                    \"type\": \"save_data\",\n                    \"dataset_dict\": {\n                        \"splits\": {\n                            \"train\": {\"type\": \"ref\", \"ref\": \"train_data\"},\n                            \"val\": {\"type\": \"ref\", \"ref\": \"val_data\"},\n                        }\n                    },\n                },\n            }\n        }\n    ) as run_dir:\n        assert (run_dir / \"saved_data\").is_dir()\n        fmt = JsonFormat()\n        data = fmt.read(run_dir / \"saved_data\")\n        assert data[\"train\"] == list(range(10))\n        assert data[\"val\"] == list(range(10, 20))\n"
  },
  {
    "path": "tests/end_to_end/test_lazy_input_with_another_step.py",
    "content": "from dataclasses import dataclass\n\nfrom tango import Format, JsonFormat, Step\nfrom tango.common import FromParams, Lazy\nfrom tango.common.testing import run_experiment\n\n\n@dataclass\nclass Foo(FromParams):\n    number: float\n\n\n@Step.register(\"generate_number\")\nclass GenerateNumberStep(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = JsonFormat()\n\n    def run(self) -> float:  # type: ignore[override]\n        return 1.0\n\n\n@Step.register(\"lazy_input\")\nclass StepWithLazyInput(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = JsonFormat()\n\n    def run(self, foo: Lazy[Foo]) -> float:  # type: ignore[override]\n        foo = foo.construct()\n        assert isinstance(foo, Foo)\n        assert isinstance(foo.number, float)\n        return foo.number\n\n\ndef test_experiment():\n    with run_experiment(\n        {\n            \"steps\": {\n                \"gen_number\": {\n                    \"type\": \"generate_number\",\n                },\n                \"get_number\": {\n                    \"type\": \"lazy_input\",\n                    \"foo\": {\n                        \"number\": {\n                            \"type\": \"ref\",\n                            \"ref\": \"gen_number\",\n                        }\n                    },\n                },\n            }\n        }\n    ) as run_dir:\n        assert (run_dir / \"get_number\").is_dir()\n        fmt: Format = JsonFormat()\n        data = fmt.read(run_dir / \"get_number\")\n        assert data == 1.0\n"
  },
  {
    "path": "tests/end_to_end/test_multicore_cli.py",
    "content": "import pytest\n\nfrom tango.common.exceptions import CliRunError\nfrom tango.common.logging import initialize_logging, teardown_logging\nfrom tango.common.testing import TangoTestCase\n\n\nclass TestExperiment(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        initialize_logging()\n        self.config = {\n            \"steps\": {\n                \"step1\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"string_to_pass_down\",\n                    \"seconds\": 1,\n                    \"fail\": True,\n                },\n                \"step2\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": {\"type\": \"ref\", \"ref\": \"step1\"},\n                    \"seconds\": 1,\n                    \"fail\": False,\n                },\n                \"step3\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"This may or may not fail!\",\n                    \"seconds\": 3,\n                    \"fail\": False,\n                },\n            }\n        }\n\n    def teardown_method(self):\n        super().teardown_method()\n        teardown_logging()\n\n    def test_experiment(self, caplog):\n        with pytest.raises(CliRunError):\n            self.run(\n                self.config,\n                multicore=True,\n                parallelism=2,\n            )\n        latest_outputs = self.TEST_DIR / \"workspace\" / \"latest\"\n        num_executed = 0\n        for out in latest_outputs.iterdir():\n            if (out / \"cache-metadata.json\").exists():\n                num_executed += 1\n        assert num_executed == 1\n\n    def test_experiment_with_overrides(self, caplog):\n        import json\n\n        self.run(\n            self.config,\n            multicore=True,\n            parallelism=2,\n            overrides=json.dumps({\"steps.step1.fail\": False}),\n        )\n        latest_outputs = self.TEST_DIR / \"workspace\" / \"latest\"\n        num_executed = 0\n        for out in latest_outputs.iterdir():\n            if (out / \"cache-metadata.json\").exists():\n                num_executed += 1\n        assert num_executed == 3\n"
  },
  {
    "path": "tests/end_to_end/test_non_cacheable_into_cacheable_multiple_runs.py",
    "content": "import random\n\nfrom tango import Step\nfrom tango.common.testing import TangoTestCase\n\n\n@Step.register(\"give_me_a_number\")\nclass GiveMeANumber(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(self, what_number: int) -> int:  # type: ignore\n        return what_number\n\n\n@Step.register(\"random_int\")\nclass RandomInt(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n\n    def run(self, lower_bound: int, upper_bound: int) -> int:  # type: ignore\n        return random.randint(lower_bound, upper_bound)\n\n\nclass TestExperiment(TangoTestCase):\n    def test_experiment(self, caplog):\n        config = {\n            \"steps\": {\n                \"a_number\": {\n                    \"type\": \"give_me_a_number\",\n                    \"what_number\": 3,\n                },\n                \"final_number\": {\n                    \"type\": \"random_int\",\n                    \"lower_bound\": 0,\n                    \"upper_bound\": {\"type\": \"ref\", \"ref\": \"a_number\"},\n                },\n            }\n        }\n\n        self.run(config)\n        self.run(config, overrides={\"steps.final_number.lower_bound\": 1})\n"
  },
  {
    "path": "tests/end_to_end/test_registered_runs.py",
    "content": "from tango import Format, JsonFormat, Step\nfrom tango.common.testing import TangoTestCase\n\n\n@Step.register(\"return_a_number\")\nclass ReturnANumber(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n    FORMAT: Format = JsonFormat()\n\n    def run(self, what_number: int) -> int:  # type: ignore\n        return what_number\n\n\nclass TestExperiment(TangoTestCase):\n    def test_experiment_updates_latest_run_output(self, caplog):\n        config = {\n            \"steps\": {\n                \"a_number\": {\n                    \"type\": \"return_a_number\",\n                    \"what_number\": 3,\n                },\n            }\n        }\n\n        self.run(config)\n        assert (self.TEST_DIR / \"workspace\" / \"latest\" / \"a_number\").exists()\n\n        fmt: Format = JsonFormat()\n        data = fmt.read(self.TEST_DIR / \"workspace\" / \"latest\" / \"a_number\")\n        assert data == 3\n\n        config = {\n            \"steps\": {\n                \"a_number\": {\n                    \"type\": \"return_a_number\",\n                    \"what_number\": 5,\n                },\n            }\n        }\n\n        self.run(config)\n        data = fmt.read(self.TEST_DIR / \"workspace\" / \"latest\" / \"a_number\")\n        assert data == 5\n"
  },
  {
    "path": "tests/end_to_end/test_run_single_step.py",
    "content": "from tango.common.testing import TangoTestCase\n\n\nclass TestRunSingleStep(TangoTestCase):\n    def test_run_single_step(self):\n        config = {\n            \"steps\": {\n                \"strA\": {\"type\": \"string\", \"result\": \"Hello, \"},\n                \"strB\": {\"type\": \"string\", \"result\": \"World\"},\n                \"concatenated\": {\n                    \"type\": \"concat_strings\",\n                    \"string1\": {\"type\": \"ref\", \"ref\": \"strA\"},\n                    \"string2\": {\"type\": \"ref\", \"ref\": \"strB\"},\n                },\n            }\n        }\n\n        num_other_files = 2  # out.log and stepinfo.json\n\n        # Regular run contains all step outputs.\n        self.run(config)\n        latest_outputs = self.TEST_DIR / \"workspace\" / \"latest\"\n        assert len(list(latest_outputs.iterdir())) == num_other_files + 3\n\n        # Running a single step with no dependencies should have a single output.\n        self.run(config, step_name=\"strB\")\n        latest_outputs = self.TEST_DIR / \"workspace\" / \"latest\"\n        assert len(list(latest_outputs.iterdir())) == num_other_files + 1\n\n        # Running a single step with one or more dependencies will also run the step's dependencies.\n        self.run(config, step_name=\"concatenated\")\n        latest_outputs = self.TEST_DIR / \"workspace\" / \"latest\"\n        assert len(list(latest_outputs.iterdir())) == num_other_files + 3\n"
  },
  {
    "path": "tests/end_to_end/test_step_indexing.py",
    "content": "from tango.common.testing import TangoTestCase\nfrom tango.workspaces import LocalWorkspace\n\n\nclass TestStepIndexing(TangoTestCase):\n    def test_step_indexing(self):\n        run_name = \"run1\"\n        config = {\n            \"steps\": {\n                \"list\": {\"type\": \"range_step\", \"start\": 0, \"end\": 3},\n                \"added\": {\n                    \"type\": \"add_numbers\",\n                    \"a_number\": 2,\n                    \"b_number\": {\"type\": \"ref\", \"ref\": \"list\", \"key\": 1},\n                },\n            }\n        }\n        self.run(config, name=run_name)\n        workspace = LocalWorkspace(self.TEST_DIR / \"workspace\")\n        result = workspace.step_result_for_run(run_name, \"added\")\n        assert result == 3\n"
  },
  {
    "path": "tests/end_to_end/test_steps_that_fail.py",
    "content": "from collections import Counter\nfrom typing import MutableMapping\n\nimport pytest\n\nfrom tango import Step\nfrom tango.common.exceptions import CliRunError\nfrom tango.common.testing import TangoTestCase\n\nstep_execution_count: MutableMapping[str, int] = Counter()\n\n\n@Step.register(\"step_a\")\nclass StepA(Step):\n    def run(self, what_number: int) -> int:  # type: ignore\n        global step_execution_count\n        step_execution_count[\"a\"] += 1\n        return what_number\n\n\n@Step.register(\"step_b\")\nclass StepB(Step):\n    def run(self, what_number: int) -> int:  # type: ignore\n        global step_execution_count\n        step_execution_count[\"b\"] += 1\n        return what_number\n\n\nstep_should_fail: bool = True\n\n\n@Step.register(\"step_fail\")\nclass StepFail(Step):\n    def run(self, what_number: int) -> int:  # type: ignore\n        global step_execution_count\n        step_execution_count[\"fail\"] += 1\n        global step_should_fail\n        if step_should_fail:\n            raise RuntimeError(\"Step should fail\")\n        else:\n            return what_number\n\n\nclass TestExperiment(TangoTestCase):\n    def test_experiment(self, caplog):\n        global step_should_fail\n        config = {\n            \"steps\": {\n                \"a_number\": {\n                    \"type\": \"step_a\",\n                    \"what_number\": 3,\n                },\n                \"fail_number\": {\n                    \"type\": \"step_fail\",\n                    \"what_number\": {\"type\": \"ref\", \"ref\": \"a_number\"},\n                },\n                \"b_number\": {\n                    \"type\": \"step_b\",\n                    \"what_number\": {\"type\": \"ref\", \"ref\": \"fail_number\"},\n                },\n            }\n        }\n\n        global step_should_fail\n        global step_execution_count\n\n        step_should_fail = True\n        with pytest.raises(CliRunError):\n            self.run(config)\n\n        assert step_execution_count[\"a\"] == 1\n        assert step_execution_count[\"fail\"] == 1\n        assert step_execution_count[\"b\"] == 0\n\n        step_should_fail = False\n        self.run(config)\n\n        assert step_execution_count[\"a\"] == 1\n        assert step_execution_count[\"fail\"] == 2\n        assert step_execution_count[\"b\"] == 1\n"
  },
  {
    "path": "tests/end_to_end/test_uncacheable_leaf_steps.py",
    "content": "from tango import Step\nfrom tango.common.testing import TangoTestCase, run_experiment\nfrom tango.common.testing.steps import MakeNumber  # noqa:F401\n\nstored_number = None\n\n\n@Step.register(\"store_number_globally\")\nclass StoreNumberGlobally(Step):\n    DETERMINISTIC = True\n    CACHEABLE = False\n\n    def run(self, number: int) -> None:  # type: ignore\n        global stored_number\n        stored_number = number\n\n\nclass TestExperiment(TangoTestCase):\n    def test_experiment(self, caplog):\n        config = {\n            \"steps\": {\n                \"a_number\": {\n                    \"type\": \"make_number\",\n                    \"what_number\": 3,\n                },\n                \"store_number\": {\n                    \"type\": \"store_number_globally\",\n                    \"number\": {\"type\": \"ref\", \"ref\": \"a_number\"},\n                },\n            }\n        }\n\n        global stored_number\n        assert stored_number is None\n        self.run(config)\n        assert stored_number == 3\n\n\nclass TestExperimentMulticore(TangoTestCase):\n    def test_experiment(self, caplog):\n        file_name = self.TEST_DIR / \"number_file.txt\"\n        assert not file_name.exists()\n        with run_experiment(\n            {\n                \"steps\": {\n                    \"a_number\": {\n                        \"type\": \"make_number\",\n                        \"what_number\": 3,\n                    },\n                    \"store_number\": {\n                        \"type\": \"store_number_in_file\",\n                        \"number\": {\"type\": \"ref\", \"ref\": \"a_number\"},\n                        \"file_name\": str(file_name),\n                    },\n                }\n            },\n            multicore=True,\n        ):\n            with open(file_name) as file_ref:\n                number = file_ref.read()\n\n            assert int(number) == 3\n"
  },
  {
    "path": "tests/executor_test.py",
    "content": "from tango.common.testing import TangoTestCase\nfrom tango.common.testing.steps import SleepPrintMaybeFail  # noqa:F401\nfrom tango.executor import Executor\nfrom tango.step import Step\nfrom tango.step_graph import StepGraph\nfrom tango.workspaces import LocalWorkspace\n\n\n@Step.register(\"sum_numbers\")\nclass AdditionStep(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n\n    def run(self, a: int, b: int) -> int:  # type: ignore\n        return a + b\n\n\nclass TestExecutor(TangoTestCase):\n    def test_executor(self):\n        workspace = LocalWorkspace(self.TEST_DIR)\n        step = AdditionStep(a=1, b=2)\n        step_graph = StepGraph.from_params({\"sum\": {\"type\": \"sum_numbers\", \"a\": 1, \"b\": 2}})\n        executor = Executor(workspace)\n        assert len(executor.workspace.step_cache) == 0\n        output = executor.execute_step_graph(step_graph)\n        assert \"sum\" in output.successful\n        assert len(executor.workspace.step_cache) == 1\n        assert executor.workspace.step_cache[step] == 3\n\n    def test_executor_with_failing_steps(self):\n        workspace = LocalWorkspace(self.TEST_DIR)\n        step_graph = StepGraph.from_params(\n            {\n                \"successful_step\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"This ran perfectly.\",\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n                \"failing_step\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"This should fail.\",\n                    \"seconds\": 0,\n                    \"fail\": True,\n                },\n                \"dependent_step\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": {\"type\": \"ref\", \"ref\": \"failing_step\"},\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n            }\n        )\n        executor = Executor(workspace)\n        assert len(executor.workspace.step_cache) == 0\n        output = executor.execute_step_graph(step_graph)\n        assert \"successful_step\" in output.successful\n        assert \"failing_step\" in output.failed\n        assert \"dependent_step\" in output.not_run\n        assert len(executor.workspace.step_cache) == 1\n"
  },
  {
    "path": "tests/executors/__init__.py",
    "content": ""
  },
  {
    "path": "tests/executors/multicore_executor_test.py",
    "content": "import time\n\nimport pytest\n\nfrom tango.common.logging import initialize_logging\nfrom tango.common.testing import TangoTestCase\nfrom tango.common.testing.steps import SleepPrintMaybeFail\nfrom tango.executors.multicore_executor import MulticoreExecutor\nfrom tango.step_graph import StepGraph\nfrom tango.workspaces import LocalWorkspace\n\n\nclass TestMulticoreExecutor(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        initialize_logging()\n\n    def test_simple_execution_in_parallel(self):\n        step_graph = StepGraph(\n            {\n                \"step1\": SleepPrintMaybeFail(string=\"hello\", seconds=5, fail=False),\n                \"step2\": SleepPrintMaybeFail(string=\"hi\", seconds=5, fail=False),\n            }\n        )\n\n        executor = MulticoreExecutor(workspace=LocalWorkspace(self.TEST_DIR), parallelism=2)\n\n        start_time = time.time()\n        executor.execute_step_graph(step_graph)\n        end_time = time.time()\n        time_taken = end_time - start_time\n        assert time_taken < 10  # TODO: will this be flaky?\n\n        assert len(executor.workspace.step_cache) == 2\n\n    def test_more_processes_ready_than_parallelism(self):\n        step_graph = StepGraph(\n            {\n                \"step1\": SleepPrintMaybeFail(string=\"hello\", seconds=5, fail=False),\n                \"step2\": SleepPrintMaybeFail(string=\"hi\", seconds=5, fail=False),\n                \"step3\": SleepPrintMaybeFail(string=\"howdy\", seconds=5, fail=False),\n            }\n        )\n\n        executor = MulticoreExecutor(workspace=LocalWorkspace(self.TEST_DIR), parallelism=2)\n        start_time = time.time()\n        executor.execute_step_graph(step_graph)\n        end_time = time.time()\n        time_taken = end_time - start_time\n        assert 10 < time_taken < 20  # TODO: will this be flaky?\n\n        assert len(executor.workspace.step_cache) == 3\n\n    @pytest.mark.parametrize(\"parallelism\", [1, 2, 3])\n    def test_failing_step_no_downstream_task(self, parallelism):\n        step_graph = StepGraph.from_params(\n            {\n                \"step1\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"string_to_pass_down\",\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n                \"step2\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": {\"type\": \"ref\", \"ref\": \"step1\"},\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n                \"step3\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"This is going to fail!\",\n                    \"seconds\": 0,\n                    \"fail\": True,\n                },\n            }\n        )\n\n        executor = MulticoreExecutor(\n            workspace=LocalWorkspace(self.TEST_DIR),\n            parallelism=parallelism,\n        )\n\n        executor.execute_step_graph(step_graph)\n        assert len(executor.workspace.step_cache) == 2\n\n    @pytest.mark.parametrize(\"parallelism\", [1, 2, 3])\n    def test_failing_step_with_downstream_task(self, parallelism):\n        step_graph = StepGraph.from_params(\n            {\n                \"step1\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"string_to_pass_down\",\n                    \"seconds\": 0,\n                    \"fail\": True,\n                },\n                \"step2\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": {\"type\": \"ref\", \"ref\": \"step1\"},\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n                \"step3\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"This is going to fail!\",\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n            }\n        )\n\n        executor = MulticoreExecutor(\n            workspace=LocalWorkspace(self.TEST_DIR),\n            parallelism=parallelism,\n        )\n\n        executor.execute_step_graph(step_graph)\n        assert len(executor.workspace.step_cache) == 1\n\n    @pytest.mark.parametrize(\"parallelism\", [1, 2, 3])\n    def test_failing_step_with_further_downstream_task(self, parallelism):\n        step_graph = StepGraph.from_params(\n            {\n                \"step1\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"string_to_pass_down\",\n                    \"seconds\": 0,\n                    \"fail\": True,\n                },\n                \"step2\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": {\"type\": \"ref\", \"ref\": \"step1\"},\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n                \"step3\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": {\"type\": \"ref\", \"ref\": \"step2\"},\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n            }\n        )\n\n        executor = MulticoreExecutor(\n            workspace=LocalWorkspace(self.TEST_DIR),\n            parallelism=parallelism,\n        )\n\n        executor.execute_step_graph(step_graph)\n        assert len(executor.workspace.step_cache) == 0\n\n    def test_uncacheable_failing_step_no_downstream_task(self):\n        step_graph = StepGraph.from_params(\n            {\n                \"step1\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"string_to_pass_down\",\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n                \"step2\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": {\"type\": \"ref\", \"ref\": \"step1\"},\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n                \"step3\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"This is going to fail!\",\n                    \"seconds\": 0,\n                    \"fail\": True,\n                    \"cache_results\": False,\n                },\n            }\n        )\n\n        executor = MulticoreExecutor(\n            workspace=LocalWorkspace(self.TEST_DIR),\n            parallelism=2,\n        )\n\n        executor.execute_step_graph(step_graph)\n        assert len(executor.workspace.step_cache) == 2\n\n    def test_uncacheable_failing_step_with_downstream_task(self):\n        step_graph = StepGraph.from_params(\n            {\n                \"step1\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"string_to_pass_down\",\n                    \"seconds\": 0,\n                    \"fail\": True,\n                    \"cache_results\": False,\n                },\n                \"step2\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": {\"type\": \"ref\", \"ref\": \"step1\"},\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n                \"step3\": {\n                    \"type\": \"sleep-print-maybe-fail\",\n                    \"string\": \"This is going to fail!\",\n                    \"seconds\": 0,\n                    \"fail\": False,\n                },\n            }\n        )\n\n        executor = MulticoreExecutor(\n            workspace=LocalWorkspace(self.TEST_DIR),\n            parallelism=2,\n        )\n\n        executor.execute_step_graph(step_graph)\n        assert len(executor.workspace.step_cache) == 1\n\n    @pytest.mark.parametrize(\"parallelism\", [1, 2, 3])\n    def test_steps_with_their_own_multiprocessing(self, parallelism):\n        step_graph = StepGraph.from_params(\n            {\n                \"step1\": {\"type\": \"multiprocessing_step\", \"num_proc\": 2},\n                \"step2\": {\"type\": \"multiprocessing_step\", \"num_proc\": 3},\n                \"step3\": {\"type\": \"multiprocessing_step\", \"num_proc\": 1},\n            }\n        )\n\n        executor = MulticoreExecutor(\n            workspace=LocalWorkspace(self.TEST_DIR),\n            parallelism=parallelism,\n        )\n\n        executor.execute_step_graph(step_graph)\n        assert len(executor.workspace.step_cache) == 3\n"
  },
  {
    "path": "tests/format_test.py",
    "content": "from typing import Dict, Iterable, Optional\n\nimport pytest\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.format import _OPEN_FUNCTIONS, DillFormat, JsonFormat, TextFormat\n\n\nclass TestFormat(TangoTestCase):\n    @pytest.mark.parametrize(\"compress\", _OPEN_FUNCTIONS.keys())\n    def test_dill_format(self, compress: Optional[str]):\n        artifact = \"Hello, World!\"\n        format = DillFormat[str](compress)\n        format.write(artifact, self.TEST_DIR)\n        assert format.read(self.TEST_DIR) == artifact\n        assert \"compress\" in format.to_params()\n\n    @pytest.mark.parametrize(\"compress\", _OPEN_FUNCTIONS.keys())\n    def test_iterable_dill_format(self, compress: Optional[str]):\n        r = (x + 1 for x in range(10))\n        format = DillFormat[Iterable[int]](compress)\n        format.write(r, self.TEST_DIR)\n        r2 = format.read(self.TEST_DIR)\n        assert [x + 1 for x in range(10)] == list(r2)\n        assert \"compress\" in format.to_params()\n\n    @pytest.mark.parametrize(\"compress\", _OPEN_FUNCTIONS.keys())\n    def test_json_format(self, compress: Optional[str]):\n        artifact = {\"Hello, World!\": \"Hi!\"}\n        format = JsonFormat[Dict[str, str]](compress)\n        format.write(artifact, self.TEST_DIR)\n        assert format.read(self.TEST_DIR) == artifact\n        assert \"compress\" in format.to_params()\n\n    @pytest.mark.parametrize(\"compress\", _OPEN_FUNCTIONS.keys())\n    def test_iterable_json_format(self, compress: Optional[str]):\n        r = (x + 1 for x in range(10))\n        format = JsonFormat[Iterable[int]](compress)\n        format.write(r, self.TEST_DIR)\n        r2 = format.read(self.TEST_DIR)\n        assert [x + 1 for x in range(10)] == list(r2)\n        assert \"compress\" in format.to_params()\n\n    def test_iterable_text_format(self):\n        numbers = [\"ichi\", \"ni\", \"san\"]\n        l1 = iter(numbers)\n        format = TextFormat()\n        format.write(l1, self.TEST_DIR)\n        l2 = format.read(self.TEST_DIR)\n        assert list(l2) == numbers\n"
  },
  {
    "path": "tests/integrations/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/beaker/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/beaker/conftest.py",
    "content": "import os\nimport uuid\nfrom pathlib import Path\nfrom typing import Generator\n\nimport pytest\nfrom beaker import Beaker\n\nfrom tango.common import util\nfrom tango.integrations.beaker.common import Constants\nfrom tango.step import Step\n\n\n@pytest.fixture(autouse=True)\ndef patched_cache_dir(tmp_path, monkeypatch) -> Path:\n    monkeypatch.setattr(util, \"tango_cache_dir\", lambda: tmp_path)\n    return tmp_path\n\n\n@pytest.fixture(autouse=True)\ndef patched_unique_id_suffix(monkeypatch) -> str:\n    UNIQUE_ID_SUFFIX = os.environ.get(\"GITHUB_SHA\", \"\")[:6] + \"-\" + str(uuid.uuid1())[:6]\n    monkeypatch.setattr(Step, \"_UNIQUE_ID_SUFFIX\", UNIQUE_ID_SUFFIX)\n    return UNIQUE_ID_SUFFIX\n\n\n@pytest.fixture(autouse=True)\ndef patched_constants_prefix(monkeypatch) -> str:\n    PREFIX = os.environ.get(\"GITHUB_SHA\", \"A\")[:6] + \"-\" + str(uuid.uuid1())[:6] + \"-\"\n    monkeypatch.setattr(Constants, \"STEP_ARTIFACT_PREFIX\", \"tango-step-\" + PREFIX)\n    monkeypatch.setattr(Constants, \"RUN_ARTIFACT_PREFIX\", \"tango-run-\" + PREFIX)\n    monkeypatch.setattr(Constants, \"ENTRYPOINT_DATASET_PREFIX\", \"tango-entrypoint-\" + PREFIX)\n    monkeypatch.setattr(Constants, \"STEP_GRAPH_ARTIFACT_PREFIX\", \"tango-step-graph-\" + PREFIX)\n    monkeypatch.setattr(Constants, \"STEP_EXPERIMENT_PREFIX\", \"tango-step-\" + PREFIX)\n    return PREFIX\n\n\n@pytest.fixture\ndef beaker_workspace_name() -> str:\n    return \"ai2/tango-beaker-testing\"\n\n\n@pytest.fixture\ndef beaker_workspace(\n    beaker_workspace_name: str, patched_unique_id_suffix: str, patched_constants_prefix: str\n) -> Generator[str, None, None]:\n    beaker = Beaker.from_env(default_workspace=beaker_workspace_name)\n    yield beaker_workspace_name\n    # Remove experiments.\n    #  for experiment in beaker.workspace.experiments(match=patched_constants_prefix):\n    #      beaker.experiment.delete(experiment)\n    # Remove datasets.\n    for dataset in beaker.workspace.datasets(match=patched_unique_id_suffix):\n        beaker.dataset.delete(dataset)\n    for dataset in beaker.workspace.datasets(match=patched_constants_prefix):\n        beaker.dataset.delete(dataset)\n"
  },
  {
    "path": "tests/integrations/beaker/executor_test.py",
    "content": "import petname\nimport pytest\nfrom beaker import DataMount\n\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.testing import run_experiment\nfrom tango.executor import Executor\nfrom tango.integrations.beaker.executor import BeakerExecutor\nfrom tango.integrations.beaker.workspace import BeakerWorkspace\nfrom tango.settings import TangoGlobalSettings\nfrom tango.workspaces import default_workspace\n\n\ndef test_from_params(beaker_workspace_name: str):\n    executor = Executor.from_params(\n        dict(\n            type=\"beaker\",\n            beaker_workspace=beaker_workspace_name,\n            beaker_image=\"ai2/conda\",\n            github_token=\"FAKE_TOKEN\",\n            datasets=[{\"source\": {\"beaker\": \"some-dataset\"}, \"mount_path\": \"/input\"}],\n            budget=\"ai2/allennlp\",\n        ),\n        workspace=BeakerWorkspace(workspace=beaker_workspace_name),\n        clusters=[\"fake-cluster\"],\n    )\n    assert isinstance(executor, BeakerExecutor)\n    assert executor.datasets is not None\n    assert len(executor.datasets) == 1\n    assert isinstance(executor.datasets[0], DataMount)\n    assert executor.datasets[0].source.beaker == \"some-dataset\"\n\n\ndef test_init_with_mem_workspace(beaker_workspace_name: str):\n    with pytest.raises(ConfigurationError, match=\"MemoryWorkspace\"):\n        BeakerExecutor(\n            workspace=default_workspace,\n            beaker_workspace=beaker_workspace_name,\n            beaker_image=\"ai2/conda\",\n            github_token=\"FAKE_TOKEN\",\n            clusters=[\"fake-cluster\"],\n            budget=\"ai2/allennlp\",\n        )\n\n\n@pytest.fixture\ndef settings(beaker_workspace_name: str) -> TangoGlobalSettings:\n    return TangoGlobalSettings(\n        workspace={\"type\": \"beaker\", \"beaker_workspace\": beaker_workspace_name},\n        executor={\n            \"type\": \"beaker\",\n            \"beaker_workspace\": beaker_workspace_name,\n            \"install_cmd\": \"pip install .[beaker]\",\n            \"clusters\": [\"ai2/allennlp-cirrascale\", \"ai2/general-cirrascale\"],\n            \"budget\": \"ai2/allennlp\",\n        },\n    )\n\n\ndef test_beaker_executor(\n    settings: TangoGlobalSettings, beaker_workspace_name: str, patched_unique_id_suffix: str\n):\n    run_name = petname.generate()\n    with run_experiment(\n        {\"steps\": {\"hello\": {\"type\": \"string\", \"result\": \"Hello, World!\"}}},\n        settings=settings,\n        workspace_url=f\"beaker://{beaker_workspace_name}\",\n        name=run_name,\n        multicore=None,\n    ):\n        workspace = BeakerWorkspace(workspace=beaker_workspace_name)\n        assert \"hello\" in workspace.registered_run(run_name).steps\n"
  },
  {
    "path": "tests/integrations/beaker/step_cache_test.py",
    "content": "from tango.common.testing.steps import FloatStep\nfrom tango.integrations.beaker.step_cache import BeakerStepCache\n\n\ndef test_step_cache(beaker_workspace: str):\n    cache = BeakerStepCache(beaker_workspace=beaker_workspace)\n\n    step = FloatStep(result=1.0)\n    cache[step] = 1.0\n    assert step in cache\n    assert len(cache) == 1\n    assert FloatStep(result=2.0) not in cache\n    assert cache[step] == 1.0\n"
  },
  {
    "path": "tests/integrations/beaker/workspace_test.py",
    "content": "import pytest\nfrom beaker import DatasetNotFound\n\nfrom tango.common.testing.steps import FloatStep\nfrom tango.integrations.beaker.workspace import BeakerWorkspace\nfrom tango.step_info import StepState\nfrom tango.workspace import Workspace\n\n\ndef test_from_url(beaker_workspace: str):\n    print(beaker_workspace)\n    workspace = Workspace.from_url(f\"beaker://{beaker_workspace}\")\n    assert isinstance(workspace, BeakerWorkspace)\n\n\ndef test_direct_usage(beaker_workspace: str):\n    workspace = BeakerWorkspace(beaker_workspace)\n\n    step = FloatStep(step_name=\"float\", result=1.0)\n    run = workspace.register_run([step])\n    assert run.name in workspace.registered_runs()\n\n    assert workspace.step_info(step).state == StepState.INCOMPLETE\n    workspace.step_starting(step)\n    assert workspace.step_info(step).state == StepState.RUNNING\n    workspace.step_finished(step, 1.0)\n    assert workspace.step_info(step).state == StepState.COMPLETED\n    assert workspace.step_result_for_run(run.name, \"float\") == 1.0\n\n\ndef test_remove_step(beaker_workspace: str):\n    beaker_workspace = \"ai2/tango_remove_cache_test\"\n    workspace = BeakerWorkspace(beaker_workspace)\n    step = FloatStep(step_name=\"float\", result=1.0)\n\n    workspace.step_starting(step)\n    workspace.step_finished(step, 1.0)\n\n    step_info = workspace.step_info(step)\n    dataset_name = workspace.Constants.step_artifact_name(step_info)\n    cache = workspace.step_cache\n\n    assert workspace.beaker.dataset.get(dataset_name) is not None\n    assert step in cache\n\n    workspace.remove_step(step.unique_id)\n    cache = workspace.step_cache\n    dataset_name = workspace.Constants.step_artifact_name(step_info)\n\n    with pytest.raises(DatasetNotFound):\n        workspace.beaker.dataset.get(dataset_name)\n    assert step not in cache\n"
  },
  {
    "path": "tests/integrations/datasets/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/datasets/dataset_test.py",
    "content": "import datasets\n\nfrom tango.common.sequences import MappedSequence\nfrom tango.common.testing import TangoTestCase\nfrom tango.integrations.datasets import (\n    DatasetRemixStep,\n    DatasetsFormat,\n    LoadDataset,\n    convert_to_tango_dataset_dict,\n)\nfrom tango.step import Step\n\n\nclass TestDatasets(TangoTestCase):\n    def test_from_params_and_convert_to_tango_dataset_dict(self):\n        step: LoadDataset = Step.from_params(  # type: ignore[assignment]\n            {\n                \"type\": \"datasets::load\",\n                \"path\": \"lhoestq/test\",\n                \"cache_dir\": str(self.TEST_DIR / \"cache\"),\n            }\n        )\n        hf_dataset_dict = step.result()\n        assert \"train\" in hf_dataset_dict\n        dataset_dict = convert_to_tango_dataset_dict(hf_dataset_dict)\n        assert \"train\" in dataset_dict.splits\n\n    def test_convert_to_tango_iterable_dataset_dict(self):\n        def data_gen():\n            for x in range(100):\n                yield {\"x\": x}\n\n        hf_dataset_dict = datasets.IterableDatasetDict(\n            train=datasets.iterable_dataset.IterableDataset.from_generator(data_gen)\n        )\n        dataset_dict1 = convert_to_tango_dataset_dict(hf_dataset_dict)\n        assert \"train\" in dataset_dict1.splits\n\n    def test_load_concatenate_and_interleave(self):\n        result_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations\" / \"datasets\" / \"config.json\",\n            overrides={\n                \"steps.train_data.cache_dir\": str(self.TEST_DIR / \"cache\"),\n                \"steps.dev_data.cache_dir\": str(self.TEST_DIR / \"cache\"),\n            },\n        )\n        assert (result_dir / \"train_data\" / \"data\").is_dir()\n        dataset = DatasetsFormat().read(result_dir / \"train_data\")\n        assert len(dataset) == 2\n\n\ndef test_mapped_sequence_of_dataset():\n    ds = datasets.load_dataset(\"piqa\", split=\"validation\")\n    mapped_ds = MappedSequence(lambda x: x[\"goal\"], ds)  # type: ignore[arg-type]\n    assert len(ds) == len(mapped_ds)  # type: ignore[arg-type]\n    assert ds[0][\"goal\"] == mapped_ds[0]  # type: ignore[index]\n    assert ds[0][\"goal\"] == mapped_ds[:10][0]  # type: ignore[index]\n\n\ndef test_datasets_dataset_remix():\n    dataset_dict = datasets.load_dataset(\"lhoestq/test\")\n    step = DatasetRemixStep()\n    result = step.run(\n        input=dataset_dict,  # type: ignore[arg-type]\n        new_splits={\n            \"all\": \"train + validation\",\n            \"crossval_train\": \"train[:1] + validation[1:]\",\n            \"crossval_test\": \"train[1:] + validation[:1]\",\n        },\n    )\n    assert len(result[\"all\"]) == len(dataset_dict[\"train\"]) + len(dataset_dict[\"validation\"])  # type: ignore\n    assert len(result[\"crossval_train\"]) == 3\n    assert len(result[\"crossval_test\"]) == 2\n"
  },
  {
    "path": "tests/integrations/fairscale/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/fairscale/train_test.py",
    "content": "from typing import Any, Dict\n\nimport pytest\nimport torch\n\nfrom tango.common.logging import initialize_logging, teardown_logging\nfrom tango.common.testing import TangoTestCase\n\n\nclass TestFairScaleTrain(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        initialize_logging(log_level=\"info\")\n\n    def teardown_method(self):\n        teardown_logging()\n\n    @pytest.mark.parametrize(\n        \"fsdp\",\n        (\n            pytest.param(\n                True,\n                id=\"fsdp=True\",\n                marks=[\n                    pytest.mark.gpu,\n                    pytest.mark.skipif(\n                        torch.cuda.device_count() < 2, reason=\"Requires CUDA devices\"\n                    ),\n                ],\n            ),\n            pytest.param(False, id=\"fsdp=False\"),\n        ),\n    )\n    @pytest.mark.parametrize(\n        \"activation_checkpoint\",\n        (\n            pytest.param(True, id=\"checkpointing=True\"),\n            pytest.param(False, id=\"checkpointing=False\"),\n        ),\n    )\n    @pytest.mark.parametrize(\n        \"amp\",\n        (\n            pytest.param(\n                True,\n                id=\"amp=True\",\n                marks=[\n                    pytest.mark.gpu,\n                    pytest.mark.skipif(\n                        torch.cuda.device_count() < 2, reason=\"Requires CUDA devices\"\n                    ),\n                ],\n            ),\n            pytest.param(False, id=\"amp=False\"),\n        ),\n    )\n    def test_train_tiny_gpt2(self, fsdp: bool, activation_checkpoint: bool, amp: bool):\n        overrides: Dict[str, Any] = {\n            \"steps.trained_model.model.activation_checkpointing\": activation_checkpoint,\n        }\n        training_engine: Dict[str, Any] = {\n            \"amp\": amp,\n            \"optimizer\": {\n                \"type\": \"torch::AdamW\",\n                \"lr\": 0.005,\n                \"betas\": [0.9, 0.95],\n                \"eps\": 1e-6,\n            },\n        }\n        if fsdp:\n            training_engine[\"type\"] = \"fairscale\"\n            fsdp_config = {\"reshard_after_forward\": True, \"mixed_precision\": amp}\n            training_engine[\"fsdp_config\"] = fsdp_config\n            overrides[\"steps.trained_model.model.fsdp_config\"] = fsdp_config\n        else:\n            training_engine[\"type\"] = \"torch\"\n            overrides[\"steps.trained_model.model.fsdp_config\"] = None\n        overrides[\"steps.trained_model.training_engine\"] = training_engine\n        run_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations\" / \"fairscale\" / \"config.jsonnet\",\n            include_package=[\"test_fixtures.integrations.fairscale.components\"],\n            overrides=overrides,\n        )\n        assert (run_dir / \"trained_model\").is_dir()\n"
  },
  {
    "path": "tests/integrations/flax/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/flax/data_test.py",
    "content": "from typing import Dict\n\nfrom transformers import AutoTokenizer\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.integrations.flax import DataLoader, FlaxDataLoader\nfrom tango.integrations.flax.util import get_PRNGkey\nfrom tango.step import Step\n\n\nclass TestDataStep(TangoTestCase):\n    def test_dataloader(self) -> None:\n        assert \"flax::dataloader\" in DataLoader.list_available()\n\n    def test_sample_data(self) -> None:\n        step = Step.from_params(  # type: ignore[assignment]\n            {\n                \"type\": \"datasets::load\",\n                \"path\": \"lhoestq/demo1\",\n                \"split\": \"train\",\n                \"cache_dir\": str(self.TEST_DIR / \"cache\"),\n            }\n        )\n\n        dataset = step.result()\n        tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n        column_names = dataset.column_names\n        dataset = dataset.map(\n            lambda e: tokenizer(e[\"review\"], truncation=True, padding=\"max_length\")\n        )\n        dataset = dataset.remove_columns(column_names)\n        data = FlaxDataLoader(dataset, batch_size=16)\n        rng = get_PRNGkey()\n        for batch in data(rng, do_distributed=False):\n            assert isinstance(batch, Dict)\n"
  },
  {
    "path": "tests/integrations/flax/format_test.py",
    "content": "import os\n\nfrom tango import Format\nfrom tango.common.testing import TangoTestCase\nfrom tango.integrations.flax.format import FlaxFormat\n\n\nclass TestTorchFormat(TangoTestCase):\n    def test_read_write(self):\n        flax_format: FlaxFormat = Format.by_name(\"flax\")()  # type: ignore[assignment]\n        flax_format.write({\"a\": 1}, self.TEST_DIR)\n        assert os.path.exists(self.TEST_DIR / \"checkpoint_0\")\n        data = flax_format.read(self.TEST_DIR)\n        assert data == {\"a\": 1}\n"
  },
  {
    "path": "tests/integrations/flax/optim_test.py",
    "content": "from tango.integrations.flax.optim import LRScheduler, Optimizer\n\n\ndef test_all_optimizers_registered():\n    assert \"optax::adafactor\" in Optimizer.list_available()\n\n\ndef test_all_lr_schedulers_registered():\n    assert \"optax::constant_schedule\" in LRScheduler.list_available()\n"
  },
  {
    "path": "tests/integrations/flax/train_test.py",
    "content": "from tango.common.logging import initialize_logging, teardown_logging\nfrom tango.common.testing import TangoTestCase\n\n\nclass TestTrainStep(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        initialize_logging(enable_cli_logs=True)\n\n    def teardown_method(self):\n        super().teardown_method()\n        teardown_logging()\n\n    def test_trainer(self):\n        result_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations\" / \"flax\" / \"config.jsonnet\",\n            include_package=[\n                \"test_fixtures.integrations.common\",\n                \"test_fixtures.integrations.flax\",\n            ],\n        )\n        assert (\n            result_dir\n            / \"train\"\n            / \"work\"\n            / \"checkpoint_state_latest\"\n            / \"checkpoint_0\"\n            / \"checkpoint\"\n        ).is_file()\n"
  },
  {
    "path": "tests/integrations/gs/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/gs/step_cache_test.py",
    "content": "import os\n\nimport pytest\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.common.testing.steps import FloatStep\nfrom tango.integrations.gs.common import empty_bucket_folder\nfrom tango.integrations.gs.step_cache import GSStepCache\n\nGS_BUCKET_NAME = os.environ.get(\"GS_BUCKET_NAME\", \"allennlp-tango-bucket\")\nGS_SUBFOLDER = f\"{GS_BUCKET_NAME}/my-workspaces/workspace1\"\n\n\nclass TestGSStepCache(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        empty_bucket_folder(GS_BUCKET_NAME)\n        empty_bucket_folder(GS_SUBFOLDER)\n\n    def teardown_method(self):\n        super().teardown_method()\n\n    @pytest.mark.parametrize(\"gs_path\", [GS_BUCKET_NAME, GS_SUBFOLDER])\n    def test_step_cache(self, gs_path):\n        cache = GSStepCache(folder_name=gs_path)\n        step = FloatStep(result=1.0)\n        cache[step] = 1.0\n        assert step in cache\n        assert len(cache) == 1\n        assert FloatStep(result=2.0) not in cache\n        assert cache[step] == 1.0\n"
  },
  {
    "path": "tests/integrations/gs/workspace_test.py",
    "content": "import os\n\nimport pytest\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.common.testing.steps import FloatStep\nfrom tango.integrations.gs.common import empty_bucket_folder, empty_datastore\nfrom tango.integrations.gs.workspace import GSWorkspace\nfrom tango.step_info import StepState\nfrom tango.workspace import Workspace\n\nGS_BUCKET_NAME = os.environ.get(\"GS_BUCKET_NAME\", \"allennlp-tango-bucket\")\nGS_SUBFOLDER = f\"{GS_BUCKET_NAME}/my-workspaces/workspace1\"\n\n\nclass TestGSWorkspace(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        empty_bucket_folder(GS_BUCKET_NAME)\n        empty_bucket_folder(GS_SUBFOLDER)\n        empty_datastore(GS_BUCKET_NAME)\n        empty_datastore(GS_SUBFOLDER)\n\n    def teardown_method(self):\n        super().teardown_method()\n\n    @pytest.mark.parametrize(\"gs_path\", [GS_BUCKET_NAME, GS_SUBFOLDER])\n    def test_from_url(self, gs_path: str):\n        workspace = Workspace.from_url(f\"gs://{gs_path}\")\n        assert isinstance(workspace, GSWorkspace)\n\n    @pytest.mark.parametrize(\"gs_path\", [GS_BUCKET_NAME, GS_SUBFOLDER])\n    def test_from_params(self, gs_path: str):\n        workspace = Workspace.from_params({\"type\": \"gs\", \"workspace\": gs_path})\n        assert isinstance(workspace, GSWorkspace)\n\n    @pytest.mark.parametrize(\"gs_path\", [GS_BUCKET_NAME, GS_SUBFOLDER])\n    def test_direct_usage(self, gs_path: str):\n        workspace = GSWorkspace(gs_path)\n\n        step = FloatStep(step_name=\"float\", result=1.0)\n        run = workspace.register_run([step])\n        assert run.name in workspace.registered_runs()\n\n        assert workspace.step_info(step).state == StepState.INCOMPLETE\n        workspace.step_starting(step)\n        assert workspace.step_info(step).state == StepState.RUNNING\n        workspace.step_finished(step, 1.0)\n        assert workspace.step_info(step).state == StepState.COMPLETED\n        assert workspace.step_result_for_run(run.name, \"float\") == 1.0\n\n    def test_remove_step(self):\n        workspace = GSWorkspace(GS_BUCKET_NAME)\n        step = FloatStep(step_name=\"float\", result=1.0)\n        step_info = workspace.step_info(step)\n\n        workspace.step_starting(step)\n        workspace.step_finished(step, 1.0)\n        bucket_artifact = workspace.Constants.step_artifact_name(step_info)\n        ds_entity = workspace._ds.get(key=workspace._ds.key(\"stepinfo\", step_info.unique_id))\n        cache = workspace.step_cache\n\n        assert workspace.client.artifacts(prefix=bucket_artifact) is not None\n        assert ds_entity is not None\n        assert step in cache\n\n        workspace.remove_step(step.unique_id)\n        cache = workspace.step_cache\n\n        ds_entity = workspace._ds.get(key=workspace._ds.key(\"stepinfo\", step_info.unique_id))\n\n        with pytest.raises(Exception) as excinfo:\n            workspace.client.artifacts(prefix=bucket_artifact)\n\n        assert \"KeyError\" in str(excinfo)\n        assert ds_entity is None\n        assert step not in cache\n"
  },
  {
    "path": "tests/integrations/torch/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/torch/data_test.py",
    "content": "import torch\n\nfrom tango.integrations.torch.data import DataLoader, Sampler\n\n\ndef test_dataloader_from_params():\n    DataLoader.from_params(\n        {\n            \"dataset\": list(range(10)),\n            \"batch_size\": 2,\n            \"shuffle\": True,\n        }\n    )\n\n\ndef test_samplers_registered():\n    assert \"torch::SequentialSampler\" in Sampler.list_available()\n\n\ndef test_dataloader_from_params_with_sampler():\n    dataloader = DataLoader.from_params(\n        {\n            \"dataset\": list(range(10)),\n            \"sampler\": {\n                \"type\": \"torch::RandomSampler\",\n                \"replacement\": True,\n            },\n        }\n    )\n    assert isinstance(dataloader.sampler, torch.utils.data.RandomSampler)\n    assert dataloader.sampler.replacement\n\n\ndef test_dataloader_from_params_with_batch_sampler():\n    dataloader = DataLoader.from_params(\n        {\n            \"dataset\": list(range(10)),\n            \"sampler\": {\n                \"type\": \"torch::BatchSampler\",\n                \"sampler\": {\n                    \"type\": \"torch::RandomSampler\",\n                },\n                \"batch_size\": 2,\n                \"drop_last\": True,\n            },\n        }\n    )\n    assert isinstance(dataloader.sampler, torch.utils.data.BatchSampler)\n"
  },
  {
    "path": "tests/integrations/torch/det_hash_test.py",
    "content": "import numpy\nimport torch\n\nfrom tango.common import det_hash\n\n\ndef test_numpy_det_hash():\n    a1 = numpy.array([[1, 2], [3, 4]], order=\"C\")\n    a2 = numpy.array([[1, 2], [3, 4]], order=\"K\")\n    assert det_hash(a1) == det_hash(a2)\n\n\ndef test_torch_det_hash():\n    a1 = numpy.array([[1, 2], [3, 4]], order=\"C\")\n    a2 = numpy.array([[1, 2], [3, 4]], order=\"K\")\n    a1 = torch.tensor(a1)\n    a2 = torch.tensor(a2)\n    assert det_hash(a1) == det_hash(a2)\n"
  },
  {
    "path": "tests/integrations/torch/eval_test.py",
    "content": "from tango.common.testing import TangoTestCase\n\n\nclass TestEvalStep(TangoTestCase):\n    def test_basic_eval(self):\n        result_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations/torch/eval.jsonnet\",\n            include_package=[\n                \"test_fixtures.integrations.common\",\n                \"test_fixtures.integrations.torch\",\n            ],\n        )\n        assert (result_dir / \"eval\" / \"data.json\").is_file()\n"
  },
  {
    "path": "tests/integrations/torch/format_test.py",
    "content": "import os\n\nfrom tango import Format\nfrom tango.common.testing import TangoTestCase\nfrom tango.integrations.torch.format import TorchFormat\n\n\nclass TestTorchFormat(TangoTestCase):\n    def test_read_write(self):\n        torch_format: TorchFormat = Format.by_name(\"torch\")()  # type: ignore[assignment]\n        torch_format.write({\"a\": 1}, self.TEST_DIR)\n        assert os.path.exists(self.TEST_DIR / \"data.pt\")\n        data = torch_format.read(self.TEST_DIR)\n        assert data == {\"a\": 1}\n"
  },
  {
    "path": "tests/integrations/torch/optim_test.py",
    "content": "from tango.integrations.torch.optim import LRScheduler, Optimizer\n\n\ndef test_all_optimizers_registered():\n    assert \"torch::Adagrad\" in Optimizer.list_available()\n\n\ndef test_all_lr_schedulers_registered():\n    assert \"torch::ExponentialLR\" in LRScheduler.list_available()\n"
  },
  {
    "path": "tests/integrations/torch/train_callback_test.py",
    "content": "from pathlib import Path\n\nimport pytest\nfrom torch.optim import SGD\n\nfrom tango.common import DatasetDict, Lazy\nfrom tango.integrations.torch import (\n    DataLoader,\n    StopEarly,\n    StopEarlyCallback,\n    TorchTrainingEngine,\n    TrainConfig,\n)\nfrom tango.workspaces import MemoryWorkspace\n\nfrom .training_engine_test import DummyModel\n\n\ndef test_stop_early_callback():\n    workspace = MemoryWorkspace()\n    train_config = TrainConfig(step_id=\"FakeStep-abc123\", work_dir=Path(\"/tmp\"))\n    training_engine = TorchTrainingEngine(\n        train_config=train_config, model=DummyModel(), optimizer=Lazy(SGD, lr=0.001)  # type: ignore\n    )\n    dataset_dict = DatasetDict(splits={\"train\": [1, 2, 3]})\n    train_dataloader = Lazy(DataLoader)\n\n    callback = StopEarlyCallback(\n        patience=10,\n        workspace=workspace,\n        train_config=train_config,\n        training_engine=training_engine,\n        dataset_dict=dataset_dict,\n        train_dataloader=train_dataloader,\n    )\n    callback.post_val_loop(1, 1, 0.5, 0.5)\n    callback.post_val_loop(2, 1, 0.5, 0.5)\n    callback.post_val_loop(20, 1, 0.6, 0.6)\n    with pytest.raises(StopEarly):\n        callback.post_val_loop(31, 1, 0.6, 0.6)\n"
  },
  {
    "path": "tests/integrations/torch/train_test.py",
    "content": "import json\n\nimport pytest\nimport torch.distributed as dist\n\nfrom tango.common.logging import initialize_logging, teardown_logging\nfrom tango.common.testing import TangoTestCase\n\n\nclass TestTrainStep(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        initialize_logging(enable_cli_logs=True)\n\n    def teardown_method(self):\n        super().teardown_method()\n        if dist.is_initialized():\n            dist.destroy_process_group()\n        teardown_logging()\n\n    @pytest.mark.parametrize(\"with_validation\", [True, False])\n    def test_basic_train(self, with_validation: bool):\n        result_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations\" / \"torch\" / \"train.jsonnet\",\n            include_package=[\n                \"test_fixtures.integrations.common\",\n                \"test_fixtures.integrations.torch\",\n            ],\n            overrides=\"\"\n            if with_validation\n            else json.dumps(\n                {\"steps.train.validation_split\": None, \"steps.train.validate_every\": None}\n            ),\n        )\n        assert (result_dir / \"train\" / \"data.pt\").is_file()\n        assert (result_dir / \"train\" / \"work\" / \"weights.pt\").is_file()\n        assert (\n            result_dir / \"train\" / \"work\" / \"checkpoint_state_latest\" / \"worker0_model.pt\"\n        ).is_file()\n        assert (\n            result_dir / \"train\" / \"work\" / \"checkpoint_state_best\" / \"worker0_optimizer.pt\"\n        ).is_file()\n        assert (\n            result_dir / \"train\" / \"work\" / \"checkpoint_state_best\" / \"worker0_trainer.pt\"\n        ).is_file()\n\n    @pytest.mark.parametrize(\"grad_acc\", [1, 2])\n    def test_basic_train_with_epochs(self, grad_acc: int):\n        result_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations\" / \"torch\" / \"train.jsonnet\",\n            include_package=[\n                \"test_fixtures.integrations.common\",\n                \"test_fixtures.integrations.torch\",\n            ],\n            overrides=json.dumps(\n                {\n                    \"steps.train.train_steps\": None,\n                    \"steps.train.train_epochs\": 2,\n                    \"steps.train.validate_every\": None,\n                    \"steps.train.grad_accum\": grad_acc,\n                }\n            ),\n        )\n        assert (result_dir / \"train\" / \"data.pt\").is_file()\n\n        # Make sure we trained for the right number of steps.\n        expected_steps = 16 // grad_acc\n        latest = result_dir / \"train\" / \"work\" / \"checkpoint_state_latest\"\n        assert latest.is_symlink()\n        last_step = result_dir / \"train\" / \"work\" / f\"checkpoint_state_step{expected_steps}\"\n        assert last_step.is_dir()\n        assert latest.samefile(last_step)\n\n    def test_basic_train_with_streaming_data(self):\n        result_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations\" / \"torch\" / \"train.jsonnet\",\n            include_package=[\n                \"test_fixtures.integrations.common\",\n                \"test_fixtures.integrations.torch\",\n            ],\n        )\n        assert (result_dir / \"train\" / \"data.pt\").is_file()\n\n    def test_train_distributed(self):\n        result_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations\" / \"torch\" / \"train_dist.jsonnet\",\n            include_package=[\n                \"test_fixtures.integrations.common\",\n                \"test_fixtures.integrations.torch\",\n            ],\n        )\n        assert (result_dir / \"train\" / \"data.pt\").is_file()\n        assert (result_dir / \"train\" / \"work\" / \"weights.pt\").is_file()\n        assert (\n            result_dir / \"train\" / \"work\" / \"checkpoint_state_latest\" / \"worker0_model.pt\"\n        ).is_file()\n        assert (\n            result_dir / \"train\" / \"work\" / \"checkpoint_state_best\" / \"worker0_model.pt\"\n        ).is_file()\n        assert (\n            result_dir / \"train\" / \"work\" / \"checkpoint_state_latest\" / \"worker1_model.pt\"\n        ).is_file()\n        assert (\n            result_dir / \"train\" / \"work\" / \"checkpoint_state_best\" / \"worker1_model.pt\"\n        ).is_file()\n\n    @pytest.mark.parametrize(\"grad_acc\", [1, 2])\n    def test_train_distributed_with_epochs(self, grad_acc: int):\n        result_dir = self.run(\n            self.FIXTURES_ROOT / \"integrations\" / \"torch\" / \"train_dist.jsonnet\",\n            include_package=[\n                \"test_fixtures.integrations.common\",\n                \"test_fixtures.integrations.torch\",\n            ],\n            overrides=json.dumps(\n                {\n                    \"steps.train.train_steps\": None,\n                    \"steps.train.train_epochs\": 2,\n                    \"steps.train.validate_every\": None,\n                    \"steps.train.grad_accum\": grad_acc,\n                }\n            ),\n        )\n\n        assert (result_dir / \"train\" / \"data.pt\").is_file()\n\n        # Make sure we trained for the right number of steps.\n        expected_steps = 8 // grad_acc\n        latest = result_dir / \"train\" / \"work\" / \"checkpoint_state_latest\"\n        assert latest.is_symlink()\n        last_step = result_dir / \"train\" / \"work\" / f\"checkpoint_state_step{expected_steps}\"\n        assert last_step.is_dir()\n        assert latest.samefile(last_step)\n"
  },
  {
    "path": "tests/integrations/torch/training_engine_test.py",
    "content": "import time\n\nimport pytest\nimport torch\nimport torch.nn as nn\nfrom torch.nn import MSELoss\n\nfrom tango.common import DatasetDict, Lazy\nfrom tango.common.testing import TangoTestCase\nfrom tango.integrations.torch import (\n    DataLoader,\n    StopEarly,\n    TorchTrainStep,\n    TrainCallback,\n)\nfrom tango.integrations.torch.model import Model\nfrom tango.integrations.torch.training_engine import TorchTrainingEngine\n\n\n@Model.register(\"dummy_model\")\nclass DummyModel(Model):\n    def __init__(self):\n        super().__init__()\n        self.linear = nn.Linear(10, 1)\n\n    def forward(self, x, y=None):\n        return self.linear(x)\n\n\n@pytest.mark.gpu\n@pytest.mark.skipif(torch.cuda.device_count() < 1, reason=\"Requires CUDA devices\")\nclass TestTorchTrainingEngine(TangoTestCase):\n    def test_grad_scaler(self):\n        training_engine = TorchTrainingEngine.from_params(\n            {\n                \"train_config\": {\"step_id\": \"001\", \"work_dir\": self.TEST_DIR},\n                \"model\": {\n                    \"type\": \"dummy_model\",\n                },\n                \"optimizer\": {\"type\": \"torch::Adam\"},\n                \"amp\": True,\n            }\n        )\n\n        state_dict = {\"training_steps\": None}\n        training_engine.save_checkpoint(self.TEST_DIR, state_dict)\n        saved_grad_scaler = training_engine.grad_scaler\n        training_engine.load_checkpoint(self.TEST_DIR)\n\n        assert (self.TEST_DIR / \"worker0_grad_scaler.pt\").is_file()\n        assert training_engine.grad_scaler == saved_grad_scaler\n\n\nclass WorseningModel(Model):\n    def __init__(self):\n        super().__init__()\n        self.linear = nn.Linear(7, 1)\n        self.loss = MSELoss()\n        self.start_time = time.time()\n\n    def forward(self, x, y):\n        y_hat = self.linear(x)\n        time.sleep(0.01)\n        return {\"loss\": self.loss(y_hat, y) + (time.time() - self.start_time)}\n\n\nclass StopOnStepCallback(TrainCallback):\n    def __init__(self, stop_on_step: int, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.stop_on_step = stop_on_step\n\n    def post_val_loop(\n        self, step: int, epoch: int, val_metric: float, best_val_metric: float\n    ) -> None:\n        if step == self.stop_on_step:\n            raise StopEarly\n\n\ndef test_with_increasing_loss():\n    model = WorseningModel()\n\n    xs = [torch.randn(7) for _ in range(100)]\n    train_set = [{\"x\": x, \"y\": x + 0.1} for x in xs]\n    dataset = DatasetDict(splits={\"train\": train_set, \"validation\": train_set}, metadata={})\n\n    step = TorchTrainStep(\n        model=model,\n        training_engine=Lazy(TorchTrainingEngine, optimizer=Lazy(torch.optim.AdamW, lr=1e-5)),\n        dataset_dict=dataset,\n        train_dataloader=Lazy(DataLoader),\n        train_steps=10,\n        validation_steps=10,\n        train_split=\"train\",\n        validation_split=\"validation\",\n        callbacks=[Lazy(StopOnStepCallback, stop_on_step=9)],\n    )\n    step.result()\n"
  },
  {
    "path": "tests/integrations/transformers/data_test.py",
    "content": "from transformers.data.data_collator import DataCollatorWithPadding, DefaultDataCollator\n\nfrom tango.integrations.torch import DataCollator\nfrom tango.integrations.transformers.data import *  # noqa: F403,F401\n\n\ndef test_init_collator_no_tokenizer():\n    collator = DataCollator.from_params({\"type\": \"transformers::DefaultDataCollator\"})\n    assert isinstance(collator, DefaultDataCollator)\n\n\ndef test_init_collator_with_tokenizer():\n    collator = DataCollator.from_params(\n        {\n            \"type\": \"transformers::DataCollatorWithPadding\",\n            \"tokenizer\": {\n                \"pretrained_model_name_or_path\": \"epwalsh/bert-xsmall-dummy\",\n            },\n        }\n    )\n    assert isinstance(collator, DataCollatorWithPadding)\n"
  },
  {
    "path": "tests/integrations/transformers/finetune_test.py",
    "content": "from datasets import Dataset, DatasetDict\nfrom transformers import AutoTokenizer\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.integrations.transformers import TokenizeText2TextData\n\n\nclass TestTokenizeText2TextData(TangoTestCase):\n    def test_tokenize_seq2seq(self):\n        dataset = Dataset.from_dict(\n            {\"field1\": [\"hello\", \"hi\"], \"field2\": [\"world\", \"me\"], \"meta_field\": [1, 0]}\n        )\n        data_dict = DatasetDict({\"train\": dataset})\n        tokenizer = AutoTokenizer.from_pretrained(\"patrickvonplaten/t5-tiny-random\")\n        step = TokenizeText2TextData()\n\n        tokenized = step.run(\n            data=data_dict, tokenizer=tokenizer, source_field=\"field1\", target_field=\"field2\"\n        )\n        assert isinstance(tokenized, DatasetDict)\n        assert len(tokenized[\"train\"]) == 2\n        assert \"input_ids\" in tokenized[\"train\"].column_names\n        assert \"labels\" in tokenized[\"train\"].column_names\n        assert tokenized[\"train\"][0][\"input_ids\"] == [21820, 1]\n\n    def test_tokenize_concat(self):\n        dataset = Dataset.from_dict(\n            {\"field1\": [\"hello\", \"hi\"], \"field2\": [\"world\", \"me\"], \"meta_field\": [1, 0]}\n        )\n        data_dict = DatasetDict({\"train\": dataset})\n        tokenizer = AutoTokenizer.from_pretrained(\"sshleifer/tiny-gpt2\")\n        step = TokenizeText2TextData()\n\n        tokenized = step.run(\n            data=data_dict,\n            tokenizer=tokenizer,\n            source_field=\"field1\",\n            target_field=\"field2\",\n            concat_source_target=True,\n        )\n        assert isinstance(tokenized, DatasetDict)\n        assert len(tokenized[\"train\"]) == 2\n        assert \"input_ids\" in tokenized[\"train\"].column_names\n        assert \"labels\" in tokenized[\"train\"].column_names\n        assert tokenized[\"train\"][0][\"input_ids\"] == [31373, 50257, 6894, 50256]\n        assert tokenized[\"train\"][0][\"labels\"] == [-100, -100, 6894, 50256]\n"
  },
  {
    "path": "tests/integrations/transformers/ia3_test.py",
    "content": "import torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom tango.integrations.transformers.ia3 import GPT_2_IA3_CONFIG, modify_with_ia3\n\n\ndef test_ia3():\n    config = GPT_2_IA3_CONFIG\n    model_name = \"sshleifer/tiny-gpt2\"\n\n    tokenizer = AutoTokenizer.from_pretrained(model_name)\n\n    input_seq = tokenizer([\"A tiny test on a tiny model.\"], return_tensors=\"pt\")\n\n    model = AutoModelForCausalLM.from_pretrained(model_name).eval()\n\n    with torch.inference_mode():\n        old_outputs = model(\n            input_ids=input_seq.input_ids,\n            attention_mask=input_seq.attention_mask,\n            labels=input_seq.input_ids,\n        )\n\n    model = modify_with_ia3(model, config=config)\n\n    with torch.inference_mode():\n        new_outputs = model(\n            input_ids=input_seq.input_ids,\n            attention_mask=input_seq.attention_mask,\n            labels=input_seq.input_ids,\n        )\n\n    logits_diff = torch.abs(old_outputs.logits - new_outputs.logits).mean()\n    assert logits_diff < 1e-10\n\n    loss_diff = torch.abs(old_outputs.loss - new_outputs.loss)\n    assert loss_diff < 1e-10\n"
  },
  {
    "path": "tests/integrations/transformers/run_generation_test.py",
    "content": "from tango import Step\nfrom tango.common import DatasetDict\nfrom tango.common.testing import TangoTestCase\nfrom tango.integrations.transformers import RunGenerationDataset\n\n\nclass TestRunGeneration(TangoTestCase):\n    def test_run_generation(self):\n        step = Step.from_params(  # type: ignore[assignment]\n            {\n                \"type\": \"transformers::run_generation\",\n                \"prompts\": [\"Tango is the future of\", \"Everybody should be using Tango to\"],\n                \"model\": \"sshleifer/tiny-gpt2\",\n            },\n        )\n        result = list(step.result())\n        assert len(result) == 2\n\n    def test_run_generation_with_model(self):\n        step = Step.from_params(  # type: ignore[assignment]\n            {\n                \"type\": \"transformers::run_generation\",\n                \"prompts\": [\"Tango is the future of\", \"Everybody should be using Tango to\"],\n                \"model\": {\n                    \"type\": \"transformers::AutoModelForCausalLM::from_pretrained\",\n                    \"pretrained_model_name_or_path\": \"sshleifer/tiny-gpt2\",\n                },\n            },\n        )\n        result = list(step.result())\n        assert len(result) == 2\n\n    def test_run_generation_dataset(self):\n        dataset = DatasetDict(\n            {\n                \"train\": [\n                    {\"prompt\": \"Tango is the future of\"},\n                    {\"prompt\": \"Everybody should be using Tango to\"},\n                ]\n            },\n            {},\n        )\n\n        step = RunGenerationDataset(\n            model=\"sshleifer/tiny-gpt2\", input=dataset, prompt_field=\"prompt\"\n        )\n\n        result = step.result()\n        assert len(result) == 1\n        train_split = result[\"train\"]\n        assert len(train_split) == 2\n        assert len(train_split[1]) == 2\n        assert train_split[1][\"prompt\"] == \"Everybody should be using Tango to\"\n        assert all(\n            g.startswith(\"Everybody should be using Tango to\")\n            for g in train_split[1][\"prompt_generated\"]\n        )\n"
  },
  {
    "path": "tests/integrations/transformers/soft_prompt_test.py",
    "content": "import transformers\n\nfrom tango.integrations.transformers import add_soft_prompt\n\n\ndef test_soft_prompt():\n    model = transformers.AutoModelForSeq2SeqLM.from_pretrained(\"t5-small\")\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\"t5-small\")\n    prompt = \"translate English to German: That is good.\"\n    model.eval()\n    generated = model.generate(\n        tokenizer.encode(prompt, return_tensors=\"pt\"), num_beams=10, num_return_sequences=5\n    )\n    original_output = [tokenizer.decode(g) for g in generated]\n\n    add_soft_prompt(model, prompt_length=3)\n    model.eval()\n    generated = model.generate(\n        tokenizer.encode(prompt, return_tensors=\"pt\"), num_beams=10, num_return_sequences=5\n    )\n    prompted_output = [tokenizer.decode(g) for g in generated]\n\n    assert original_output != prompted_output\n\n\ndef test_soft_prompt_twice():\n    tokenizer = transformers.AutoTokenizer.from_pretrained(\"gpt2\")\n\n    model = transformers.AutoModelForCausalLM.from_pretrained(\"gpt2\")\n    add_soft_prompt(model, prompt_length=2)\n    model.eval()\n    generated = model.generate(tokenizer.encode(\"It was the best of times.\", return_tensors=\"pt\"))\n    prompted_output1 = tokenizer.decode(generated[0])\n\n    add_soft_prompt(model, prompt_length=5)\n    model.eval()\n    generated = model.generate(tokenizer.encode(\"It was the best of times.\", return_tensors=\"pt\"))\n    prompted_output2 = tokenizer.decode(generated[0])\n\n    assert prompted_output1 != prompted_output2\n"
  },
  {
    "path": "tests/integrations/wandb/__init__.py",
    "content": ""
  },
  {
    "path": "tests/integrations/wandb/step_cache_test.py",
    "content": "import os\nimport pickle\nimport sys\n\nimport pytest\n\nfrom tango import Step\nfrom tango.integrations.wandb import WandbStepCache\n\nWANDB_ENTITY = os.environ.get(\"WANDB_ENTITY\", \"allennlp\")\nWANDB_PROJECT = \"tango-workspace-testing\"\n\n\nclass SomeFakeStep(Step):\n    DETERMINISTIC = True\n    CACHEABLE = True\n\n    def run(self) -> int:  # type: ignore\n        return 1\n\n\ndef test_step_cache_artifact_not_found():\n    step = SomeFakeStep(step_name=\"hi there\")\n    step_cache = WandbStepCache(project=WANDB_PROJECT, entity=WANDB_ENTITY)\n    assert step not in step_cache\n\n\n@pytest.mark.parametrize(\n    \"protocol\",\n    [pytest.param(protocol, id=f\"protocol={protocol}\") for protocol in range(4)]\n    + [\n        pytest.param(\n            5,\n            id=\"protocol=5\",\n            marks=pytest.mark.skipif(\n                sys.version_info < (3, 8), reason=\"Protocol 5 requires Python 3.8 or newer\"\n            ),\n        ),\n    ],\n)\ndef test_pickling(protocol: int):\n    step_cache = WandbStepCache(project=WANDB_PROJECT, entity=WANDB_ENTITY)\n    pickle.loads(pickle.dumps(step_cache, protocol=protocol))\n"
  },
  {
    "path": "tests/integrations/wandb/workspace_test.py",
    "content": "import json\nimport os\nimport pickle\nimport shutil\nimport sys\nimport uuid\n\nimport pytest\nimport wandb\n\nfrom tango import Step, StepGraph, Workspace\nfrom tango.common import Params, util\nfrom tango.common.logging import initialize_logging, teardown_logging\nfrom tango.common.testing import TangoTestCase\nfrom tango.common.testing.steps import *  # noqa: F403,F401\nfrom tango.integrations.wandb import WandbWorkspace\nfrom tango.step_info import StepState\n\nWANDB_ENTITY = os.environ.get(\"WANDB_ENTITY\", \"allennlp\")\nWANDB_PROJECT = \"tango-workspace-testing\"\n\n\nclass TestWandbWorkspace(TangoTestCase):\n    # Need to define the `setup_method()` as fixture so we can use other fixtures within it.\n    @pytest.fixture(autouse=True)\n    def setup_method(self, monkeypatch):\n        super().setup_method()\n        # Patch tango_cache_dir()\n        monkeypatch.setattr(util, \"tango_cache_dir\", lambda: self.TEST_DIR)\n\n    @pytest.mark.parametrize(\n        \"protocol\",\n        [pytest.param(protocol, id=f\"protocol={protocol}\") for protocol in range(4)]\n        + [\n            pytest.param(\n                5,\n                id=\"protocol=5\",\n                marks=pytest.mark.skipif(\n                    sys.version_info < (3, 8), reason=\"Protocol 5 requires Python 3.8 or newer\"\n                ),\n            ),\n        ],\n    )\n    def test_pickle_workspace(self, protocol):\n        workspace = WandbWorkspace(project=WANDB_PROJECT, entity=WANDB_ENTITY)\n        unpickled_workspace = pickle.loads(pickle.dumps(workspace, protocol=protocol))\n        assert unpickled_workspace.wandb_client is not None\n        assert unpickled_workspace.project == workspace.project\n        assert unpickled_workspace.entity == workspace.entity\n        assert unpickled_workspace.steps_dir == workspace.steps_dir\n\n    def test_from_url(self):\n        workspace = Workspace.from_url(f\"wandb://{WANDB_ENTITY}/{WANDB_PROJECT}\")\n        assert isinstance(workspace, WandbWorkspace)\n        assert workspace.entity == WANDB_ENTITY\n        assert workspace.project == WANDB_PROJECT\n\n\nclass TestWandbWorkspaceUsage(TangoTestCase):\n    # Need to define the `setup_method()` as fixture so we can use other fixtures within it.\n    @pytest.fixture(autouse=True)\n    def setup_method(self, monkeypatch):\n        super().setup_method()\n        self.UNIQUE_ID_SUFFIX = os.environ.get(\"GITHUB_SHA\", \"\")[:6] + \"-\" + str(uuid.uuid1())[:6]\n        # Patch tango_cache_dir()\n        monkeypatch.setattr(util, \"tango_cache_dir\", lambda: self.TEST_DIR)\n        # Patch Step unique IDs and W&B run IDs.\n        monkeypatch.setattr(Step, \"_UNIQUE_ID_SUFFIX\", self.UNIQUE_ID_SUFFIX)\n        monkeypatch.setattr(\n            WandbWorkspace,\n            \"_generate_run_suite_id\",\n            lambda workspace: wandb.util.generate_id() + \"-\" + self.UNIQUE_ID_SUFFIX,\n        )\n\n        self.workspace = WandbWorkspace(project=WANDB_PROJECT, entity=WANDB_ENTITY)\n\n        initialize_logging(enable_cli_logs=True)\n\n    def teardown_method(self):\n        super().teardown_method()\n\n        # Delete W&B runs and their artifacts produced by the test.\n        for wandb_run in self.workspace.wandb_client.runs(\n            f\"{WANDB_ENTITY}/{WANDB_PROJECT}\",\n        ):\n            if (\n                self.UNIQUE_ID_SUFFIX in wandb_run.id\n                or self.UNIQUE_ID_SUFFIX in wandb_run.config.get(\"_run_suite_id\", \"\")\n            ):\n                wandb_run.delete(delete_artifacts=True)\n\n        teardown_logging()\n\n    def test_direct_usage(self):\n        params = Params.from_file(self.FIXTURES_ROOT / \"experiment\" / \"hello_world.jsonnet\")\n        step_graph = StepGraph.from_params(params.pop(\"steps\", keep_as_dict=True))\n        tango_run = self.workspace.register_run(step for step in step_graph.values())\n\n        # Test 'registered_run()' and 'registered_runs()' methods.\n        assert self.workspace.registered_run(tango_run.name) == tango_run\n        assert self.workspace.registered_runs()[tango_run.name] == tango_run\n\n        hello_step = step_graph[\"hello\"]\n        hello_world_step = step_graph[\"hello_world\"]\n\n        # Test getting step info.\n        step_info = self.workspace.step_info(hello_step)\n        assert step_info.unique_id.endswith(self.UNIQUE_ID_SUFFIX)\n        assert step_info.step_name == \"hello\"\n        assert step_info.state == StepState.INCOMPLETE\n\n        # Mark the \"hello\" step as starting.\n        self.workspace.step_starting(hello_step)\n        assert self.workspace.step_info(hello_step).state == StepState.RUNNING\n\n        # Mark the \"hello\" step as finished.\n        self.workspace.step_finished(hello_step, \"hello\")\n        assert self.workspace.step_info(hello_step).state == StepState.COMPLETED\n\n        # Make sure the result is in the cache, exists locally, and on W&B.\n        cache = self.workspace.cache\n        assert hello_step in cache\n        assert cache.step_dir(hello_step).is_dir()\n        assert cache.get_step_result_artifact(hello_step) is not None\n\n        # Now make sure we can fetch the item from the cache, even if it's not in memory\n        # or in the cache directory.\n        if hello_step.unique_id in cache.weak_cache:\n            del cache.weak_cache[hello_step.unique_id]\n        if hello_step.unique_id in cache.strong_cache:\n            del cache.strong_cache[hello_step.unique_id]\n        shutil.rmtree(cache.step_dir(hello_step))\n        assert hello_step in cache\n        assert cache[hello_step] == \"hello\"\n\n        # Now start the \"hello_world\" step and then mark it as failed.\n        self.workspace.step_starting(hello_world_step)\n        self.workspace.step_failed(hello_world_step, ValueError(\"oh no!\"))\n        assert self.workspace.step_info(hello_world_step).state == StepState.FAILED\n\n    @pytest.mark.parametrize(\n        \"multicore\", [pytest.param(True, id=\"multicore\"), pytest.param(False, id=\"singe-core\")]\n    )\n    @pytest.mark.parametrize(\n        \"distributed\",\n        [\n            pytest.param(True, id=\"distributed\"),\n            pytest.param(False, id=\"single-device\"),\n        ],\n    )\n    def test_with_wandb_train_callback(self, multicore: bool, distributed: bool):\n        self.run(\n            self.FIXTURES_ROOT\n            / \"integrations\"\n            / \"torch\"\n            / (\"train.jsonnet\" if not distributed else \"train_dist.jsonnet\"),\n            include_package=[\n                \"test_fixtures.integrations.common\",\n                \"test_fixtures.integrations.torch\",\n            ],\n            overrides=json.dumps({\"steps.train.callbacks\": [{\"type\": \"wandb::log\"}]}),\n            workspace_url=f\"wandb://{WANDB_ENTITY}/{WANDB_PROJECT}\",\n            multicore=multicore,\n        )\n"
  },
  {
    "path": "tests/main_test.py",
    "content": "import json\nimport os\nimport re\nimport subprocess\nfrom pathlib import Path\nfrom typing import List, Tuple\n\nimport click\nimport pytest\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.settings import TangoGlobalSettings\nfrom tango.version import VERSION\n\n\nclass TestRun(TangoTestCase):\n    def clean_log_lines(\n        self, log_lines: List[str], file_friendly_logging: bool = False\n    ) -> List[str]:\n        out = []\n        for line in log_lines:\n            unstyled_line = click.unstyle(line)\n            if file_friendly_logging:\n                assert line == unstyled_line\n            line = unstyled_line\n            parts = re.split(r\"(DEBUG|INFO|WARNING|ERROR|CRITICAL)\\s+\", line)\n            if len(parts) >= 3:\n                line = \"\".join(parts[2:])\n                line = re.sub(r\"\\s+[^ ]+$\", \"\", line)\n            elif len(parts) == 1:\n                line = parts[0]\n            else:\n                raise ValueError(str(parts))\n            if line:\n                out.append(line.strip())\n        return out\n\n    def check_logs(\n        self,\n        run_dir: Path,\n        process_result: subprocess.CompletedProcess,\n        file_friendly_logging: bool = False,\n    ) -> Tuple[List[str], List[str]]:\n        stdout_lines = process_result.stdout.decode().replace(\"\\r\", \"\\n\").split(\"\\n\")\n        cleaned_stdout_lines = self.clean_log_lines(stdout_lines, file_friendly_logging)\n\n        log_file = run_dir / \"out.log\"\n        assert log_file.is_file()\n\n        log_lines = open(log_file).read().split(\"\\n\")\n        cleaned_log_lines = self.clean_log_lines(log_lines)\n\n        for line in cleaned_stdout_lines[\n            next(i for i, line in enumerate(stdout_lines) if \"Starting new run\" in line) :\n        ]:\n            assert line in cleaned_log_lines\n\n        return log_lines, cleaned_log_lines\n\n    def test_version(self):\n        result = subprocess.run([\"tango\", \"--version\"], capture_output=True, text=True)\n        assert result.returncode == 0\n        assert VERSION in result.stdout\n\n    @pytest.mark.parametrize(\"log_level\", [\"debug\", \"info\", \"warning\", \"error\"])\n    @pytest.mark.parametrize(\"raise_error\", (True, False))\n    def test_logging_all_levels(self, log_level: str, raise_error):\n        cmd = [\n            \"tango\",\n            \"--log-level\",\n            log_level,\n            \"run\",\n            str(self.FIXTURES_ROOT / \"experiment\" / \"noisy.jsonnet\"),\n            \"-w\",\n            str(self.TEST_DIR),\n            \"-o\",\n            json.dumps({\"steps.noisy_step.raise_error\": raise_error}),\n        ]\n        result = subprocess.run(cmd, capture_output=True)\n        run_dir = next((self.TEST_DIR / \"runs\").iterdir())\n        if raise_error:\n            assert result.returncode == 1\n        else:\n            assert result.returncode == 0\n        _, cleaned_log_lines = self.check_logs(run_dir, result)\n\n        # Debug messages.\n        assert cleaned_log_lines.count(\"debug message from cli_logger\") == 1\n        assert cleaned_log_lines.count(\"debug message\") == (1 if log_level == \"debug\" else 0)\n\n        # Info messages.\n        assert cleaned_log_lines.count(\"info message from cli_logger\") == 1\n        assert cleaned_log_lines.count(\"info message\") == (\n            1 if log_level in {\"debug\", \"info\"} else 0\n        )\n\n        # Warning messages.\n        assert cleaned_log_lines.count(\"warning message from cli_logger\") == 1\n        assert cleaned_log_lines.count(\"warning message\") == (\n            1 if log_level in {\"debug\", \"info\", \"warning\"} else 0\n        )\n\n        # Error messages.\n        assert cleaned_log_lines.count(\"error message from cli_logger\") == 1\n        assert cleaned_log_lines.count(\"error message\") == (\n            1 if log_level in {\"debug\", \"info\", \"warning\", \"error\"} else 0\n        )\n\n        # Traceback.\n        if raise_error:\n            assert \"Traceback (most recent call last):\" in cleaned_log_lines\n            assert \"ValueError: Oh no!\" in cleaned_log_lines\n\n    def test_deterministic_experiment(self):\n        cmd = [\n            \"tango\",\n            \"run\",\n            str(self.FIXTURES_ROOT / \"experiment\" / \"hello_world.jsonnet\"),\n            \"-w\",\n            str(self.TEST_DIR),\n        ]\n        result = subprocess.run(cmd, capture_output=True)\n        assert result.returncode == 0\n        assert len(os.listdir(self.TEST_DIR / \"cache\")) == 2\n        run_dir = next((self.TEST_DIR / \"runs\").iterdir())\n        assert (run_dir / \"hello\").is_dir()\n        assert (run_dir / \"hello\" / \"cache-metadata.json\").is_file()\n        assert (run_dir / \"hello_world\").is_dir()\n\n        # Check logs.\n        self.check_logs(run_dir, result)\n\n        # Running again shouldn't create any more directories in the cache.\n        result = subprocess.run(cmd)\n        assert result.returncode == 0\n        assert len(os.listdir(self.TEST_DIR / \"cache\")) == 2\n        # We should see two runs now.\n        assert len(os.listdir(self.TEST_DIR / \"runs\")) == 2\n\n    def test_experiment_with_memory_workspace(self):\n        cmd = [\n            \"tango\",\n            \"run\",\n            str(self.FIXTURES_ROOT / \"experiment\" / \"hello_world.jsonnet\"),\n            \"-w\",\n            \"memory://\",\n        ]\n        result = subprocess.run(cmd, capture_output=True)\n        assert result.returncode == 0\n\n    def test_experiment_with_default_workspace(self):\n        cmd = [\n            \"tango\",\n            \"run\",\n            str(self.FIXTURES_ROOT / \"experiment\" / \"hello_world.jsonnet\"),\n        ]\n        result = subprocess.run(cmd, capture_output=True)\n        assert result.returncode == 0\n\n    def test_random_experiment(self):\n        cmd = [\n            \"tango\",\n            \"run\",\n            str(self.FIXTURES_ROOT / \"experiment\" / \"random.jsonnet\"),\n            \"-w\",\n            str(self.TEST_DIR),\n        ]\n        result = subprocess.run(cmd)\n        assert result.returncode == 0\n\n    def test_run_name(self):\n        name = \"unique-tango-run-name\"\n        cmd = [\n            \"tango\",\n            \"run\",\n            str(self.FIXTURES_ROOT / \"experiment\" / \"hello_world.jsonnet\"),\n            \"-w\",\n            str(self.TEST_DIR),\n            \"--name\",\n            name,\n        ]\n        result = subprocess.run(cmd, capture_output=True)\n        run_dir = next((self.TEST_DIR / \"runs\").iterdir())\n        _, clean_log_lines = self.check_logs(run_dir, result)\n        assert result.returncode == 0\n        assert f\"Starting new run {name}\" == clean_log_lines[0]\n\n    @pytest.mark.parametrize(\"parallelism\", [1, 2])\n    @pytest.mark.parametrize(\"start_method\", [\"fork\", \"spawn\"])\n    @pytest.mark.parametrize(\"file_friendly_logging\", [True, False])\n    def test_experiment_with_logging_and_multiprocessing(\n        self, parallelism, start_method, file_friendly_logging\n    ):\n        cmd = (\n            [\n                \"tango\",\n                \"--log-level\",\n                \"info\",\n                \"--start-method\",\n                start_method,\n            ]\n            + ([] if not file_friendly_logging else [\"--file-friendly-logging\"])\n            + [\n                \"run\",\n                str(self.FIXTURES_ROOT / \"experiment\" / \"logging_check.jsonnet\"),\n                \"-w\",\n                str(self.TEST_DIR),\n                \"-j\",\n                str(parallelism),\n            ]\n        )\n        result = subprocess.run(cmd, capture_output=True)\n        run_dir = next((self.TEST_DIR / \"runs\").iterdir())\n        _, clean_log_lines = self.check_logs(run_dir, result, file_friendly_logging)\n        all_logs = \"\\n\".join(clean_log_lines)\n        assert \"[step stringA] 0 - This is a logging test.\" in clean_log_lines\n        assert \"[step stringC] 0 - This is also a logging test.\" in clean_log_lines\n        assert (\n            \"[step final_string] 0 - This is a logging test. This is being logged.\"\n            in clean_log_lines\n        )\n        # Make sure tqdm output makes it into the log file.\n        assert \"[step stringA] log progress: 100%\" in all_logs\n        assert \"[step stringC] log progress: 100%\" in all_logs\n        assert \"[step final_string] log progress: 100%\" in all_logs\n\n        # And logs from steps that contain multiprocessing themselves.\n        assert \"[step multiprocessing_result rank 0] Hello from worker 0!\" in all_logs\n        assert \"[step multiprocessing_result rank 1] Hello from worker 1!\" in all_logs\n        assert (\n            \"[step multiprocessing_result rank 0] Hello from the cli logger in worker 0!\"\n            in all_logs\n        )\n        assert (\n            \"[step multiprocessing_result rank 1] Hello from the cli logger in worker 1!\"\n            in all_logs\n        )\n\n        assert \"[step multiprocessing_result] progress from main process: 100%\" in all_logs\n\n\nclass TestSettings(TangoTestCase):\n    def setup_method(self):\n        super().setup_method()\n        self._wd_backup = os.getcwd()\n        os.chdir(self.TEST_DIR)\n        cmd = \"tango settings init -p ./tango.yml\".split(\" \")\n        subprocess.run(cmd, check=True)\n\n    def teardown_method(self):\n        os.chdir(self._wd_backup)\n        super().teardown_method()\n\n    @property\n    def settings(self) -> TangoGlobalSettings:\n        return TangoGlobalSettings.from_file(self.TEST_DIR / \"tango.yml\")\n\n    def test_settings_set_workspace(self):\n        cmd = \"tango settings set workspace ./workspace\".split(\" \")\n        subprocess.run(cmd, check=True)\n        assert self.settings.workspace == {\n            \"type\": \"local\",\n            \"dir\": str((self.TEST_DIR / \"workspace\").resolve()),\n        }\n\n    def test_settings_set_include_package(self):\n        cmd = \"tango settings set include-package tango.steps\".split(\" \")\n        subprocess.run(cmd, check=True)\n        assert self.settings.include_package == [\"tango.steps\"]\n\n    def test_settings_set_include_package_invalid(self):\n        cmd = \"tango settings set include-package foo\".split(\" \")\n        with pytest.raises(subprocess.CalledProcessError):\n            subprocess.run(cmd, check=True)\n\n    def test_settings_set_environment(self):\n        cmd = \"tango settings set env FOO BAR\".split(\" \")\n        subprocess.run(cmd, check=True)\n        assert self.settings.environment == {\"FOO\": \"BAR\"}\n\n    def test_settings_set_environment_blocked_var(self):\n        cmd = \"tango settings set env TANGO_LOG_LEVEL info\".split(\" \")\n        with pytest.raises(subprocess.CalledProcessError):\n            subprocess.run(cmd, check=True)\n"
  },
  {
    "path": "tests/step_caches/__init__.py",
    "content": ""
  },
  {
    "path": "tests/step_caches/local_step_cache_test.py",
    "content": "import pickle\nimport sys\n\nimport pytest\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.step import Step\nfrom tango.step_caches.local_step_cache import LocalStepCache\n\n\nclass DummyStep(Step):\n    def run(self, x: int) -> int:  # type: ignore[override]\n        return x\n\n\nclass TestLocalStepCache(TangoTestCase):\n    @pytest.mark.parametrize(\n        \"protocol\",\n        [pytest.param(protocol, id=f\"protocol={protocol}\") for protocol in range(4)]\n        + [\n            pytest.param(\n                5,\n                id=\"protocol=5\",\n                marks=pytest.mark.skipif(\n                    sys.version_info < (3, 8), reason=\"Protocol 5 requires Python 3.8 or newer\"\n                ),\n            ),\n        ],\n    )\n    def test_pickling(self, protocol: int):\n        step = DummyStep(step_name=\"dummy\", x=1)\n        step_cache = LocalStepCache(self.TEST_DIR)\n        step_cache[step] = 1\n        assert step in step_cache\n        assert step.unique_id in step_cache.strong_cache\n        pickled_step_cache = pickle.dumps(step_cache, protocol=protocol)\n        unpickled_step_cache = pickle.loads(pickled_step_cache)\n        assert step.unique_id not in unpickled_step_cache.strong_cache\n        assert step in unpickled_step_cache\n"
  },
  {
    "path": "tests/step_graph_test.py",
    "content": "import re\nfrom copy import deepcopy\nfrom tempfile import NamedTemporaryFile\n\nimport pytest\n\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.testing import TangoTestCase\nfrom tango.common.testing.steps import (  # noqa: F401\n    AddNumbersStep,\n    ConcatStringsStep,\n    StringStep,\n)\nfrom tango.step_graph import StepGraph\n\n\nclass TestStepGraph(TangoTestCase):\n    def test_ordered_steps(self):\n        step_graph = StepGraph.from_params(\n            {\n                \"stepB\": {\n                    \"type\": \"add_numbers\",\n                    \"a_number\": 2,\n                    \"b_number\": 3,\n                },\n                \"stepC\": {\n                    \"type\": \"add_numbers\",\n                    \"a_number\": {\"type\": \"ref\", \"ref\": \"stepB\"},\n                    \"b_number\": 5,\n                },\n                \"stepA\": {\n                    \"type\": \"add_numbers\",\n                    \"a_number\": 3,\n                    \"b_number\": 1,\n                },\n            }\n        )\n\n        result = StepGraph.ordered_steps(step_graph.parsed_steps)\n        assert [res.name for res in result] == [\"stepB\", \"stepC\", \"stepA\"]\n\n    def test_from_file(self):\n        step_graph = StepGraph.from_file(self.FIXTURES_ROOT / \"experiment\" / \"hello_world.jsonnet\")\n        assert \"hello\" in step_graph\n        assert \"hello_world\" in step_graph\n\n    def test_missing_type(self):\n        with pytest.raises(ConfigurationError, match=re.escape('key \"type\" is required')):\n            StepGraph.from_params(\n                {\n                    \"step3\": {\n                        \"a_number\": 3,\n                        \"b_number\": 1,\n                    },\n                }\n            )\n\n    def test_direct_construction(self):\n        step_a = AddNumbersStep(a_number=3, b_number=2, step_name=\"stepA\")\n        step_b = AddNumbersStep(a_number=step_a, b_number=2, step_name=\"stepB\")\n        step_graph = StepGraph({\"stepA\": step_a, \"stepB\": step_b})\n        assert list(step_graph.parsed_steps.keys()) == [\"stepA\", \"stepB\"]\n\n    def test_direct_construction_missing_dependency(self):\n        step_a = AddNumbersStep(a_number=3, b_number=2, step_name=\"stepA\")\n        step_b = AddNumbersStep(a_number=step_a, b_number=2, step_name=\"stepB\")\n        with pytest.raises(ConfigurationError, match=\"Or a missing dependency\"):\n            StepGraph({\"stepB\": step_b})\n\n    def test_to_file(self):\n        step_graph = StepGraph.from_file(self.FIXTURES_ROOT / \"experiment\" / \"hello_world.jsonnet\")\n\n        with NamedTemporaryFile(\n            prefix=\"test-step-graph-to-file-\", suffix=\".jsonnet\", dir=self.TEST_DIR\n        ) as file_ref:\n            step_graph.to_file(file_ref.name)\n\n            new_step_graph = StepGraph.from_file(file_ref.name)\n            assert step_graph == new_step_graph\n\n    def test_to_file_without_config(self):\n        from tango.format import JsonFormat\n\n        step_a = AddNumbersStep(a_number=3, b_number=2, step_name=\"stepA\", cache_results=False)\n        step_b = AddNumbersStep(\n            a_number=step_a, b_number=2, step_name=\"stepB\", step_format=JsonFormat(\"gz\")\n        )\n        step_graph = StepGraph({\"stepA\": step_a, \"stepB\": step_b})\n\n        with NamedTemporaryFile(\n            prefix=\"test-step-graph-to-file-without-config\", suffix=\".jsonnet\", dir=self.TEST_DIR\n        ) as file_ref:\n            step_graph.to_file(file_ref.name)\n            new_step_graph = StepGraph.from_file(file_ref.name)\n            assert step_graph == new_step_graph\n\n    def test_with_step_indexer(self):\n        config = {\n            \"list\": {\"type\": \"range_step\", \"start\": 0, \"end\": 3},\n            \"added\": {\n                \"type\": \"add_numbers\",\n                \"a_number\": 2,\n                \"b_number\": {\"type\": \"ref\", \"ref\": \"list\", \"key\": 1},\n            },\n        }\n        step_graph = StepGraph.from_params(deepcopy(config))  # type: ignore[arg-type]\n        assert [s.name for s in step_graph[\"added\"].dependencies] == [\"list\"]\n        assert step_graph.to_config() == config\n\n    def test_with_forced_dependencies(self):\n        config = {\n            \"some_string\": {\n                \"type\": \"string\",\n                \"result\": \"I should run second\",\n                \"step_extra_dependencies\": [{\"type\": \"ref\", \"ref\": \"other_string\"}],\n            },\n            \"other_string\": {\"type\": \"string\", \"result\": \"I should run first\"},\n            \"added\": {\n                \"type\": \"concat_strings\",\n                \"string1\": \"Some string:\",\n                \"string2\": {\"type\": \"ref\", \"ref\": \"some_string\"},\n            },\n        }\n        step_graph = StepGraph.from_params(deepcopy(config))  # type: ignore[arg-type]\n        assert step_graph[\"some_string\"].dependencies == {step_graph[\"other_string\"]}\n        assert step_graph[\"added\"].recursive_dependencies == {\n            step_graph[\"other_string\"],\n            step_graph[\"some_string\"],\n        }\n"
  },
  {
    "path": "tests/step_info_test.py",
    "content": "import json\nfrom pathlib import Path\nfrom typing import Any\n\nfrom tango.common.testing.steps import FloatStep\nfrom tango.step import Step\nfrom tango.step_graph import StepGraph\nfrom tango.step_info import StepInfo\n\n\ndef test_step_info():\n    step = FloatStep(step_name=\"float\", result=1.0)\n    step_info = StepInfo.new_from_step(step)\n\n    # Check Git metadata.\n    if (Path.cwd() / \".git\").exists():\n        assert step_info.environment.git is not None\n        assert step_info.environment.git.commit is not None\n        assert step_info.environment.git.remote is not None\n        assert \"allenai/tango\" in step_info.environment.git.remote\n\n    # Check pip requirements.\n    assert step_info.environment.packages is not None\n\n    # Test serialization / deserialization.\n    serialized = json.dumps(step_info.to_json_dict())\n    deserialized = StepInfo.from_json_dict(json.loads(serialized))\n    assert deserialized == step_info\n\n\ndef test_step_info_with_step_dependency():\n    \"\"\"Checks that the StepInfo config is not parsed to a Step if it has dependencies on upstream steps\"\"\"\n\n    @Step.register(\"foo\", exist_ok=True)\n    class FooStep(Step):\n        def run(self, bar: Any) -> str:  # type: ignore\n            return \"foo\" + bar\n\n    @Step.register(\"bar\", exist_ok=True)\n    class BarStep(Step):\n        def run(self) -> str:  # type: ignore\n            return \"Hey!\"\n\n    graph = StepGraph.from_params(\n        {\n            \"foo\": {\n                \"type\": \"foo\",\n                \"bar\": {\"type\": \"ref\", \"ref\": \"bar\"},\n            },\n            \"bar\": {\n                \"type\": \"bar\",\n            },\n        }\n    )\n    step = graph[\"foo\"]\n    step_info = StepInfo.new_from_step(step)\n\n    step_info_json = json.dumps(step_info.to_json_dict())\n    step_info = StepInfo.from_json_dict(json.loads(step_info_json))\n    assert isinstance(step_info.config, dict)\n"
  },
  {
    "path": "tests/step_test.py",
    "content": "import collections\nfrom typing import Any, Dict, Mapping, MutableMapping\n\nimport pytest\n\nfrom tango import StepGraph\nfrom tango.common import Params, Registrable\nfrom tango.common.exceptions import ConfigurationError\nfrom tango.common.from_params import FromParams\nfrom tango.common.testing import TangoTestCase\nfrom tango.step import FunctionalStep, Step, step\nfrom tango.workspaces import MemoryWorkspace\n\n\nclass TestStep(TangoTestCase):\n    def test_from_params(self):\n        step = Step.from_params({\"type\": \"float\", \"result\": 3})\n        result = step.result()\n        assert result == 3\n\n    def test_from_params_wrong_type(self):\n        with pytest.raises(TypeError):\n            Step.from_params({\"type\": \"float\", \"result\": \"not a float\"})\n\n    def test_step_with_from_params_input(self):\n        class Bar(FromParams):\n            def __init__(self, x: int):\n                self.x = x\n\n        @Step.register(\"foo\", exist_ok=True)\n        class FooStep(Step):\n            def run(self, bar: Bar) -> Bar:  # type: ignore\n                return bar\n\n        step = Step.from_params({\"type\": \"foo\", \"bar\": {\"x\": 1}})\n        assert step.result().x == 1\n\n    def test_no_hash_arguments(self):\n        @Step.register(\"no_hash_step\")\n        class SkipArgStep(Step):\n            SKIP_ID_ARGUMENTS = {\"arg\"}\n\n            def run(self, arg: str) -> int:  # type: ignore\n                return 5\n\n        step1 = SkipArgStep(arg=\"foo\")\n        step2 = SkipArgStep(arg=\"bar\")\n        assert step1.unique_id == step2.unique_id\n\n    def test_skip_default_arguments(self):\n        class SkipArgStep(Step):\n            def run(self) -> int:  # type: ignore\n                return 5\n\n        old_hash = SkipArgStep().unique_id\n\n        class SkipArgStep(Step):\n            SKIP_DEFAULT_ARGUMENTS = {\"arg\": 5}\n\n            def run(self, arg: int = 5) -> int:  # type: ignore\n                return arg\n\n        assert SkipArgStep().unique_id == old_hash\n        assert SkipArgStep(arg=5).unique_id == old_hash\n        assert SkipArgStep(arg=6).unique_id != old_hash\n\n    def test_massage_kwargs(self):\n        class CountLettersStep(Step):\n            @classmethod\n            def massage_kwargs(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]:\n                kwargs = kwargs.copy()\n                kwargs[\"text\"] = kwargs[\"text\"].lower()\n                return kwargs\n\n            def run(self, text: str) -> Mapping[str, int]:  # type: ignore\n                text = text.lower()\n                counter: MutableMapping[str, int] = collections.Counter()\n                for c in text:\n                    counter[c] += 1\n                return counter\n\n        upper = CountLettersStep(text=\"FOO\")\n        lower = CountLettersStep(text=\"foo\")\n        assert upper.unique_id == lower.unique_id\n        assert upper.result() == lower.result()\n\n    def test_default_args(self):\n        class DefaultArgStep(Step[int]):\n            def run(self, left: int, right: int = 0) -> int:  # type: ignore\n                return left + right\n\n        explicit = DefaultArgStep(left=1, right=0)\n        implicit = DefaultArgStep(left=1)\n\n        assert explicit.unique_id == implicit.unique_id\n        assert explicit.result() == implicit.result()\n\n    def test_steps_in_params(self):\n        class Widget(Registrable):\n            def __init__(self, x: int):\n                self.x = x\n\n        @Widget.register(\"gizmo\")\n        class GizmoWidget(Widget):\n            def __init__(self, x: int):\n                super().__init__(x * x)\n\n        @Step.register(\"consumer\")\n        class WidgetConsumerStep(Step):\n            def run(self, widget: Widget):  # type: ignore\n                return widget.x\n\n        @Step.register(\"producer\")\n        class WidgetProducerStep(Step):\n            def run(self, x: int) -> Widget:  # type: ignore\n                return GizmoWidget(x)\n\n        config = {\n            \"widget_producer\": Params({\"type\": \"producer\", \"x\": 4}),\n            \"widget_consumer\": Params(\n                {\"type\": \"consumer\", \"widget\": {\"type\": \"ref\", \"ref\": \"widget_producer\"}}\n            ),\n        }\n\n        sg = StepGraph.from_params(config)\n        assert len(sg[\"widget_consumer\"].dependencies) > 0\n\n        class WidgetHolder(Registrable):\n            def __init__(self, widget: Widget):\n                self.widget = widget\n\n        @WidgetHolder.register(\"gizmo\")\n        class GizmoWidgetHolder(WidgetHolder):\n            def __init__(self, gizmo: GizmoWidget):\n                super().__init__(gizmo)\n\n        @Step.register(\"holder_consumer\")\n        class WidgetHolderConsumerStep(Step):\n            def run(self, widget_holder: WidgetHolder) -> int:  # type: ignore\n                return widget_holder.widget.x\n\n        config = {\n            \"widget_producer\": Params({\"type\": \"producer\", \"x\": 4}),\n            \"holder_consumer\": Params(\n                {\n                    \"type\": \"holder_consumer\",\n                    \"widget_holder\": {\n                        \"type\": \"gizmo\",\n                        \"gizmo\": {\"type\": \"ref\", \"ref\": \"widget_producer\"},\n                    },\n                }\n            ),\n        }\n        sg = StepGraph.from_params(config)\n        assert len(sg[\"holder_consumer\"].dependencies) > 0\n\n    def test_functional_step(self):\n        class Bar(FromParams):\n            def __init__(self, x: int):\n                self.x = x\n\n        @step(exist_ok=True)\n        def foo(bar: Bar) -> int:\n            return bar.x\n\n        assert issubclass(foo, FunctionalStep)\n        assert foo().run(Bar(x=1)) == 1\n\n        foo_step = Step.from_params({\"type\": \"foo\", \"bar\": {\"x\": 1}})\n        assert isinstance(foo_step, FunctionalStep)\n        assert isinstance(foo_step.kwargs[\"bar\"], Bar)\n\n    def test_bound_functional_step(self):\n        class Bar(FromParams):\n            def __init__(self, x: int):\n                self.x = x\n\n        @step(exist_ok=True, bind=True)\n        def foo(self, bar: Bar) -> int:\n            assert self.work_dir.is_dir()\n            return bar.x\n\n        foo_step = Step.from_params({\"type\": \"foo\", \"bar\": {\"x\": 1}})\n        assert isinstance(foo_step, FunctionalStep)\n        assert foo_step.result(MemoryWorkspace()) == 1\n\n    def test_bound_functional_step_missing_self(self):\n        @step(exist_ok=True, bind=True)\n        def foo(x: int) -> int:\n            return x\n\n        with pytest.raises(ConfigurationError):\n            Step.from_params({\"type\": \"foo\", \"x\": 1})\n\n        @step(exist_ok=True, bind=True)\n        def bar(s, x: int) -> int:\n            return x\n\n        with pytest.raises(ConfigurationError):\n            Step.from_params({\"type\": \"bar\", \"x\": 1})\n"
  },
  {
    "path": "tests/steps/__init__.py",
    "content": ""
  },
  {
    "path": "tests/steps/dataset_remix_test.py",
    "content": "from tango.common.dataset_dict import DatasetDict\nfrom tango.steps.dataset_remix import DatasetRemixStep\n\n\ndef test_dataset_remix_step():\n    step = DatasetRemixStep(\"remix\")\n    dataset_dict = DatasetDict(\n        {\n            \"train\": list(range(10)),\n            \"dev\": list(range(10, 15)),\n            \"test\": list(range(15, 20)),\n        }\n    )\n    result = step.run(\n        input=dataset_dict,\n        new_splits={\n            \"all_train\": \"train + dev\",\n            \"cross_val_train\": \"train[:8]\",\n            \"cross_val_dev\": \"train[-2:]\",\n        },\n    )\n    assert len(result[\"all_train\"]) == len(dataset_dict[\"train\"]) + len(dataset_dict[\"dev\"])\n"
  },
  {
    "path": "tests/steps/shell_step_test.py",
    "content": "import pytest\n\nfrom tango.common.testing import TangoTestCase\nfrom tango.steps.shell_step import ShellStep, make_registrable\n\n\nclass TestShellStep(TangoTestCase):\n    def test_shell_step(self):\n        step = ShellStep()\n        result = step.run(\"echo hello\")\n        assert isinstance(result, str)\n        assert result == \"hello\\n\"\n\n    def test_shell_step_failure(self):\n        step = ShellStep()\n        with pytest.raises(RuntimeError):\n            step.run(\"ls -l non_existent_path\")\n\n    def test_shell_step_with_output_path(self, caplog):\n        output_path = self.TEST_DIR / \"test-folder\"\n        step = ShellStep()\n        step.run(f\"mkdir {output_path}\", output_path=output_path)\n        assert f\"Output found at: {output_path}\" in caplog.text\n\n    def test_shell_step_different_validation(self, caplog):\n        @make_registrable(exist_ok=True)\n        def validate_func(path):\n            \"\"\"\n            Validates that the file contents of the `path` are a json string.\n            \"\"\"\n            import json\n\n            with open(path) as f:\n                json.load(f)\n\n        output_path = self.TEST_DIR / \"hello.json\"\n        command = f\"python3 -c \\\"import json; print(json.dumps({{'a': 23}}))\\\" > {output_path}\"\n        step = ShellStep()\n        step.run(command, output_path=output_path, validate_output=validate_func, shell=True)\n        assert f\"Output found at: {output_path}\" in caplog.text\n\n    def test_shell_step_in_config(self, caplog):\n        output_path = str(self.TEST_DIR / \"test-folder\")\n        config = {\n            \"steps\": {\n                \"create_dir\": {\n                    \"type\": \"shell_step\",\n                    \"shell_command\": f\"mkdir {output_path}\",\n                    \"output_path\": output_path,\n                    \"validate_output\": {\"type\": \"check_path_existence\"},\n                },\n            }\n        }\n\n        # Regular run contains all step outputs.\n        self.run(config)\n        assert f\"Output found at: {output_path}\" in caplog.text\n"
  },
  {
    "path": "tests/workspaces/__init__.py",
    "content": ""
  },
  {
    "path": "tests/workspaces/local_workspace_test.py",
    "content": "from shutil import copytree\n\nimport pytest\nfrom sqlitedict import SqliteDict\n\nfrom tango import Step\nfrom tango.common.testing import TangoTestCase\nfrom tango.step_info import StepState\nfrom tango.workspaces import LocalWorkspace\n\n\nclass AdditionStep(Step):\n    def run(self, a: int, b: int) -> int:  # type: ignore\n        return a + b\n\n\nclass TestLocalWorkspace(TangoTestCase):\n    def test_local_workspace_one_step(self):\n        workspace = LocalWorkspace(self.TEST_DIR)\n        step = AdditionStep(a=1, b=2)\n\n        with pytest.raises(KeyError):\n            # This can't possibly work because the workspace has never seen that step before.\n            step_info = workspace.step_info(step.unique_id)\n            assert step_info.state == StepState.INCOMPLETE\n        step_info = workspace.step_info(step)\n        assert step_info.state == StepState.INCOMPLETE\n\n        result = step.result(workspace)\n        assert result == 3\n\n        step_info = workspace.step_info(step.unique_id)\n        assert step_info.state == StepState.COMPLETED\n        step_info = workspace.step_info(step)\n        assert step_info.state == StepState.COMPLETED\n\n    def test_local_workspace_two_steps(self):\n        workspace = LocalWorkspace(self.TEST_DIR)\n        step1 = AdditionStep(a=1, b=2)\n        step2 = AdditionStep(a=step1, b=3)\n\n        step_info = workspace.step_info(step2)\n        assert step_info.state == StepState.INCOMPLETE\n        step_info = workspace.step_info(step2.unique_id)\n        assert step_info.state == StepState.INCOMPLETE\n        assert step1.unique_id in step_info.dependencies\n        step_info = workspace.step_info(step1.unique_id)\n        assert step_info.state == StepState.INCOMPLETE\n        step_info = workspace.step_info(step1)\n        assert step_info.state == StepState.INCOMPLETE\n\n        result = step2.result(workspace)\n        assert result == 6\n\n        for step in [step1, step2]:\n            step_info = workspace.step_info(step.unique_id)\n            assert step_info.state == StepState.COMPLETED\n            step_info = workspace.step_info(step)\n            assert step_info.state == StepState.COMPLETED\n\n    def test_local_workspace_upgrade_v1_to_v2(self):\n        workspace_dir = self.TEST_DIR / \"workspace\"\n        copytree(\n            self.FIXTURES_ROOT / \"v1_local_workspace\",\n            workspace_dir,\n            symlinks=True,\n        )\n        workspace = LocalWorkspace(workspace_dir)\n        step_info = workspace.step_info(\"SubtractionStep-YCdedqjmmd9GUFi96VzPXD5tAVho3CTz\")\n        assert step_info.state == StepState.COMPLETED\n        dependencies = list(step_info.dependencies)\n\n        # Make sure all the dependencies are there.\n        while len(dependencies) > 0:\n            step_info = workspace.step_info(dependencies.pop())\n            dependencies.extend(step_info.dependencies)\n\n    def test_remove_step(self):\n        workspace = LocalWorkspace(self.TEST_DIR)\n        step = AdditionStep(a=1, b=2)\n        workspace.step_starting(step)\n        workspace.step_finished(step, 1.0)\n\n        with SqliteDict(workspace.step_info_file) as d:\n            assert step.unique_id in d\n\n        cache = workspace.step_cache\n        assert step in cache\n\n        workspace.remove_step(step.unique_id)\n\n        with SqliteDict(workspace.step_info_file) as d:\n            assert step.unique_id not in d\n\n        cache = workspace.step_cache\n        assert step not in cache\n"
  },
  {
    "path": "tests/workspaces/memory_workspace_test.py",
    "content": "from tango.common.testing.steps import FloatStep\nfrom tango.workspaces import MemoryWorkspace\n\n\ndef test_remove_step():\n    workspace = MemoryWorkspace()\n    step = FloatStep(step_name=\"float\", result=1.0)\n\n    workspace.step_starting(step)\n    workspace.step_finished(step, 1.0)\n    cache = workspace.step_cache\n\n    assert step.unique_id in workspace.unique_id_to_info\n    assert step in cache\n\n    workspace.remove_step(step.unique_id)\n    cache = workspace.step_cache\n\n    assert step.unique_id not in workspace.unique_id_to_info\n    assert step not in cache\n"
  }
]