[
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Create a report to help us improve Alpa\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**Please describe the bug**\n\n**Please describe the expected behavior**\n\n**System information and environment**\n- OS Platform and Distribution (e.g., Linux Ubuntu 16.04, docker):\n- Python version:\n- CUDA version:\n- NCCL version:\n- cupy version:\n- GPU model and memory:\n- Alpa version:\n- TensorFlow version:\n- JAX version:\n\n**To Reproduce**\nSteps to reproduce the behavior:\n1.\n2.\n3.\n4. See error\n\n**Screenshots**\nIf applicable, add screenshots to help explain your problem.\n\n**Code snippet to reproduce the problem**\n\n**Additional information**\nAdd any other context about the problem here or include any logs that would be helpful to diagnose the problem."
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Suggest a new feature for Alpa\ntitle: ''\nlabels: ''\nassignees: ''\n\n---\n\n**System information**\n- Alpa version:\n- Are you willing to contribute it (Yes/No):\n\n**Describe the new feature and the current behavior/state**\n\n**Will this change the current API? How?**\n\n**Describe alternatives you've considered**\n\n**Additional context**"
  },
  {
    "path": ".github/workflows/build_jaxlib.yml",
    "content": "name: Build Jaxlib\n\non:\n  workflow_dispatch:\n    inputs:\n      tensorflow:\n        description: 'TensorFlow-alpa branch to build'\n        required: true\n        default: 'master'\n\n\nenv:\n  TF_BRANCH: ${{ github.event.inputs.tensorflow }}\n\n\njobs:\n  build_jaxlib:\n    name: Build JaxLib wheels\n    runs-on: [self-hosted]\n    # change the following to build with\n    #   Python： 3.7, 3.8. 3.9\n    #   CUDA 11.1, 11.2, 11.3\n    # Using github matrix\n\n    steps:\n      - name: Cancel previous\n        uses: styfle/cancel-workflow-action@0.9.1\n        with:\n          access_token: ${{ secrets.PAT_TOKEN }}\n        if: ${{github.ref != 'refs/head/main'}}\n\n      # checkout repo\n      - uses: actions/checkout@v3\n\n      - name: clean up images\n        run: |\n          docker image prune -f\n\n      - name: build image\n        run: |\n          docker build -t build-jaxlib-image -f docker/build_jaxlib.Dockerfile docker/\n\n      - name: Compile Jaxlib\n        run: |\n          mkdir -p dist\n          docker run --gpus all --tmpfs /build:exec \\\n          --rm -v $(pwd)/dist:/dist build-jaxlib-image \\\n          3.8 cuda 11.1 main ${TF_BRANCH##*/}\n\n      # change this to publishing to pypi\n      - name: Publish to local\n        run: |\n          echo \"Move the Jaxlib binary\"\n          mv dist/*.whl /data/alpa-dist/jaxlib-alpa-ci/\n"
  },
  {
    "path": ".github/workflows/ci.yml",
    "content": "name: CI\n\non:\n  workflow_run:\n    workflows: [Build Jaxlib and Jax]\n    types:\n      - completed\n  workflow_dispatch:\n  push:\n    branches: [main]\n  pull_request:\n    branches: [main]\n\njobs:\n  yapf:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.7\"]\n    steps:\n    - uses: actions/checkout@v2\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v2\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install yapf==0.32.0\n    - name: Running yapf\n      run: |\n        yapf --diff --style .style.yapf --recursive alpa && yapf --diff --style .style.yapf --recursive tests\n\n  pylint:\n    runs-on: ubuntu-latest\n    strategy:\n      matrix:\n        python-version: [\"3.7\"]\n    steps:\n    - uses: actions/checkout@v2\n    - name: Set up Python ${{ matrix.python-version }}\n      uses: actions/setup-python@v2\n      with:\n        python-version: ${{ matrix.python-version }}\n    - name: Install dependencies\n      run: |\n        python -m pip install --upgrade pip\n        pip install pylint==2.14.0\n    - name: Analysing the code with pylint\n      run: |\n        pylint alpa\n\n  Unittest:\n    runs-on: [self-hosted, gpu]\n    needs: [yapf, pylint]\n    steps:\n      - name: Cancel previous\n        uses: styfle/cancel-workflow-action@0.9.1\n        with:\n          access_token: ${{ secrets.PAT_TOKEN }}\n        if: |\n          github.event_name =='pull_request' &&\n          github.event.pull_request.head.repo.full_name == github.repository\n\n      - uses: actions/checkout@v3\n\n      - name: clean up images\n        run: |\n          docker image prune -f\n\n      - name: build test image\n        run: |\n          docker build -t test-alpa-image -f docker/unittest.Dockerfile docker/\n\n      - name: Test\n        run: |\n          ALPA_BRANCH=${{ github.ref }}\n          echo \"${ALPA_BRANCH}\"\n          \n          docker run --gpus all --tmpfs /build:exec --rm \\\n          -v /data/alpa-dist:/alpa-dist \\\n          --shm-size=10.24gb test-alpa-image 3.8 ${ALPA_BRANCH}\n"
  },
  {
    "path": ".github/workflows/docs.yml",
    "content": "# This workflow will generate docs for alpa.\n\nname: Docs\n\non:\n  workflow_dispatch:\n\njobs:\n  build_docs:\n    runs-on: [self-hosted, alpa]\n\n    steps:\n      - uses: actions/checkout@v3\n\n      - name: Set up Python 3.8\n        uses: actions/setup-python@v2\n        with:\n          python-version: 3.8\n\n      - name: build doc-building image\n        run: |\n          docker build -t build-alpa-doc -f docker/build_doc.Dockerfile docker/\n\n      - name: Build docs\n        run: |          \n          docker run --gpus all --tmpfs /build:exec --rm \\\n          -v /data/alpa-dist:/alpa-dist \\\n          --shm-size=10.24gb \\\n          build-alpa-doc\n\n      - name: Deploy\n        uses: peaceiris/actions-gh-pages@v3\n        with:\n          personal_token: ${{ secrets.PAT_TOKEN }}\n          external_repository: alpa-projects/alpa-projects.github.io\n          publish_branch: master\n          publish_dir: /data/alpa-dist/docs\n          keep_files: true\n"
  },
  {
    "path": ".github/workflows/release_alpa.yml",
    "content": "name: Release Alpa\n\non:\n  release:\n    types: [created]\n  workflow_dispatch:\n\nenv:\n  TWINE_USERNAME: \"__token__\"\n  TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}\n\njobs:\n\n  build-image:\n    runs-on: [self-hosted]\n\n    steps:\n      - uses: actions/checkout@v3\n\n      - name: clean up images\n        run: |\n          docker image prune -f\n\n      - name: build docker image\n        run: |\n          docker build -t build-alpa-image -f docker/build_alpa.Dockerfile docker/\n\n  release-alpa:\n    runs-on: [self-hosted]\n    needs: [build-image]\n\n    steps:\n      - uses: actions/checkout@v3\n\n      - name: Build Alpa wheels\n        run: |\n          mkdir -p dist\n          docker run --gpus all --tmpfs /build:exec \\\n          --rm -v $(pwd)/dist:/dist --entrypoint /build_alpa.sh \\\n          build-alpa-image 3.8 ${ALPA_BRANCH}\n        env:\n          ALPA_BRANCH: ${{ github.ref }}\n\n      - name: Set up Python 3.8\n        uses: actions/setup-python@v3\n        with:\n          python-version: 3.8\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          pip install twine\n\n      - name: Publish to Pypi\n        run: |\n          echo \"Publish to PyPI\"\n          ls -ltr dist/\n          python -m twine upload --verbose dist/*\n"
  },
  {
    "path": ".github/workflows/release_jaxlib.yml",
    "content": "name: Release Jaxlib\n\non:\n  release:\n    types: [created]\n  workflow_dispatch:\n    inputs:\n      tensorflow:\n        description: 'TensorFlow-alpa branch to build'\n        required: true\n        default: 'master'\n\njobs:\n\n  clean-up:\n    runs-on: [self-hosted]\n\n    steps:\n      - name: clean up images\n        run: |\n          docker image prune -f\n\n  build-jaxlib:\n    runs-on: [self-hosted]\n    needs: [clean-up]\n    strategy:\n      matrix:\n        cuda: [\"11.1\", \"11.2\", \"11.3\"]\n        python: [\"3.7\", \"3.8\", \"3.9\"]\n\n    steps:\n      - uses: actions/checkout@v3\n\n      - name: build image\n        run: |\n          docker build -t build-jaxlib-image-cuda${CUDA_VERSION} \\\n            -f docker/build_jaxlib.Dockerfile docker/ \\\n            --build-arg JAX_CUDA_VERSION=${CUDA_VERSION}\n        env:\n          CUDA_VERSION: ${{ matrix.cuda }}\n\n      - name: Compile Jaxlib\n        run: |\n          mkdir -p /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}\n          echo \"Compile Python ${PYTHON_VERSION}, CUDA ${CUDA_VERSION}, ALPA BRANCH: ${ALPA_BRANCH}, TF_BRANCH: ${TF_BRANCH}\"\n          if [[ ${{ github.event_name }} == \"release\" ]]; then\n            docker run --gpus all --tmpfs /build:exec \\\n              --rm -v /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}:/dist \\\n              build-jaxlib-image-cuda${CUDA_VERSION} ${PYTHON_VERSION} \\\n              cuda ${CUDA_VERSION} ${ALPA_BRANCH}\n          else\n            docker run --gpus all --tmpfs /build:exec \\\n              --rm -v /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}:/dist \\\n              build-jaxlib-image-cuda${CUDA_VERSION} ${PYTHON_VERSION} \\\n              cuda ${CUDA_VERSION} ${ALPA_BRANCH} ${TF_BRANCH}\n          fi\n        env:\n          CUDA_VERSION: ${{ matrix.cuda }}\n          PYTHON_VERSION: ${{ matrix.python }}\n          ALPA_BRANCH: ${{ github.ref }}\n          TF_BRANCH: ${{ github.event.inputs.tensorflow }}\n\n      - name: Move CUDA${{ matrix.cuda }}\n        run: |\n          echo \"Move to one single folder\"\n          ls /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}\n          mv /data/alpa-dist/jaxlib-alpa/cuda${CUDA_VERSION//.}/*.whl /data/alpa-pypi/packages/\n        env:\n          CUDA_VERSION: ${{ matrix.cuda }}\n\n  publish:\n    runs-on: [self-hosted]\n    needs: [build-jaxlib]\n    steps:\n      - name: Set up Python 3.8\n        uses: actions/setup-python@v2\n        with:\n          python-version: 3.8\n\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip\n          python -m pip install github3.py requests\n\n      - uses: actions/checkout@v3\n        with:\n          fetch-depth: 0\n\n      - name: Get latest tag\n        id: latesttag\n        uses: \"WyriHaximus/github-action-get-previous-tag@v1\"\n\n      - name: Upload wheels\n        run: |\n          echo \"Upload wheels to tag ${TAG}\"\n          ls /data/alpa-pypi/packages/\n          python build_jaxlib/release/wheel_upload.py --tag ${TAG} --path /data/alpa-pypi/packages/\n        env:\n          GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }}\n          TAG: ${{ steps.latesttag.outputs.tag }}\n\n      - name: \"Generate and update PyPI index\"\n        env:\n          GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }}\n          TAG: ${{ steps.latesttag.outputs.tag }}\n        run: |\n          git clone https://$GITHUB_TOKEN@github.com/alpa-projects/alpa-projects.github.io\n          cd alpa-projects.github.io\n          git config user.name github-actions\n          git config user.email github-actions@github.com\n          cd ..\n          python build_jaxlib/release/generate_pypi_index.py --tag ${TAG}\n"
  },
  {
    "path": ".gitignore",
    "content": "# Python cache\n__pycache__\n*.pyc\ndist\n*.egg-info\n.cache\n*env\n\n# NFS temp files\n.nfs*\n\n# Vim\n*.swp\n\n# pycharm\n.idea\n\n# vscode\n*vscode*\n\n# Build files\nalpa/pipeline_parallel/xla_custom_call_marker/build\nbuild/lib\nbuild/bdist*\nbuild_jaxlib/build/bazel*\nbuild_jaxlib/bazel-*\nbuild_jaxlib/.jax_configure.bazelrc\nbuild_jaxlib/dist\n\n# Examples build and tmp files\nexamples/build/\nexamples/imagenet/imagenet\nexamples/llm_serving/dataset/*.so\nexamples/llm_serving/dataset/*.c\nexamples/llm_serving/dataset/*.cpp\nexamples/llm_serving/weblogs\nexamples/llm_serving/keys_file.json\nexamples/llm_serving/benchmark/tmp*\nexamples/llm_serving/tmp*\nexamples/opt_finetune/output/\nexamples/gpt2/norwegian-gpt2/\nalpa_debug_info\n\n# Analysis temp files\n*.nvprof\n*.prof\n*.tsv\n*.hlo\n*.pkl\nbenchmark/alpa/tmp*\nbenchmark/alpa/chrome_trace\n*.log\n\n# Tests temp files\ntests/tmp\ntests/*/tmp\n\n# Dataset\nbenchmark/deepspeed/data\n\n# plots\nbenchmark/*.pdf\n\n# Numpy cache\n*.npy\n\n# Documentation website build\ndocs/_build\ndocs/tutorials\n\n# macOS temp files\n.DS_Store\n"
  },
  {
    "path": ".gitmodules",
    "content": "[submodule \"third_party/jax\"]\n\tpath = third_party/jax\n\turl = https://github.com/google/jax.git\n[submodule \"third_party/tensorflow-alpa\"]\n\tpath = third_party/tensorflow-alpa\n\turl = https://github.com/alpa-projects/tensorflow-alpa.git\n"
  },
  {
    "path": ".pylintrc",
    "content": "# This Pylint rcfile contains a best-effort configuration to uphold the\n# best-practices and style described in the Google Python style guide:\n#   https://google.github.io/styleguide/pyguide.html\n#\n# Its canonical open-source location is:\n#   https://google.github.io/styleguide/pylintrc\n\n[MASTER]\n\n# Files or directories to be skipped. They should be base names, not paths.\nignore=benchmark,docs,examples,playground,third_party,model\n\n# Files or directories matching the regex patterns are skipped. The regex\n# matches against base names, not paths.\nignore-patterns=\n\n# Pickle collected data for later comparisons.\npersistent=no\n\n# List of plugins (as comma separated values of python modules names) to load,\n# usually to register additional checkers.\nload-plugins=\n\n# Use multiple processes to speed up Pylint.\njobs=4\n\n# Allow loading of arbitrary C extensions. Extensions are imported into the\n# active Python interpreter and may run arbitrary code.\nunsafe-load-any-extension=no\n\n\n[MESSAGES CONTROL]\n\n# Only show warnings with the listed confidence levels. Leave empty to show\n# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED\nconfidence=\n\n# Enable the message, report, category or checker with the given id(s). You can\n# either give multiple identifier separated by comma (,) or put this option\n# multiple time (only on the command line, not in the configuration file where\n# it should appear only once). See also the \"--disable\" option for examples.\n#enable=\n\n# Disable the message, report, category or checker with the given id(s). You\n# can either give multiple identifiers separated by comma (,) or put this\n# option multiple times (only on the command line, not in the configuration\n# file where it should appear only once).You can also use \"--disable=all\" to\n# disable everything first and then reenable specific checks. For example, if\n# you want to run only the similarities checker, you can use \"--disable=all\n# --enable=similarities\". If you want to run only the classes checker, but have\n# no Warning level messages displayed, use\"--disable=all --enable=classes\n# --disable=W\"\ndisable=abstract-method,\n        apply-builtin,\n        arguments-differ,\n        attribute-defined-outside-init,\n        backtick,\n        bad-option-value,\n        basestring-builtin,\n        buffer-builtin,\n        c-extension-no-member,\n        consider-using-enumerate,\n        cmp-builtin,\n        cmp-method,\n        coerce-builtin,\n        coerce-method,\n        delslice-method,\n        div-method,\n        duplicate-code,\n        eq-without-hash,\n        execfile-builtin,\n        file-builtin,\n        filter-builtin-not-iterating,\n        fixme,\n        getslice-method,\n        global-statement,\n        hex-method,\n        idiv-method,\n        implicit-str-concat-in-sequence,\n        import-error,\n        import-self,\n        import-star-module-level,\n        inconsistent-return-statements,\n        input-builtin,\n        intern-builtin,\n        invalid-str-codec,\n        locally-disabled,\n        logging-format-interpolation,  # FIXME(alpa): make pass.\n        logging-fstring-interpolation,  # FIXME(alpa): make pass.\n        long-builtin,\n        long-suffix,\n        map-builtin-not-iterating,\n        misplaced-comparison-constant,\n        missing-function-docstring,\n        metaclass-assignment,\n        next-method-called,\n        next-method-defined,\n        no-absolute-import,\n        no-else-break,\n        no-else-continue,\n        no-else-raise,\n        no-else-return,\n        no-init,  # added\n        no-member,\n        no-name-in-module,\n        no-self-use,\n        nonzero-method,\n        oct-method,\n        old-division,\n        old-ne-operator,\n        old-octal-literal,\n        old-raise-syntax,\n        parameter-unpacking,\n        print-statement,\n        raising-string,\n        range-builtin-not-iterating,\n        raw_input-builtin,\n        rdiv-method,\n        reduce-builtin,\n        relative-import,\n        reload-builtin,\n        round-builtin,\n        setslice-method,\n        signature-differs,\n        standarderror-builtin,\n        suppressed-message,\n        sys-max-int,\n        too-few-public-methods,\n        too-many-ancestors,\n        too-many-arguments,\n        too-many-boolean-expressions,\n        too-many-branches,\n        too-many-instance-attributes,\n        too-many-locals,\n        too-many-nested-blocks,\n        too-many-public-methods,\n        too-many-return-statements,\n        too-many-statements,\n        trailing-newlines,\n        unichr-builtin,\n        unicode-builtin,\n        unnecessary-pass,\n        unpacking-in-except,\n        unspecified-encoding,\n        useless-else-on-loop,\n        useless-object-inheritance,\n        useless-suppression,\n        using-cmp-argument,\n        wrong-import-order,\n        xrange-builtin,\n        zip-builtin-not-iterating,\n\n\n[REPORTS]\n\n# Set the output format. Available formats are text, parseable, colorized, msvs\n# (visual studio) and html. You can also give a reporter class, eg\n# mypackage.mymodule.MyReporterClass.\noutput-format=text\n\n# Tells whether to display a full report or only the messages\nreports=no\n\n# Python expression which should return a note less than 10 (10 is the highest\n# note). You have access to the variables errors warning, statement which\n# respectively contain the number of errors / warnings messages and the total\n# number of statements analyzed. This is used by the global evaluation report\n# (RP0004).\nevaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)\n\n# Template used to display messages. This is a python new-style format string\n# used to format the message information. See doc for all details\n#msg-template=\n\n\n[BASIC]\n\n# Good variable names which should always be accepted, separated by a comma\ngood-names=main,_\n\n# Bad variable names which should always be refused, separated by a comma\nbad-names=\n\n# Colon-delimited sets of names that determine each other's naming style when\n# the name regexes allow several styles.\nname-group=\n\n# Include a hint for the correct naming format with invalid-name\ninclude-naming-hint=no\n\n# List of decorators that produce properties, such as abc.abstractproperty. Add\n# to this list to register other decorators that produce valid properties.\nproperty-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl\n\n# Regular expression matching correct function names\nfunction-rgx=^(?:(?P<exempt>setUp|tearDown|setUpModule|tearDownModule)|(?P<camel_case>_?[A-Z][a-zA-Z0-9]*)|(?P<snake_case>_?[a-z][a-z0-9_]*))$\n\n# Regular expression matching correct variable names\nvariable-rgx=^[a-z][a-z0-9_]*$\n\n# Regular expression matching correct constant names\nconst-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$\n\n# Regular expression matching correct attribute names\nattr-rgx=^_{0,2}[a-z][a-z0-9_]*$\n\n# Regular expression matching correct argument names\nargument-rgx=^[a-z][a-z0-9_]*$\n\n# Regular expression matching correct class attribute names\nclass-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$\n\n# Regular expression matching correct inline iteration names\ninlinevar-rgx=^[a-z][a-z0-9_]*$\n\n# Regular expression matching correct class names\nclass-rgx=^_?[A-Z][a-zA-Z0-9]*$\n\n# Regular expression matching correct module names\nmodule-rgx=^(_?[a-z][a-z0-9_]*|__init__)$\n\n# Regular expression matching correct method names\nmethod-rgx=(?x)^(?:(?P<exempt>_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P<camel_case>_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P<snake_case>_{0,2}[a-z][a-z0-9_]*))$\n\n# Regular expression which should only match function or class names that do\n# not require a docstring.\nno-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$\n\n# Minimum line length for functions/classes that require docstrings, shorter\n# ones are exempt.\ndocstring-min-length=10\n\n\n[TYPECHECK]\n\n# List of decorators that produce context managers, such as\n# contextlib.contextmanager. Add to this list to register other decorators that\n# produce valid context managers.\ncontextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager\n\n# Tells whether missing members accessed in mixin class should be ignored. A\n# mixin class is detected if its name ends with \"mixin\" (case insensitive).\nignore-mixin-members=yes\n\n# List of module names for which member attributes should not be checked\n# (useful for modules/projects where namespaces are manipulated during runtime\n# and thus existing member attributes cannot be deduced by static analysis. It\n# supports qualified module names, as well as Unix pattern matching.\nignored-modules=\n\n# List of class names for which member attributes should not be checked (useful\n# for classes with dynamically set attributes). This supports the use of\n# qualified names.\nignored-classes=optparse.Values,thread._local,_thread._local\n\n# List of members which are set dynamically and missed by pylint inference\n# system, and so shouldn't trigger E1101 when accessed. Python regular\n# expressions are accepted.\ngenerated-members=\n\n\n[FORMAT]\n\n# Maximum number of characters on a single line.\nmax-line-length=80\n\n# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt\n# lines made too long by directives to pytype.\n\n# Regexp for a line that is allowed to be longer than the limit.\nignore-long-lines=(?x)(\n  ^\\s*(\\#\\ )?<?https?://\\S+>?$|\n  ^\\s*(from\\s+\\S+\\s+)?import\\s+.+$)\n\n# Allow the body of an if to be on the same line as the test if there is no\n# else.\nsingle-line-if-stmt=yes\n\n# Maximum number of lines in a module\nmax-module-lines=99999\n\n# String used as indentation unit.  The internal Google style guide mandates 2\n# spaces.  Google's externaly-published style guide says 4, consistent with\n# PEP 8.  Here, we use 2 spaces, for conformity with many open-sourced Google\n# projects (like TensorFlow).\nindent-string='    '\n\n# Number of spaces of indent required inside a hanging  or continued line.\nindent-after-paren=4\n\n# Expected format of line ending, e.g. empty (any line ending), LF or CRLF.\nexpected-line-ending-format=\n\n\n[MISCELLANEOUS]\n\n# List of note tags to take in consideration, separated by a comma.\nnotes=TODO\n\n\n[STRING]\n\n# This flag controls whether inconsistent-quotes generates a warning when the\n# character used as a quote delimiter is used inconsistently within a module.\ncheck-quote-consistency=yes\n\n\n[VARIABLES]\n\n# Tells whether we should check for unused import in __init__ files.\ninit-import=no\n\n# A regular expression matching the name of dummy variables (i.e. expectedly\n# not used).\ndummy-variables-rgx=^\\*{0,2}(_$|unused_|dummy_)\n\n# List of additional names supposed to be defined in builtins. Remember that\n# you should avoid to define new builtins when possible.\nadditional-builtins=\n\n# List of strings which can identify a callback function by name. A callback\n# name must start or end with one of those strings.\ncallbacks=cb_,_cb\n\n# List of qualified module names which can have objects that can redefine\n# builtins.\nredefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools\n\n\n[LOGGING]\n\n# Logging modules to check that the string format arguments are in logging\n# function parameter format\nlogging-modules=logging,absl.logging,tensorflow.io.logging\n\n\n[SIMILARITIES]\n\n# Minimum lines number of a similarity.\nmin-similarity-lines=4\n\n# Ignore comments when computing similarities.\nignore-comments=yes\n\n# Ignore docstrings when computing similarities.\nignore-docstrings=yes\n\n# Ignore imports when computing similarities.\nignore-imports=no\n\n\n[SPELLING]\n\n# Spelling dictionary name. Available dictionaries: none. To make it working\n# install python-enchant package.\nspelling-dict=\n\n# List of comma separated words that should not be checked.\nspelling-ignore-words=\n\n# A path to a file that contains private dictionary; one word per line.\nspelling-private-dict-file=\n\n# Tells whether to store unknown words to indicated private dictionary in\n# --spelling-private-dict-file option instead of raising a message.\nspelling-store-unknown-words=no\n\n\n[IMPORTS]\n\n# Deprecated modules which should not be used, separated by a comma\ndeprecated-modules=regsub,\n                   TERMIOS,\n                   Bastion,\n                   rexec,\n                   sets\n\n# Create a graph of every (i.e. internal and external) dependencies in the\n# given file (report RP0402 must not be disabled)\nimport-graph=\n\n# Create a graph of external dependencies in the given file (report RP0402 must\n# not be disabled)\next-import-graph=\n\n# Create a graph of internal dependencies in the given file (report RP0402 must\n# not be disabled)\nint-import-graph=\n\n# Force import order to recognize a module as part of the standard\n# compatibility libraries.\nknown-standard-library=\n\n# Force import order to recognize a module as part of a third party library.\nknown-third-party=enchant, absl\n\n# Analyse import fallback blocks. This can be used to support both Python 2 and\n# 3 compatible code, which means that the block might have code that exists\n# only in one or another interpreter, leading to false positives when analysed.\nanalyse-fallback-blocks=no\n\n\n[CLASSES]\n\n# List of method names used to declare (i.e. assign) instance attributes.\ndefining-attr-methods=__init__,\n                      __new__,\n                      setUp\n\n# List of member names, which should be excluded from the protected access\n# warning.\nexclude-protected=_asdict,\n                  _fields,\n                  _replace,\n                  _source,\n                  _make\n\n# List of valid names for the first argument in a class method.\nvalid-classmethod-first-arg=cls,\n                            class_\n\n# List of valid names for the first argument in a metaclass class method.\nvalid-metaclass-classmethod-first-arg=mcs\n\n\n[EXCEPTIONS]\n\n# Exceptions that will emit a warning when being caught. Defaults to\n# \"Exception\"\novergeneral-exceptions=StandardError,\n                       Exception,\n                       BaseException\n"
  },
  {
    "path": ".style.yapf",
    "content": "[style]\nbased_on_style = google\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright 2021- The Alpa team. All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "**Note: Alpa is not actively maintained currently. It is available as a research artifact. The core algorithm in Alpa has been merged into XLA, which is still being maintained. https://github.com/openxla/xla/tree/main/xla/hlo/experimental/auto_sharding**\n\n\n<div align=\"center\">\n<img src=\"https://github.com/alpa-projects/alpa/blob/main/docs/logo/alpa-logo-cropped.png\" alt=\"logo\" width=\"250\"></img>\n<br></br>\n</div>\n\n[![CI](https://github.com/alpa-projects/alpa/actions/workflows/ci.yml/badge.svg)](https://github.com/alpa-projects/alpa/actions/workflows/ci.yml)\n[![Build Jaxlib](https://github.com/alpa-projects/alpa/actions/workflows/build_jaxlib.yml/badge.svg)](https://github.com/alpa-projects/alpa/actions/workflows/build_jaxlib.yml)\n\n[**Documentation**](https://alpa-projects.github.io) | [**Slack**](https://forms.gle/YEZTCrtZD6EAVNBQ7)\n\nAlpa is a system for training and serving large-scale neural networks.\n\nScaling neural networks to hundreds of billions of parameters has enabled dramatic breakthroughs such as GPT-3, but training and serving these large-scale neural networks require complicated distributed system techniques.\nAlpa aims to automate large-scale distributed training and serving with just a few lines of code.\n\nThe key features of Alpa include:  \n\n💻 **Automatic Parallelization**. Alpa automatically parallelizes users' single-device code on distributed clusters with data, operator, and pipeline parallelism. \n\n🚀 **Excellent Performance**. Alpa achieves linear scaling on training models with billions of parameters on distributed clusters.\n\n✨ **Tight Integration with Machine Learning Ecosystem**. Alpa is backed by open-source, high-performance, and production-ready libraries such as [Jax](https://github.com/google/jax), [XLA](https://www.tensorflow.org/xla), and [Ray](https://github.com/ray-project/ray).\n\n## Serving\nThe code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference.\nDetailed documentation is in [Serving OPT-175B using Alpa](https://alpa-projects.github.io/tutorials/opt_serving.html).\n\n```python\nfrom transformers import AutoTokenizer\nfrom llm_serving.model.wrapper import get_model\n\n# Load the tokenizer\ntokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-2.7b\")\ntokenizer.add_bos_token = False\n\n# Load the model. Alpa automatically downloads the weights to the specificed path\nmodel = get_model(model_name=\"alpa/opt-2.7b\", path=\"~/opt_weights/\")\n\n# Generate\nprompt = \"Paris is the capital city of\"\n\ninput_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\noutput = model.generate(input_ids=input_ids, max_length=256, do_sample=True)\ngenerated_string = tokenizer.batch_decode(output, skip_special_tokens=True)\n\nprint(generated_string)\n```\n\n## Training\nUse Alpa's decorator ``@parallelize`` to scale your single-device training code to distributed clusters.\nCheck out the [documentation](https://alpa-projects.github.io) site and\n[examples](https://github.com/alpa-projects/alpa/tree/main/examples) folder\nfor installation instructions, tutorials, examples, and more.\n\n```python\nimport alpa\n\n# Parallelize the training step in Jax by simply using a decorator\n@alpa.parallelize\ndef train_step(model_state, batch):\n    def loss_func(params):\n        out = model_state.forward(params, batch[\"x\"])\n        return jnp.mean((out - batch[\"y\"]) ** 2)\n\n    grads = grad(loss_func)(model_state.params)\n    new_model_state = model_state.apply_gradient(grads)\n    return new_model_state\n\n# The training loop now automatically runs on your designated cluster\nmodel_state = create_train_state()\nfor batch in data_loader:\n    model_state = train_step(model_state, batch)\n```\n\n## Learning more\n- [Papers](docs/publications/publications.rst)\n- [Google AI blog](https://ai.googleblog.com/2022/05/alpa-automated-model-parallel-deep.html)\n- [OSDI 2022 talk slides](https://docs.google.com/presentation/d/1CQ4S1ff8yURk9XmL5lpQOoMMlsjw4m0zPS6zYDcyp7Y/edit?usp=sharing)\n- [ICML 2022 big model tutorial](https://sites.google.com/view/icml-2022-big-model/home)\n- [GTC 2023 talk video](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51337/)\n\n## Getting Involved\n- Connect to Alpa developers via the [Alpa slack](https://forms.gle/YEZTCrtZD6EAVNBQ7).\n- Please read the [contributor guide](https://alpa-projects.github.io/developer/developer_guide.html) if you are interested in contributing code.\n\n## License\nAlpa is licensed under the [Apache-2.0 license](https://github.com/alpa-projects/alpa/blob/main/LICENSE).\n"
  },
  {
    "path": "alpa/__init__.py",
    "content": "\"\"\"Alpa is a system for training large-scale neural networks.\"\"\"\n# Import all public packages\nfrom . import api\nfrom . import collective\nfrom . import create_state_parallel\nfrom . import data_loader\nfrom . import device_mesh\nfrom . import follow_parallel\nfrom . import global_env\nfrom . import mesh_executable\nfrom . import mesh_profiling\nfrom . import monkey_patch\nfrom . import parallel_method\nfrom . import parallel_plan\nfrom . import pipeline_parallel\nfrom . import shard_parallel\nfrom . import timer\nfrom . import util\nfrom . import version\nfrom . import wrapped_hlo\n\n# Short cuts\nfrom alpa.api import (init, shutdown, parallelize, grad, value_and_grad,\n                      clear_executable_cache)\nfrom alpa.data_loader import DataLoader, MeshDriverDataLoader\nfrom alpa.device_mesh import (\n    DeviceCluster, PhysicalDeviceMesh, LocalPhysicalDeviceMesh,\n    DistributedPhysicalDeviceMesh, DistributedArray, prefetch,\n    get_global_cluster, get_global_physical_mesh,\n    get_global_virtual_physical_mesh, set_global_virtual_physical_mesh,\n    set_seed, get_global_num_devices)\nfrom alpa.global_env import global_config\nfrom alpa.mesh_profiling import ProfilingResultDatabase\nfrom alpa.parallel_method import (ShardParallel, DataParallel, Zero2Parallel,\n                                  Zero3Parallel, PipeshardParallel,\n                                  CreateStateParallel, FollowParallel,\n                                  get_3d_parallel_method)\nfrom alpa.parallel_plan import plan_to_method\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\nfrom alpa.pipeline_parallel.layer_construction import (manual_remat,\n                                                       automatic_remat,\n                                                       ManualLayerOption,\n                                                       AutoLayerOption)\nfrom alpa.pipeline_parallel.stage_construction import (ManualStageOption,\n                                                       AutoStageOption,\n                                                       UniformStageOption)\nfrom alpa.shard_parallel.auto_sharding import AutoShardingOption\nfrom alpa.shard_parallel.manual_sharding import ManualShardingOption\nfrom alpa.serialization import save_checkpoint, restore_checkpoint\nfrom alpa.timer import timers\nfrom alpa.version import __version__\n"
  },
  {
    "path": "alpa/api.py",
    "content": "\"\"\"Top-level user API.\"\"\"\nfrom typing import Callable, Optional, Sequence, Union\n\nfrom jax import linear_util as lu\nfrom jax._src import api, traceback_util\nfrom jax._src.util import HashableFunction\nfrom jax.api_util import (argnums_partial, donation_vector,\n                          flatten_fun_nokwargs, rebase_donate_argnums)\nfrom jax.core import AbstractValue\nfrom jax.experimental.maps import FrozenDict\nfrom jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef\n\nfrom alpa.device_mesh import init_global_cluster, shutdown_global_cluster\nfrom alpa.parallel_method import ParallelMethod, ShardParallel\nfrom alpa.pipeline_parallel.primitive_def import mark_gradient\nfrom alpa.util import (auto_donate_argnums, auto_static_argnums,\n                       abstractify_with_aval, GradFuncTransformContext)\nfrom alpa.version import check_alpa_jaxlib_version\n\ntraceback_util.register_exclusion(__file__)\n\nis_initialized = False\n\n\ndef init(cluster: str = \"ray\",\n         cluster_address: Optional[str] = None,\n         num_nodes: Optional[int] = None,\n         num_devices_per_node: Optional[int] = None,\n         namespace: Optional[str] = \"alpa_default_space\"):\n    \"\"\"Initialize the global environment.\n\n    `devices_per_node, num_nodes` are used to specify the number of devices.\n    If not specified, the number of devices is determined automatically and\n    the whole cluster is used.\n\n    For simplicity, the resource specification is only supported for\n    ray cluster.\n\n    Args:\n      cluster: The distributed cluster.\n        Possible choices: {\"local\", \"ray\"}.\n        \"local\" means using all local devices on a single node.\n        \"ray\" means using all devices in a ray cluster.\n      cluster_address: Address of the distributed cluster.\n        If cluster is \"ray\", this parameter can be used to specify a different\n          address that will be used to initialize the ray cluster.\n          E.g., \"ray://123.45.67.89:10001\". If not specified, \"auto\" will be\n          used instead.\n        Ignored if cluster is \"local\".\n      num_nodes: The number of nodes.\n      num_devices_per_node: The number of devices per node.\n    \"\"\"\n    global is_initialized\n\n    if is_initialized:\n        return\n    is_initialized = True\n\n    init_global_cluster(cluster, cluster_address, num_nodes,\n                        num_devices_per_node, namespace)\n\n\ndef shutdown():\n    \"\"\"Shutdown the global environment.\"\"\"\n    global is_initialized\n    assert is_initialized is True\n    is_initialized = False\n    shutdown_global_cluster()\n\n\ndef parallelize(fun: Optional[Callable] = None,\n                *,\n                static_argnums: Union[Sequence[int], str] = \"auto\",\n                donate_argnums: Union[Sequence[int], str] = \"auto\",\n                batch_argnums: Union[Sequence[int], str] = (1,),\n                method: Optional[ParallelMethod] = None):\n    \"\"\"\n    Parallelize a jax function.\n\n    Args:\n        fun: The function to be parallelized.\n        static_argnums: The same as the static_argnums argument of jax.jit.\n          If it is \"auto\", alpa uses heuristic rules to infer this.\n        donate_argnums: The same as the donate_argnums argument of jax.jit.\n          If it is \"auto\", alpa uses heuristic rules to infer this.\n        batch_argnums: The indices of arguments that are the data batch.\n          This information is used to split the original data batch into micro\n          batches to perform gradient accumulation or pipeline parallelism.\n          Alpa assumes the 0-th dimension of the tensor is the batch dimension.\n        method: The parallelization method.\n    \"\"\"\n    check_alpa_jaxlib_version()\n\n    def decorate_fun(fun):\n        api._check_callable(fun)  # pylint: disable=protected-access\n        nonlocal method\n        method = method or ShardParallel()\n        return ParallelizedFunc(fun, static_argnums, donate_argnums,\n                                batch_argnums, method)\n\n    if fun is None:\n        return decorate_fun\n    return decorate_fun(fun)\n\n\nclass ParallelizedFunc:\n    \"\"\"The function after being transformed by alpa.parallelize.\"\"\"\n\n    def __init__(\n        self,\n        fun: Callable,\n        static_argnums: Union[Sequence[int], str],\n        donate_argnums: Union[Sequence[int], str],\n        batch_argnums: Union[Sequence[int], str],\n        method: ParallelMethod,\n    ):\n        self.fun = fun\n        self.static_argnums = static_argnums\n        self.donate_argnums = donate_argnums\n        self.batch_argnums = batch_argnums\n        self.method = method\n\n        self.last_executable = None\n\n    @traceback_util.api_boundary\n    def __call__(self, *args):\n        \"\"\"Launch the computation on the driver.\"\"\"\n        executable, _, out_tree, args_flat = (\n            self._decode_args_and_get_executable(*args))\n        out = executable.launch_on_driver(*args_flat)\n        return tree_unflatten(out_tree(), out)\n\n    def get_executable(self, *args):\n        \"\"\"Get the compiled exectuable.\"\"\"\n        executable, _, _, _ = self._decode_args_and_get_executable(*args)\n        return executable\n\n    def preshard_dynamic_args(self, *args):\n        \"\"\"Shard the dynamic arguments.\"\"\"\n        executable, in_tree, _, args_flat = (\n            self._decode_args_and_get_executable(*args))\n        sharded_args = executable.preshard_dynamic_args(*args_flat)\n        return tree_unflatten(in_tree, sharded_args)\n\n    def get_last_executable(self):\n        \"\"\"Return the last compiled executable for this function.\"\"\"\n        return self.last_executable\n\n    def _decode_args_and_get_executable(self, *args):\n        \"\"\"Flatten PyTree arguments and get the executable.\"\"\"\n        static_argnums, donate_argnums, batch_argnums = (self.static_argnums,\n                                                         self.donate_argnums,\n                                                         self.batch_argnums)\n        kwargs = {}\n\n        f = lu.wrap_init(self.fun)\n\n        # Deal with static arguments and extract dynamic arguments\n        if static_argnums == \"auto\":\n            static_argnums = auto_static_argnums(args)\n\n        if static_argnums:\n            dyn_argnums = [\n                i for i in range(len(args)) if i not in static_argnums\n            ]\n            # Freeze static dict to make it hashable\n            frozen_args = []\n            for i, arg in enumerate(args):\n                if i in static_argnums and isinstance(arg, dict):\n                    frozen_args.append(FrozenDict(arg))\n                else:\n                    frozen_args.append(arg)\n            f, dyn_args = argnums_partial(f, dyn_argnums, frozen_args)\n        else:\n            dyn_args = args\n\n        # Flatten pytree arguments\n        args_flat, in_tree = tree_flatten(dyn_args)\n        f, out_tree = flatten_fun_nokwargs(f, in_tree)\n        # pylint: disable=unnecessary-lambda\n        out_tree_hashable = HashableFunction(lambda: out_tree(), closure=None)\n\n        # Deal with donate argnums\n        if donate_argnums == \"auto\":\n            donate_argnums = auto_donate_argnums(args)\n\n        donate_tuple = rebase_donate_argnums(donate_argnums, static_argnums)\n        if donate_tuple:\n            donated_invars = donation_vector(donate_tuple, dyn_args, kwargs)\n        else:\n            donated_invars = (False,) * len(args_flat)\n\n        # Deal with batch argnums\n        batch_tuple = rebase_donate_argnums(batch_argnums, static_argnums)\n        batch_invars = donation_vector(batch_tuple, dyn_args, kwargs)\n\n        # Compile\n        abstract_args = map(abstractify_with_aval, args_flat)\n        executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,\n                                                  static_argnums,\n                                                  donated_invars, batch_invars,\n                                                  self.method, *abstract_args)\n\n        self.last_executable = executable\n        return executable, in_tree, out_tree, args_flat\n\n\n@lu.cache\ndef _compile_parallel_executable(\n    fun: lu.WrappedFun,\n    in_tree: PyTreeDef,\n    out_tree_thunk: Callable[[], PyTreeDef],\n    static_argnums: Sequence[int],\n    donated_invars: Sequence[bool],\n    batch_invars: Sequence[bool],\n    method: ParallelMethod,\n    *avals: Sequence[AbstractValue],\n):\n    \"\"\"Cached parallelized callable.\"\"\"\n    # Clean stores for the next call\n    for store in fun.stores:\n        if store:\n            store.reset()\n    batch_invars = list(batch_invars)\n    for idx, aval in enumerate(avals):\n        if len(aval.shape) == 0:\n            batch_invars[idx] = False\n    batch_invars = tuple(batch_invars)\n\n    # Compile a callable\n    return method.compile_executable(fun, in_tree, out_tree_thunk,\n                                     static_argnums, donated_invars,\n                                     batch_invars, *avals)\n\n\ndef clear_executable_cache():\n    \"\"\"Clear all cached executables.\"\"\"\n    _compile_parallel_executable.cache_clear()\n\n\ndef grad(*args, **kwargs):\n    \"\"\"This is the same as jax.grad, except that alpa inserts a\n    gradient marker after the gradient computation.\n\n    This function annotates all gradient tensors. This information is used to\n    perform gradient accumulation transformation.\n    If any auxiliary tensors are returned, they are averaged over mini batches\n    in the same way as how the gradients are averaged.\n    \"\"\"\n\n    def ret(*call_args, **call_kwargs):\n        # Apply transformations (e.g., layer construction, rematerialization)\n        # to the forward func\n        arg_list = list(args)\n        for transform in GradFuncTransformContext.transforms:\n            arg_list[0] = transform(arg_list[0])\n        grad_func = api.grad(*arg_list, **kwargs)\n\n        grads = grad_func(*call_args, **call_kwargs)\n        return mark_gradient(grads)\n\n    return ret\n\n\ndef value_and_grad(*args, **kwargs):\n    \"\"\"This is the same as jax.value_and_grad, except that alpa inserts a\n    gradient marker after the gradient computation.\n\n\n    This function annotates all gradient tensors. This information is used to\n    perform gradient accumulation transformation.\n    If any auxiliary tensors are returned, they are averaged over mini batches\n    in the same way as how the gradients are averaged.\n    \"\"\"\n\n    def ret(*call_args, **call_kwargs):\n        # Apply transformations (e.g., layer construction, rematerialization)\n        # to the forward func\n        arg_list = list(args)\n        for transform in GradFuncTransformContext.transforms:\n            arg_list[0] = transform(arg_list[0])\n        grad_func = api.value_and_grad(*arg_list, **kwargs)\n\n        val, grads = grad_func(*call_args, **call_kwargs)\n        return mark_gradient((val, grads))\n\n    return ret\n"
  },
  {
    "path": "alpa/collective/__init__.py",
    "content": "\"\"\"Alpa's wrapper for NCCL collective operations.\"\"\"\n\nfrom alpa.collective.collective import (\n    nccl_available, gloo_available, is_group_initialized, init_collective_group,\n    destroy_collective_group, create_collective_group, get_rank,\n    get_collective_group_size, allreduce, allreduce_multigpu, barrier, reduce,\n    reduce_multigpu, broadcast, broadcast_partialgpu, broadcast_multigpu,\n    allgather, allgather_multigpu, reducescatter, reducescatter_multigpu, send,\n    send_multigpu, recv, recv_multigpu, check_and_get_group, record_events,\n    wait_events, comm_wait_compute, compute_wait_comm)\n\n__all__ = [\n    \"nccl_available\", \"gloo_available\", \"is_group_initialized\",\n    \"init_collective_group\", \"destroy_collective_group\",\n    \"create_collective_group\", \"get_rank\", \"get_collective_group_size\",\n    \"allreduce\", \"allreduce_multigpu\", \"barrier\", \"reduce\", \"reduce_multigpu\",\n    \"broadcast\", \"broadcast_partialgpu\", \"broadcast_multigpu\", \"allgather\",\n    \"allgather_multigpu\", \"reducescatter\", \"reducescatter_multigpu\", \"send\",\n    \"send_multigpu\", \"recv\", \"recv_multigpu\", \"check_and_get_group\",\n    \"record_events\", \"wait_events\", \"comm_wait_compute\", \"compute_wait_comm\"\n]\n"
  },
  {
    "path": "alpa/collective/collective.py",
    "content": "\"\"\"APIs exposed under the namespace ray.util.collective.\"\"\"\nimport logging\nimport os\nfrom typing import List\n\nimport numpy as np\nimport ray\nfrom jax._src.lib import xla_extension as xe\n\nfrom alpa.collective import types\nfrom alpa.global_env import global_config\nfrom alpa.util import try_import_ray_worker\n\nray_worker = try_import_ray_worker()\n\n_CUPY_NCCL_AVAILABLE = True\n_XLA_NCCL_AVAILABLE = True\n_GLOO_AVAILABLE = True\n\nlogger = logging.getLogger(__name__)\n\ntry:\n    from alpa.collective.collective_group.nccl_collective_group import (\n        NCCLGroup as CupyNcclGroup)\nexcept ImportError:\n    _CUPY_NCCL_AVAILABLE = False\n\ntry:\n    from alpa.collective.collective_group.xla_nccl_collective_group import (\n        XLANCCLGroup as XlaNcclGroup)\nexcept AttributeError:\n    _XLA_NCCL_AVAILABLE = False\n\ntry:\n    from alpa.collective.collective_group.gloo_collective_group import (\n        GLOOGroup)\nexcept ImportError:\n    _GLOO_AVAILABLE = False\n\n\ndef nccl_available():\n    if global_config.nccl_mode == \"cupy\":\n        if not _CUPY_NCCL_AVAILABLE:\n            logger.warning(\"Cupy's NCCL seems unavailable. Please install Cupy \"\n                           \"following the guide at: \"\n                           \"https://docs.cupy.dev/en/stable/install.html.\")\n        return _CUPY_NCCL_AVAILABLE\n    elif global_config.nccl_mode == \"xla_extension\":\n        if not _XLA_NCCL_AVAILABLE:\n            logger.warning(\"NCCL from xla_extention seems unavailable! \"\n                           \"Please check whether your local tensorflow-alpa \"\n                           \"has already been up-to-date. You could also set \"\n                           \"global_config.nccl_mode == \\\"cupy\\\" to \"\n                           \"use another set of nccl apis from cupy. \")\n        return _XLA_NCCL_AVAILABLE\n    else:\n        raise ValueError(f\"nccl mode {global_config.nccl_mode} is illegal\")\n\n\ndef get_nccl_group(world_size, rank, group_name):\n    assert nccl_available()\n    if global_config.nccl_mode == \"cupy\":\n        return CupyNcclGroup(world_size, rank, group_name)\n    elif global_config.nccl_mode == \"xla_extension\":\n        return XlaNcclGroup(world_size, rank, group_name)\n    else:\n        raise ValueError(f\"nccl mode {global_config.nccl_mode} is illegal\")\n\n\ndef gloo_available():\n    return _GLOO_AVAILABLE\n\n\nclass GroupManager:\n    \"\"\"Use this class to manage the collective groups we created so far.\n\n    Each process will have an instance of `GroupManager`. Each process\n    could belong to multiple collective groups. The membership information\n    and other metadata are stored in the global `_group_mgr` object.\n    \"\"\"\n\n    def __init__(self):\n        self._name_group_map = {}\n        self._group_name_map = {}\n\n    def create_collective_group(self, backend, world_size, rank, group_name):\n        \"\"\"The entry to create new collective groups in the manager.\n\n        Put the registration and the group information into the manager\n        metadata as well.\n        \"\"\"\n        backend = types.Backend(backend)\n        if backend == types.Backend.MPI:\n            raise RuntimeError(\"Ray does not support MPI.\")\n        if backend == types.Backend.GLOO:\n            logger.debug(f\"Creating GLOO group: '{group_name}'...\")\n            g = GLOOGroup(world_size,\n                          rank,\n                          group_name,\n                          store_type=\"redis\",\n                          device_type=\"tcp\")\n            self._name_group_map[group_name] = g\n            self._group_name_map[g] = group_name\n        if backend == types.Backend.NCCL:\n            logger.debug(f\"Creating NCCL group: '{group_name}'...\")\n            g = get_nccl_group(world_size, rank, group_name)\n            self._name_group_map[group_name] = g\n            self._group_name_map[g] = group_name\n        return self._name_group_map[group_name]\n\n    def is_group_exist(self, group_name):\n        return group_name in self._name_group_map\n\n    def get_group_by_name(self, group_name):\n        \"\"\"Get the collective group handle by its name.\"\"\"\n        if not self.is_group_exist(group_name):\n            logger.warning(f\"The group '{group_name}' is not initialized.\")\n            return None\n        return self._name_group_map[group_name]\n\n    def destroy_collective_group(self, group_name):\n        \"\"\"Group destructor.\"\"\"\n        if not self.is_group_exist(group_name):\n            logger.warning(f\"The group '{group_name}' does not exist.\")\n            return\n\n        # release the collective group resource\n        g = self._name_group_map[group_name]\n        # clean up the dicts\n        del self._group_name_map[g]\n        del self._name_group_map[group_name]\n        # Release the communicator resources\n        g.destroy_group()\n\n        # Release the detached actors spawned by `create_collective_group()`\n        name = \"info_\" + group_name\n        try:\n            store = ray.get_actor(name)\n            ray.kill(store)\n        except ValueError:\n            pass\n\n\n_group_mgr = GroupManager()\n\n\ndef is_group_initialized(group_name):\n    \"\"\"Check if the group is initialized in this process by the group name.\"\"\"\n    return _group_mgr.is_group_exist(group_name)\n\n\ndef init_collective_group(world_size: int,\n                          rank: int,\n                          backend=types.Backend.NCCL,\n                          group_name: str = \"default\"):\n    \"\"\"Initialize a collective group inside an actor process.\n\n    Args:\n        world_size (int): the total number of processes in the group.\n        rank (int): the rank of the current process.\n        backend: the CCL backend to use, NCCL or GLOO.\n        group_name (str): the name of the collective group.\n\n    Returns:\n        None\n    \"\"\"\n    _check_inside_actor()\n    backend = types.Backend(backend)\n    _check_backend_availability(backend)\n    # TODO(Hao): implement a group auto-counter.\n    if not group_name:\n        raise ValueError(f\"group_name '{group_name}' needs to be a string.\")\n\n    if _group_mgr.is_group_exist(group_name):\n        raise RuntimeError(\"Trying to initialize a group twice.\")\n\n    assert world_size > 0\n    assert rank >= 0\n    assert rank < world_size\n    _group_mgr.create_collective_group(backend, world_size, rank, group_name)\n\n\ndef create_collective_group(actors,\n                            world_size: int,\n                            ranks: List[int],\n                            backend=types.Backend.NCCL,\n                            group_name: str = \"default\"):\n    \"\"\"Declare a list of actors as a collective group.\n\n    Note: This function should be called in a driver process.\n\n    Args:\n        actors (list): a list of actors to be set in a collective group.\n        world_size (int): the total number of processes in the group.\n        ranks (List[int]): the rank of each actor.\n        backend: the CCL backend to use, NCCL or GLOO.\n        group_name (str): the name of the collective group.\n\n    Returns:\n        None\n    \"\"\"\n    backend = types.Backend(backend)\n    _check_backend_availability(backend)\n\n    name = \"info_\" + group_name\n    try:\n        ray.get_actor(name)\n        raise RuntimeError(\"Trying to initialize a group twice.\")\n    except ValueError:\n        pass\n\n    if len(ranks) != len(actors):\n        raise RuntimeError(\n            f\"Each actor should correspond to one rank. Got '{len(ranks)}' \"\n            f\"ranks but '{len(actors)}' actors\")\n\n    if set(ranks) != set(range(len(ranks))):\n        got_ranks = \"\".join([str(r) for r in ranks])\n        raise RuntimeError(\n            f\"Ranks must be a permutation from 0 to '{len(ranks)}'. \"\n            f\"Got '{got_ranks}'.\")\n\n    if world_size <= 0:\n        raise RuntimeError(\"World size must be greater than zero. \"\n                           f\"Got '{world_size}'.\")\n    if not all(ranks) >= 0:\n        raise RuntimeError(\"Ranks must be non-negative.\")\n    if not all(ranks) < world_size:\n        raise RuntimeError(\"Ranks cannot be greater than world_size.\")\n\n    # avoid a circular dependency\n    from alpa.collective.util import Info  # pylint: disable=import-outside-toplevel\n    # store the information into a NamedActor that can be accessed later.\n    name = \"info_\" + group_name\n    actors_id = [a._ray_actor_id for a in actors]  # pylint: disable=protected-access\n    # TODO (Dacheng): how do we recycle this name actor?\n    info = Info.options(name=name, lifetime=\"detached\").remote()\n    ray.get([info.set_info.remote(actors_id, world_size, ranks, backend)])\n\n\n# TODO (we need a declarative destroy() API here.)\ndef destroy_collective_group(group_name: str = \"default\") -> None:\n    \"\"\"Destroy a collective group given its group name.\"\"\"\n    _check_inside_actor()\n    _group_mgr.destroy_collective_group(group_name)\n\n\ndef get_rank(group_name: str = \"default\") -> int:\n    \"\"\"Return the rank of this process in the given group.\n\n    Args:\n        group_name (str): the name of the group to query\n\n    Returns:\n        the rank of this process in the named group,\n        -1 if the group does not exist or the process does\n        not belong to the group.\n    \"\"\"\n    _check_inside_actor()\n    if not is_group_initialized(group_name):\n        return -1\n    g = _group_mgr.get_group_by_name(group_name)\n    return g.rank\n\n\ndef get_collective_group_size(group_name: str = \"default\") -> int:\n    \"\"\"Return the size of the collective group with the given name.\n\n    Args:\n        group_name: the name of the group to query\n\n    Returns:\n        The world size of the collective group, -1 if the group does\n            not exist or the process does not belong to the group.\n    \"\"\"\n    _check_inside_actor()\n    if not is_group_initialized(group_name):\n        return -1\n    g = _group_mgr.get_group_by_name(group_name)\n    return g.world_size\n\n\ndef allreduce(tensor, group_name: str = \"default\", op=types.ReduceOp.SUM):\n    \"\"\"Collective allreduce the tensor across the group.\n\n    Args:\n        tensor: the tensor to be all-reduced on this process.\n        group_name (str): the collective group name to perform allreduce.\n        op: The reduce operation.\n\n    Returns:\n        None\n    \"\"\"\n    _check_single_tensor_input(tensor)\n    g = _check_and_get_group(group_name)\n    opts = types.AllReduceOptions\n    opts.reduce_op = op\n    g.allreduce([tensor], opts)\n\n\ndef allreduce_multigpu(tensor_list: list,\n                       group_name: str = \"default\",\n                       op=types.ReduceOp.SUM):\n    \"\"\"Collective allreduce a list of tensors across the group.\n\n    Args:\n        tensor_list (List[tensor]): list of tensors to be allreduced,\n            each on a GPU.\n        group_name (str): the collective group name to perform allreduce.\n\n    Returns:\n        None\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"Multigpu calls requires NCCL and Cupy.\")\n    _check_tensor_list_input(tensor_list)\n    g = _check_and_get_group(group_name)\n    opts = types.AllReduceOptions\n    opts.reduce_op = op\n    g.allreduce(tensor_list, opts)\n\n\ndef barrier(group_name: str = \"default\"):\n    \"\"\"Barrier all processes in the collective group.\n\n    Args:\n        group_name (str): the name of the group to barrier.\n\n    Returns:\n        None\n    \"\"\"\n    g = _check_and_get_group(group_name)\n    g.barrier()\n\n\ndef reduce(tensor,\n           dst_rank: int = 0,\n           group_name: str = \"default\",\n           op=types.ReduceOp.SUM):\n    \"\"\"Reduce the tensor across the group to the destination rank.\n\n    Args:\n        tensor: the tensor to be reduced on this process.\n        dst_rank (int): the rank of the destination process.\n        group_name (str): the collective group name to perform reduce.\n        op: The reduce operation.\n\n    Returns:\n        None\n    \"\"\"\n    _check_single_tensor_input(tensor)\n    g = _check_and_get_group(group_name)\n\n    # check dst rank\n    _check_rank_valid(g, dst_rank)\n    opts = types.ReduceOptions()\n    opts.reduce_op = op\n    opts.root_rank = dst_rank\n    opts.root_tensor = 0\n    g.reduce([tensor], opts)\n\n\ndef reduce_multigpu(tensor_list: list,\n                    dst_rank: int = 0,\n                    dst_tensor: int = 0,\n                    group_name: str = \"default\",\n                    op=types.ReduceOp.SUM):\n    \"\"\"Reduce the tensor across the group to the destination rank\n    and destination tensor.\n\n    Args:\n        tensor_list: the list of tensors to be reduced on this process;\n            each tensor located on a GPU.\n        dst_rank (int): the rank of the destination process.\n        dst_tensor: the index of GPU at the destination.\n        group_name (str): the collective group name to perform reduce.\n        op: The reduce operation.\n\n    Returns:\n        None\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"Multigpu calls requires NCCL and Cupy.\")\n    _check_tensor_list_input(tensor_list)\n    g = _check_and_get_group(group_name)\n\n    # check dst rank\n    _check_rank_valid(g, dst_rank)\n    _check_root_tensor_valid(len(tensor_list), dst_tensor)\n    opts = types.ReduceOptions()\n    opts.reduce_op = op\n    opts.root_rank = dst_rank\n    opts.root_tensor = dst_tensor\n    g.reduce(tensor_list, opts)\n\n\ndef broadcast(tensor, src_rank: int = 0, group_name: str = \"default\"):\n    \"\"\"Broadcast the tensor from a source process to all others.\n\n    Args:\n        tensor: the tensor to be broadcasted (src) or received (destination).\n        src_rank (int): the rank of the source process.\n        group_name (str): the collective group name to perform broadcast.\n\n    Returns:\n        None\n    \"\"\"\n    _check_single_tensor_input(tensor)\n    g = _check_and_get_group(group_name)\n\n    # check src rank\n    _check_rank_valid(g, src_rank)\n    opts = types.BroadcastOptions()\n    opts.root_rank = src_rank\n    opts.root_tensor = 0\n    g.broadcast([tensor], opts)\n\n\ndef broadcast_partialgpu(tensor_list,\n                         n_elements,\n                         comm_key,\n                         world_size,\n                         devices_ids,\n                         devices_global_rank,\n                         group_name: str = \"default\",\n                         local_start_pos_list=None):\n    \"\"\"Broadcast the tensor from a source GPU to some other GPUs.\n    This function is different from broadcast_multigpu that it only\n    uses a subset of gpus in one host.\n\n    Args:\n        tensor_list: the tensors to broadcast (src) or receive (dst).\n        n_elements: total number of elements involved in this broadcast.\n        comm_key: an unique identifier for this cross-host collective group.\n        world_size: total number of devices in this cross-host collective group.\n        devices_ids: local devices in this cross-host collective group.\n        devices_global_rank: the corresponding global rank for local devices.\n        group_name (str): the collective group name to perform broadcast.\n        local_start_pos_list (list[int]): the list contains starting positions\n        of the contiguous data to be sent in every tensor.\n\n    Returns:\n        None\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"Multigpu calls requires NCCL and Cupy.\")\n    _check_tensor_list_input(tensor_list)\n    g = _check_and_get_group(group_name)\n\n    opts = types.BroadcastOptions()\n    opts.n_elements = n_elements\n    opts.comm_key = comm_key\n    opts.world_size = world_size\n    opts.devices_ids = devices_ids\n    opts.devices_global_rank = devices_global_rank\n    opts.local_start_pos_list = (local_start_pos_list\n                                 if local_start_pos_list is not None else [])\n    g.broadcast_partialgpu(tensor_list, opts)\n\n\ndef broadcast_multigpu(tensor_list,\n                       src_rank: int = 0,\n                       src_tensor: int = 0,\n                       group_name: str = \"default\"):\n    \"\"\"Broadcast the tensor from a source GPU to all other GPUs.\n\n    Args:\n        tensor_list: the tensors to broadcast (src) or receive (dst).\n        src_rank (int): the rank of the source process.\n        src_tensor (int): the index of the source GPU on the source process.\n        group_name (str): the collective group name to perform broadcast.\n\n    Returns:\n        None\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"Multigpu calls requires NCCL and Cupy.\")\n    _check_tensor_list_input(tensor_list)\n    g = _check_and_get_group(group_name)\n\n    # check src rank\n    _check_rank_valid(g, src_rank)\n    _check_root_tensor_valid(len(tensor_list), src_tensor)\n    opts = types.BroadcastOptions()\n    opts.root_rank = src_rank\n    opts.root_tensor = src_tensor\n    g.broadcast(tensor_list, opts)\n\n\ndef allgather(tensor_list: list, tensor, group_name: str = \"default\"):\n    \"\"\"Allgather tensors from each process of the group into a list.\n\n    Args:\n        tensor_list (list): the results, stored as a list of tensors.\n        tensor: the tensor (to be gathered) in the current process\n        group_name (str): the name of the collective group.\n\n    Returns:\n        None\n    \"\"\"\n    _check_single_tensor_input(tensor)\n    _check_tensor_list_input(tensor_list)\n    g = _check_and_get_group(group_name)\n    if len(tensor_list) != g.world_size:\n        # Typically CLL lib requires len(tensor_list) >= world_size;\n        # Here we make it more strict: len(tensor_list) == world_size.\n        raise RuntimeError(\n            \"The length of the tensor list operands to allgather \"\n            \"must be equal to world_size.\")\n    opts = types.AllGatherOptions()\n    g.allgather([tensor_list], [tensor], opts)\n\n\ndef allgather_multigpu(output_tensor_lists: list,\n                       input_tensor_list: list,\n                       group_name: str = \"default\"):\n    \"\"\"Allgather tensors from each gpus of the group into lists.\n\n    Args:\n        output_tensor_lists (List[List[tensor]]): gathered results, with shape\n            must be num_gpus * world_size * shape(tensor).\n        input_tensor_list: (List[tensor]): a list of tensors, with shape\n            num_gpus * shape(tensor).\n        group_name (str): the name of the collective group.\n\n    Returns:\n        None\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"Multigpu calls requires NCCL and Cupy.\")\n    _check_tensor_lists_input(output_tensor_lists)\n    _check_tensor_list_input(input_tensor_list)\n    g = _check_and_get_group(group_name)\n    opts = types.AllGatherOptions()\n    g.allgather(output_tensor_lists, input_tensor_list, opts)\n\n\ndef reducescatter(tensor,\n                  tensor_list: list,\n                  group_name: str = \"default\",\n                  op=types.ReduceOp.SUM):\n    \"\"\"Reducescatter a list of tensors across the group.\n\n    Reduce the list of the tensors across each process in the group, then\n    scatter the reduced list of tensors -- one tensor for each process.\n\n    Args:\n        tensor: the resulted tensor on this process.\n        tensor_list (list): The list of tensors to be reduced and scattered.\n        group_name (str): the name of the collective group.\n        op: The reduce operation.\n\n    Returns:\n        None\n    \"\"\"\n    _check_single_tensor_input(tensor)\n    _check_tensor_list_input(tensor_list)\n    g = _check_and_get_group(group_name)\n    if len(tensor_list) != g.world_size:\n        raise RuntimeError(\n            \"The length of the tensor list operands to reducescatter \"\n            \"must not be equal to world_size.\")\n    opts = types.ReduceScatterOptions()\n    opts.reduce_op = op\n    g.reducescatter([tensor], [tensor_list], opts)\n\n\ndef reducescatter_multigpu(output_tensor_list,\n                           input_tensor_lists,\n                           group_name: str = \"default\",\n                           op=types.ReduceOp.SUM):\n    \"\"\"Reducescatter a list of tensors across all GPUs.\n\n    Args:\n        output_tensor_list: the resulted list of tensors, with\n            shape: num_gpus * shape(tensor).\n        input_tensor_lists: the original tensors, with shape:\n            num_gpus * world_size * shape(tensor).\n        group_name (str): the name of the collective group.\n        op: The reduce operation.\n\n    Returns:\n        None.\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"Multigpu calls requires NCCL and Cupy.\")\n    _check_tensor_lists_input(input_tensor_lists)\n    _check_tensor_list_input(output_tensor_list)\n    g = _check_and_get_group(group_name)\n    opts = types.ReduceScatterOptions()\n    opts.reduce_op = op\n    g.reducescatter(output_tensor_list, input_tensor_lists, opts)\n\n\ndef send(tensor, dst_rank: int, group_name: str = \"default\"):\n    \"\"\"Send a tensor to a remote process synchronously.\n\n    Args:\n        tensor: the tensor to send.\n        dst_rank (int): the rank of the destination process.\n        group_name (str): the name of the collective group.\n\n    Returns:\n        None\n    \"\"\"\n    _check_single_tensor_input(tensor)\n    g = _check_and_get_group(group_name)\n    _check_rank_valid(g, dst_rank)\n    if dst_rank == g.rank:\n        raise RuntimeError(f\"The destination rank '{dst_rank}' is self.\")\n    opts = types.SendOptions()\n    opts.dst_rank = dst_rank\n    g.send([tensor], opts)\n\n\ndef send_multigpu(tensor,\n                  dst_rank: int,\n                  dst_gpu_index: int,\n                  group_name: str = \"default\",\n                  start_pos=0,\n                  n_elements=0):\n    \"\"\"Send a tensor to a remote GPU synchronously.\n\n    The function asssume each process owns >1 GPUs, and the sender\n    process and receiver process has equal nubmer of GPUs.\n\n    Args:\n        tensor: the tensor to send, located on a GPU.\n        dst_rank (int): the rank of the destination process.\n        dst_gpu_index (int): the destination gpu index.\n        group_name (str): the name of the collective group.\n        start_pos (int): the starting position of the contiguous\n        data to be sent in this tensor.\n        n_elements (int): if specified, send the next n elements\n            from the starting address of tensor.\n\n    Returns:\n        None\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"send_multigpu call requires NCCL.\")\n    _check_single_tensor_input(tensor)\n    g = _check_and_get_group(group_name)\n    _check_rank_valid(g, dst_rank)\n    if dst_rank == g.rank:\n        raise RuntimeError(f\"The dst_rank '{dst_rank}' is self. Considering \"\n                           \"doing GPU to GPU memcpy instead?\")\n    if n_elements < 0:\n        raise RuntimeError(f\"The n_elements '{n_elements}' should >= 0.\")\n    opts = types.SendOptions()\n    opts.dst_rank = dst_rank\n    opts.dst_gpu_index = dst_gpu_index\n    opts.start_pos = start_pos\n    opts.n_elements = n_elements\n    g.send([tensor], opts)\n\n\ndef recv(tensor, src_rank: int, group_name: str = \"default\"):\n    \"\"\"Receive a tensor from a remote process synchronously.\n\n    Args:\n        tensor: the received tensor.\n        src_rank (int): the rank of the source process.\n        group_name (str): the name of the collective group.\n\n    Returns:\n        None\n    \"\"\"\n    _check_single_tensor_input(tensor)\n    g = _check_and_get_group(group_name)\n    _check_rank_valid(g, src_rank)\n    if src_rank == g.rank:\n        raise RuntimeError(f\"The destination rank '{src_rank}' is self.\")\n    opts = types.RecvOptions()\n    opts.src_rank = src_rank\n    g.recv([tensor], opts)\n\n\ndef recv_multigpu(tensor,\n                  src_rank: int,\n                  src_gpu_index: int,\n                  group_name: str = \"default\",\n                  start_pos=0,\n                  n_elements=0):\n    \"\"\"Receive a tensor from a remote GPU synchronously.\n\n    The function asssume each process owns >1 GPUs, and the sender\n    process and receiver process has equal nubmer of GPUs.\n\n    Args:\n        tensor: the received tensor, located on a GPU.\n        src_rank (int): the rank of the source process.\n        src_gpu_index (int): the index of the source gpu on the src process.\n        start_pos (int): the starting position of the contiguous\n        data to be sent in this tensor.\n        group_name (str): the name of the collective group.\n\n    Returns:\n        None\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"recv_multigpu call requires NCCL.\")\n    _check_single_tensor_input(tensor)\n    g = _check_and_get_group(group_name)\n    _check_rank_valid(g, src_rank)\n    if src_rank == g.rank:\n        raise RuntimeError(f\"The dst_rank '{src_rank}' is self. Considering \"\n                           \"doing GPU to GPU memcpy instead?\")\n    if n_elements < 0:\n        raise RuntimeError(f\"The n_elements '{n_elements}' should be >= 0.\")\n    opts = types.RecvOptions()\n    opts.src_rank = src_rank\n    opts.src_gpu_index = src_gpu_index\n    opts.start_pos = start_pos\n    opts.n_elements = n_elements\n    g.recv([tensor], opts)\n\n\ndef synchronize(gpu_id: int):\n    \"\"\"Synchronize the current process to a give device.\n\n    Args:\n        gpu_id (int): the GPU device id to synchronize.\n\n    Returns:\n        None\n    \"\"\"\n    if not types.cupy_available():\n        raise RuntimeError(\"synchronize call requires CUDA and NCCL.\")\n    import cupy as cp  # pylint: disable=import-outside-toplevel\n    cp.cuda.Device(gpu_id).synchronize()\n\n\ndef _check_and_get_group(group_name):\n    \"\"\"Check the existence and return the group handle.\"\"\"\n    _check_inside_actor()\n    if not is_group_initialized(group_name):\n        # try loading from remote info store\n        try:\n            # if the information is stored in an Info object,\n            # get and create the group.\n            name = \"info_\" + group_name\n            info_actor = ray.get_actor(name=name)\n            ids, world_size, rank, backend = ray.get(\n                info_actor.get_info.remote())\n\n            # Recycle the info named actor *pro-activately* to avoid named actor\n            # leak.\n            if ray.get(info_actor.get_access_counter.remote()) == world_size:\n                ray.kill(info_actor)\n                logger.debug(\n                    \"Information about the collective group has been \"\n                    \"broadcasted. The Info actor will go out of context and be \"\n                    \"destroyed.\")\n\n            worker = ray_worker.global_worker\n            id_ = worker.core_worker.get_actor_id()\n            r = rank[ids.index(id_)]\n            _group_mgr.create_collective_group(backend, world_size, r,\n                                               group_name)\n        except ValueError as exc:\n            # check if this group is initialized using options()\n            if (\"collective_group_name\" in os.environ and\n                    os.environ[\"collective_group_name\"] == group_name):\n                rank = int(os.environ[\"collective_rank\"])\n                world_size = int(os.environ[\"collective_world_size\"])\n                backend = os.environ[\"collective_backend\"]\n                _group_mgr.create_collective_group(backend, world_size, rank,\n                                                   group_name)\n            else:\n                raise RuntimeError(\n                    f\"The collective group '{group_name}' is not \"\n                    \"initialized in the process.\") from exc\n    g = _group_mgr.get_group_by_name(group_name)\n    return g\n\n\ncheck_and_get_group = _check_and_get_group\n\n\ndef record_events(group_name, uuids, num_devices, is_send):\n    g = _check_and_get_group(group_name)\n    g.record_events(uuids, num_devices, is_send)\n\n\ndef wait_events(group_name, uuids, num_devices, is_send):\n    g = _check_and_get_group(group_name)\n    g.wait_events(uuids, num_devices, is_send)\n\n\ndef comm_wait_compute(group_name, is_send, is_compute, device_id):\n    g = _check_and_get_group(group_name)\n    g.comm_wait_compute(is_send, is_compute, device_id)\n\n\ndef compute_wait_comm(group_name, is_send, is_compute, device_id):\n    g = _check_and_get_group(group_name)\n    g.compute_wait_comm(is_send, is_compute, device_id)\n\n\ndef _check_single_tensor_input(tensor):\n    \"\"\"Check if the tensor is with a supported type.\"\"\"\n    if isinstance(tensor, (np.ndarray, xe.DeviceArray)):\n        return\n    if types.cupy_available():\n        if isinstance(tensor, types.cp.ndarray):\n            return\n    if types.torch_available():\n        if isinstance(tensor, types.th.Tensor):\n            return\n    raise RuntimeError(f\"Unrecognized tensor type '{type(tensor)}'. \"\n                       \"Supported types are: np.ndarray, torch.Tensor, \"\n                       \"cupy.ndarray.\")\n\n\ndef _check_backend_availability(backend: types.Backend):\n    \"\"\"Check whether the backend is available.\"\"\"\n    if backend == types.Backend.GLOO:\n        if not gloo_available():\n            raise RuntimeError(\"GLOO is not available.\")\n    elif backend == types.Backend.NCCL:\n        if not nccl_available():\n            raise RuntimeError(\"NCCL is not available.\")\n\n\ndef _check_inside_actor():\n    \"\"\"Check if currently it is inside a Ray actor/task.\"\"\"\n    worker = ray_worker.global_worker\n    if worker.mode == ray.WORKER_MODE:\n        return\n    else:\n        raise RuntimeError(\"The collective APIs shall be only used inside \"\n                           \"a Ray actor or task.\")\n\n\ndef _check_rank_valid(g, rank: int):\n    \"\"\"Check the rank: 0 <= rank < world_size.\"\"\"\n    if rank < 0:\n        raise ValueError(f\"rank '{rank}' is negative.\")\n    if rank >= g.world_size:\n        raise ValueError(f\"rank '{rank}' must be less than world size \"\n                         f\"'{g.world_size}'\")\n\n\ndef _check_tensor_list_input(tensor_list):\n    \"\"\"Check if the input is a list of supported tensor types.\"\"\"\n    if not isinstance(tensor_list, list):\n        raise RuntimeError(\"The input must be a list of tensors. \"\n                           f\"Got '{type(tensor_list)}'.\")\n    if not tensor_list:\n        raise RuntimeError(\"Got an empty list of tensors.\")\n    for t in tensor_list:\n        _check_single_tensor_input(t)\n\n\ndef _check_tensor_lists_input(tensor_lists):\n    \"\"\"Check if the input is a list of lists of supported tensor types.\"\"\"\n    if not isinstance(tensor_lists, list):\n        raise RuntimeError(\"The input must be a list of lists of tensors. \"\n                           f\"Got '{type(tensor_lists)}'.\")\n    if not tensor_lists:\n        raise RuntimeError(f\"Did not receive tensors. Got: {tensor_lists}\")\n    for t in tensor_lists:\n        _check_tensor_list_input(t)\n\n\ndef _check_root_tensor_valid(length, root_tensor):\n    \"\"\"Check the root_tensor device is 0 <= root_tensor < length\"\"\"\n    if root_tensor < 0:\n        raise ValueError(f\"root_tensor '{root_tensor}' is negative.\")\n    if root_tensor >= length:\n        raise ValueError(f\"root_tensor '{root_tensor}' is greater \"\n                         f\"than the number of GPUs: '{length}'\")\n"
  },
  {
    "path": "alpa/collective/collective_group/__init__.py",
    "content": ""
  },
  {
    "path": "alpa/collective/collective_group/base_collective_group.py",
    "content": "\"\"\"Abstract class for collective groups.\"\"\"\nfrom abc import ABCMeta\nfrom abc import abstractmethod\nimport logging\nimport datetime\nimport time\n\nimport ray\n\nfrom alpa.collective.const import get_store_name\nfrom alpa.collective.types import (AllReduceOptions, BarrierOptions,\n                                   ReduceOptions, AllGatherOptions,\n                                   BroadcastOptions, ReduceScatterOptions)\n\nlogger = logging.getLogger(__name__)\n\n\nclass Rendezvous:\n    \"\"\"A rendezvous class for different actor/task processes to meet.\n\n    To initialize an NCCL collective communication group, different\n    actors/tasks spawned in Ray in a collective group needs to meet\n    each other to synchronize the NCCLUniqueID. This class guarantees\n    they meet via the NCCLUniqueIDStore, initialized on the rank=0\n    process.\n\n    Args:\n        store_key (str): the unique store key, usually as a concatanation\n            of group_name and communicator key. See `get_nccl_communicator`\n            for more details.\n    \"\"\"\n\n    def __init__(self, store_key):\n        if not store_key:\n            raise ValueError(\n                \"Invalid store_key. The store_key is a concatenation of \"\n                \"'group_name' and the 'communicator_key'. See the \"\n                \"docstring of `get_nccl_communicator` for details.\")\n        self._store_key = store_key\n        self._store_name = None\n        self._store = None\n\n    def meet(self, timeout_s=180):\n        \"\"\"Meet at the named actor store.\n\n        Args:\n            timeout_s (int): timeout in seconds.\n\n        Return:\n            None\n        \"\"\"\n        if timeout_s <= 0:\n            raise ValueError(\"The 'timeout' argument must be positive. \"\n                             f\"Got '{timeout_s}'.\")\n        self._store_name = get_store_name(self._store_key)\n        timeout_delta = datetime.timedelta(seconds=timeout_s)\n        elapsed = datetime.timedelta(seconds=0)\n        start_time = datetime.datetime.now()\n        while elapsed < timeout_delta:\n            try:\n                logger.debug(\n                    f\"Trying to meet at the store '{self._store_name}'\")\n                self._store = ray.get_actor(self._store_name)\n            except ValueError:\n                logger.debug(\n                    f\"Failed to meet at the store '{self._store_name}'. \"\n                    \"Trying again...\")\n                time.sleep(1)\n                elapsed = datetime.datetime.now() - start_time\n                continue\n            logger.debug(\"Successful rendezvous!\")\n            break\n        if not self._store:\n            raise RuntimeError(\"Unable to meet other processes \"\n                               \"at the rendezvous store. If you are using \"\n                               \"P2P communication, please check if tensors \"\n                               \"are put in the correct GPU. \")\n\n    @property\n    def store(self):\n        return self._store\n\n    def get_nccl_id(self, timeout_s=180):\n        \"\"\"Get the NCCLUniqueID from the store through Ray.\n\n        Args:\n            timeout_s: timeout in seconds.\n\n        Return:\n            uid (str): the NCCLUniqueID if successful.\n        \"\"\"\n        if not self._store:\n            raise ValueError(\"Rendezvous store is not setup.\")\n        uid = None\n        timeout_delta = datetime.timedelta(seconds=timeout_s)\n        elapsed = datetime.timedelta(seconds=0)\n        start_time = datetime.datetime.now()\n        while elapsed < timeout_delta:\n            uid = ray.get(self._store.get_id.remote())\n            if not uid:\n                time.sleep(1)\n                elapsed = datetime.datetime.now() - start_time\n                continue\n            break\n        if not uid:\n            raise RuntimeError(\"Unable to get the NCCLUniqueID from the store.\")\n        return uid\n\n    def get_access_counter(self):\n        \"\"\"Return how many times the NCCLUniqueID has been accessed.\"\"\"\n        return ray.get(self._store.get_access_counter.remote())\n\n    def destroy_store(self):\n        \"\"\"Delete the named actor.\"\"\"\n        self._store = None\n\n\nclass BaseGroup(metaclass=ABCMeta):\n    \"\"\"Abstract class for collective groups.\"\"\"\n\n    def __init__(self, world_size, rank, group_name):\n        \"\"\"Init the process group with basic information.\n\n        Args:\n            world_size (int): The total number of processes in the group.\n            rank (int): The rank of the current process.\n            group_name (str): The group name.\n        \"\"\"\n        self._world_size = world_size\n        self._rank = rank\n        self._group_name = group_name\n\n    @property\n    def rank(self):\n        \"\"\"Return the rank of the current process.\"\"\"\n        return self._rank\n\n    @property\n    def world_size(self):\n        \"\"\"Return the number of processes in this group.\"\"\"\n        return self._world_size\n\n    @property\n    def group_name(self):\n        \"\"\"Return the group name of this group.\"\"\"\n        return self._group_name\n\n    @classmethod\n    def backend(cls):\n        \"\"\"The backend of this collective group.\"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def allreduce(self, tensors, allreduce_options=AllReduceOptions()):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def barrier(self, barrier_options=BarrierOptions()):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def reduce(self, tensors, reduce_options=ReduceOptions()):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def allgather(self,\n                  tensor_lists,\n                  tensors,\n                  allgather_options=AllGatherOptions()):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def broadcast(self, tensors, broadcast_options=BroadcastOptions()):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def reducescatter(self,\n                      tensors,\n                      tensor_lists,\n                      reducescatter_options=ReduceScatterOptions()):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def send(self, tensors, send_options):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def recv(self, tensors, recv_options):\n        raise NotImplementedError()\n"
  },
  {
    "path": "alpa/collective/collective_group/cuda_stream.py",
    "content": "\"\"\"CUDA stream pool.\"\"\"\nimport logging\nimport threading\n\nimport cupy\nfrom alpa.collective.collective_group import nccl_util\nfrom alpa.collective.const import ENV\n\nNCCL_STREAM_POOL_SIZE = 32\nMAX_GPU_PER_ACTOR = 16\n\nlogger = logging.getLogger(__name__)\n\n\nclass StreamPool:\n    \"\"\"The class that represents a stream pool associated with a GPU.\n\n    When multistream is enabled, we will allocate a pool of streams for each\n    GPU, and get available stream from this pool when a collective kernel is\n    initialized. This enables overlapping computation/communication kernels\n    using multiple CUDA streams, given that the streams a appropriately\n    synchronized. The class is thread-safe.\n\n\n    Args:\n        device_idx (int): the absolute index of the device for this pool.\n    \"\"\"\n\n    def __init__(self, device_idx):\n        self.device_idx = device_idx\n\n        self._initialized = False\n        self._initialized_lock = threading.Lock()\n\n        self._pool = [None] * NCCL_STREAM_POOL_SIZE\n        self._counter = 0\n        self._pool_lock = threading.Lock()\n        self._init_flag = False\n\n    def get_stream(self):\n        \"\"\"Get an available stream from the pool.\n\n        The function locks the stream pool and releases the lock before\n        returning.\n\n        Returns:\n            stream (cupy.cuda.Stream): the returned stream from pool.\n        \"\"\"\n\n        # check the flag\n        with self._initialized_lock:\n            if not self._initialized:\n                self._init_once()\n\n        # Get the stream from the pool.\n        with self._pool_lock:\n            stream = self._pool[self._counter]\n            self._counter = (self._counter + 1) % NCCL_STREAM_POOL_SIZE\n        return stream\n\n    def _init_once(self):\n        \"\"\"Initialize the stream pool only for once.\"\"\"\n        with nccl_util.Device(self.device_idx):\n            for i in range(NCCL_STREAM_POOL_SIZE):\n                # this is the only place where self._pool will be written.\n                if ENV.NCCL_USE_MULTISTREAM.val:\n                    logger.debug(\"NCCL multistream enabled.\")\n                    self._pool[i] = cupy.cuda.Stream(null=False,\n                                                     non_blocking=False)\n                else:\n                    logger.debug(\"NCCL multistream disabled.\")\n                    self._pool[i] = cupy.cuda.Stream.null\n        self._init_flag = True\n\n\n# This is a map from GPU index to its stream pool.\n# It is supposed to be READ-ONLY out of this file\n_device_stream_pool_map = {}\n\n\ndef _init_stream_pool():\n    for i in range(MAX_GPU_PER_ACTOR):\n        _device_stream_pool_map[i] = StreamPool(i)\n\n\ndef get_stream_pool(device_idx):\n    \"\"\"Get the CUDA stream pool of a GPU device.\"\"\"\n    # In case there will be multiple threads writing to the pool.\n    lock = threading.Lock()\n    with lock:\n        if not _device_stream_pool_map:\n            _init_stream_pool()\n    return _device_stream_pool_map[device_idx]\n"
  },
  {
    "path": "alpa/collective/collective_group/gloo_collective_group.py",
    "content": "\"\"\"Gloo-based collective operations.\"\"\"\nimport logging\nimport datetime\nimport time\nimport os\nimport shutil\n\nimport numpy\nimport ray\nfrom ray import ray_constants\nimport pygloo\n\nfrom alpa.collective.collective_group import gloo_util\nfrom alpa.collective.collective_group.base_collective_group import BaseGroup\nfrom alpa.collective.types import (AllReduceOptions, BarrierOptions, Backend,\n                                   ReduceOptions, BroadcastOptions,\n                                   AllGatherOptions, ReduceScatterOptions,\n                                   SendOptions, RecvOptions)\nfrom alpa.collective.const import get_store_name\nfrom alpa.util import try_import_ray_worker\n\nray_worker = try_import_ray_worker()\n\nlogger = logging.getLogger(__name__)\n\n\nclass Rendezvous:\n    \"\"\"A rendezvous class for different actor/task processes to meet.\n\n    To initialize an GLOO collective communication group, different\n    actors/tasks spawned in Ray in a collective group needs to meet\n    each other to synchronize the GLOOUniqueID. This class guarantees\n    they meet via the GLOOUniqueIDStore, initialized on the rank=0\n    process.\n\n    Args:\n        group_name (str): the unique user-specified group name.\n    \"\"\"\n\n    def __init__(self, group_name, context, store_type, device_type):\n        self._group_name = group_name\n        self._context = context\n        self._redis_ip_address, self._redis_port = (\n            ray_worker._global_node.redis_address.split(\":\"))\n        self._process_ip_address = (ray.util.get_node_ip_address())\n        logger.debug(f\"Redis address: {self._redis_ip_address}, \"\n                     f\"port: {self._redis_port}, \"\n                     f\"this actor address: {self._process_ip_address}.\")\n        self._store_type = store_type\n        self._device_type = device_type\n        self._store = None\n        self._device = None\n        self.create_store(store_type)\n        self.create_device(device_type)\n\n    def create_store(self, store_type):\n        if store_type == \"redis\":\n            redis_store = pygloo.rendezvous.RedisStore(self._redis_ip_address,\n                                                       int(self._redis_port))\n            redis_password = ray_constants.REDIS_DEFAULT_PASSWORD\n            redis_store.authorize(redis_password)\n            self._store = redis_store\n        elif store_type == \"file\":\n            store_name = get_store_name(self._group_name)\n            store_path = gloo_util.get_gloo_store_path(store_name)\n            if self._context.rank == 0:\n                if not os.path.exists(store_path):\n                    os.makedirs(store_path)\n                elif os.listdir(store_path) and os.listdir(store_path):\n                    shutil.rmtree(store_path)\n                    os.makedirs(store_path)\n            else:\n                while not os.path.exists(store_path):\n                    time.sleep(0.1)\n            # Note: multi-machines needs a shared NFS.\n            file_store = pygloo.rendezvous.FileStore(store_path)\n            self._store = pygloo.rendezvous.PrefixStore(self._group_name,\n                                                        file_store)\n        elif store_type == \"hash\":\n            raise NotImplementedError(\"No implementation for hash store.\")\n        else:\n            raise RuntimeError(f\"Unrecognized store type: {store_type}.\")\n\n    def create_device(self, device_type):\n        if device_type == \"tcp\":\n            attr = pygloo.transport.tcp.attr(self._process_ip_address)\n            self._device = pygloo.transport.tcp.CreateDevice(attr)\n        elif device_type == \"uv\":\n            raise NotImplementedError(\"No implementation for uv.\")\n\n    def meet(self, timeout_s=180):\n        \"\"\"Meet at the named actor store.\n\n        Args:\n            timeout_s (int): timeout in seconds.\n\n        Return:\n            None\n        \"\"\"\n        if timeout_s <= 0:\n            raise ValueError(\"The 'timeout' argument must be positive. \"\n                             f\"Got '{timeout_s}'.\")\n\n        timeout_delta = datetime.timedelta(seconds=timeout_s)\n        elapsed = datetime.timedelta(seconds=0)\n        start_time = datetime.datetime.now()\n        q, s = None, None\n\n        if self._store_type == \"redis\":\n            while elapsed < timeout_delta:\n                try:\n                    q = ray.get_actor(\"gloo_queue\")\n                    s = ray.get_actor(f\"gloo_{self._group_name}_signal\")\n                    break\n                except ValueError:\n                    if self._context.rank == 0:\n                        if not q:\n                            ray.remote(gloo_util.GlooQueue).options(\n                                name=\"gloo_queue\",\n                                lifetime=\"detached\").remote(1000)\n                        if not s:\n                            gloo_util.SignalActor.options(\n                                name=f\"gloo_{self._group_name}_signal\",\n                                lifetime=\"detached\").remote(self._context.size)\n                    else:\n                        time.sleep(0.1)\n                elapsed = datetime.datetime.now() - start_time\n            if not q:\n                raise RuntimeError(\"Unable to get gloo_queue.\")\n            if self._context.rank == 0:\n                ray.get(q.put_nowait.remote(self._group_name))\n            while ray.get(q.index.remote(self._group_name)):\n                time.sleep(0.1)\n            self._context.connectFullMesh(self._store, self._device)\n            ray.get(s.send.remote(self._context.rank))\n            if self._context.rank == 0:\n                ray.get(s.wait.remote())\n                keys = []\n                keys += [f\"rank_{i}\" for i in range(self._context.size)]\n                keys += [f\"{i}\" for i in range(self._context.size)]\n                self._store.delKeys(keys)\n                group_name = ray.get(q.get_nowait.remote())\n                assert group_name == self._group_name\n                ray.kill(s)\n\n    @property\n    def store_type(self):\n        return self._store_type\n\n    @property\n    def store(self):\n        return self._store\n\n    @property\n    def device_type(self):\n        return self._device_type\n\n    @property\n    def device(self):\n        return self._device\n\n    def destroy(self):\n        \"\"\"GC the store and device used by this rendevzous.\"\"\"\n        self._device = None\n\n\nclass GLOOGroup(BaseGroup):\n    \"\"\"Gloo-based collective operations.\"\"\"\n\n    def __init__(self,\n                 world_size,\n                 rank,\n                 group_name,\n                 store_type=\"redis\",\n                 device_type=\"tcp\"):\n        \"\"\"Init an GLOO collective group.\n\n        Args:\n            world_size (int): The number of processes.\n            rank (int): The id of process\n            group_name (str): The unique user-specified group name.\n            store_type (str): The store type. Optional: \"redis\",\n                              \"file\", \"hash\".\n            device_type (str): The device type to transport.\n                               Optional: \"tcp\", \"uv\".\n        \"\"\"\n        super().__init__(world_size, rank, group_name)\n        self._gloo_context = gloo_util.create_gloo_context(\n            self.rank, self.world_size)\n        self._rendezvous = Rendezvous(self.group_name, self._gloo_context,\n                                      store_type, device_type)\n        self._rendezvous.meet()\n\n    def destroy_group(self):\n        \"\"\"Destroy the group and release GLOO communicators.\"\"\"\n        self._rendezvous.destroy()\n\n        if self._gloo_context is not None:\n            pygloo.barrier(self._gloo_context)\n            # destroy the communicator\n            self._gloo_context = None\n\n        if self.rank == 0 and self._rendezvous.store_type == \"file\":\n            store_name = get_store_name(self._group_name)\n            store_path = gloo_util.get_gloo_store_path(store_name)\n            if os.path.exists(store_path):\n                shutil.rmtree(store_path)\n\n    @classmethod\n    def backend(cls):\n        return Backend.GLOO\n\n    def allreduce(self, tensors, allreduce_options=AllReduceOptions()):\n        \"\"\"AllReduce a list of tensors following options.\n\n        Args:\n            tensor: the tensor to be reduced, each tensor locates on CPU\n            allreduce_options:\n\n        Returns:\n            None\n        \"\"\"\n\n        def collective_fn(input_tensor, output_tensor, context):\n            pygloo.allreduce(\n                context, gloo_util.get_tensor_ptr(input_tensor),\n                gloo_util.get_tensor_ptr(output_tensor),\n                gloo_util.get_tensor_n_elements(input_tensor),\n                gloo_util.get_gloo_tensor_dtype(input_tensor),\n                gloo_util.get_gloo_reduce_op(allreduce_options.reduce_op))\n\n        self._collective(tensors, tensors, collective_fn)\n\n    def barrier(self, barrier_options=BarrierOptions()):\n        \"\"\"Blocks until all processes reach this barrier.\n\n        Args:\n            barrier_options: barrier options.\n\n        Returns:\n            None\n        \"\"\"\n        barrier_tensor = numpy.array([1])\n        self.allreduce([barrier_tensor])\n\n    def reduce(self, tensors, reduce_options=ReduceOptions()):\n        \"\"\"Reduce tensors following options.\n\n        Args:\n            tensors (List): the list of tensors to be reduced,\n                            this list only have one tensor.\n            reduce_options: reduce options.\n\n        Returns:\n            None\n        \"\"\"\n        root_rank = reduce_options.root_rank\n\n        def collective_fn(input_tensor, output_tensor, context):\n            pygloo.reduce(\n                context, gloo_util.get_tensor_ptr(input_tensor),\n                gloo_util.get_tensor_ptr(output_tensor),\n                gloo_util.get_tensor_n_elements(input_tensor),\n                gloo_util.get_gloo_tensor_dtype(input_tensor),\n                gloo_util.get_gloo_reduce_op(reduce_options.reduce_op),\n                root_rank)\n\n        self._collective(tensors, tensors, collective_fn)\n\n    def broadcast(self, tensors, broadcast_options=BroadcastOptions()):\n        \"\"\"Broadcast tensors to all other processes following options.\n\n        Args:\n            tensors (List): tensors to be broadcast or received.\n            broadcast_options: broadcast options.\n\n        Returns:\n            None\n        \"\"\"\n        root_rank = broadcast_options.root_rank\n\n        def collective_fn(input_tensor, output_tensor, context):\n            pygloo.broadcast(context, gloo_util.get_tensor_ptr(input_tensor),\n                             gloo_util.get_tensor_ptr(output_tensor),\n                             gloo_util.get_tensor_n_elements(input_tensor),\n                             gloo_util.get_gloo_tensor_dtype(input_tensor),\n                             root_rank)\n\n        self._collective(tensors, tensors, collective_fn)\n\n    def allgather(self,\n                  tensor_lists,\n                  tensors,\n                  allgather_options=AllGatherOptions()):\n        \"\"\"Allgather tensors on CPU into a list of tensors.\n\n        Args:\n            tensor_lists (List[List[Tensor]]): allgathered tensors.\n            tensors: the list of tensors to allgather across the group.\n                     Each tensor must locate on CPU.\n            allgather_options: allgather options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def collective_fn(input_tensor, output_tensor, context):\n            pygloo.allgather(context, gloo_util.get_tensor_ptr(input_tensor),\n                             gloo_util.get_tensor_ptr(output_tensor),\n                             gloo_util.get_tensor_n_elements(input_tensor),\n                             gloo_util.get_gloo_tensor_dtype(input_tensor))\n\n        _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)\n        output_flattened = [\n            _flatten_for_scatter_gather(tensor_list, copy=False)\n            for tensor_list in tensor_lists\n        ]\n\n        def postprocess_fn():\n            for i, tensor_list in enumerate(tensor_lists):\n                for j, tensor in enumerate(tensor_list):\n                    gloo_util.copy_tensor(tensor, output_flattened[i][j])\n\n        self._collective(tensors,\n                         output_flattened,\n                         collective_fn,\n                         postprocess_fn=postprocess_fn)\n\n    def reducescatter(self,\n                      tensors,\n                      tensor_lists,\n                      reducescatter_options=ReduceScatterOptions()):\n        \"\"\"Reduce the scatter a list of tensors across the group.\n\n        Args:\n            tensors (List): the output tensors (could be unspecified), each\n                            located on CPU.\n            tensor_lists (List[List]): the list of tensors to be reduced then\n                                       scattered.\n            reducescatter_options: reduce-scatter options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def collective_fn(input_tensor, output_tensor, context):\n            size = gloo_util.get_tensor_n_elements(input_tensor)\n            world_size = self._gloo_context.size\n            pygloo.reduce_scatter(\n                context, gloo_util.get_tensor_ptr(input_tensor),\n                gloo_util.get_tensor_ptr(output_tensor), size,\n                [size // world_size for _ in range(world_size)],\n                gloo_util.get_gloo_tensor_dtype(output_tensor),\n                gloo_util.get_gloo_reduce_op(reducescatter_options.reduce_op))\n\n        _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)\n        input_flattened = [\n            _flatten_for_scatter_gather(tensor_list, copy=False)\n            for tensor_list in tensor_lists\n        ]\n\n        def preprocess_fn():\n            for i, tensor_list in enumerate(tensor_lists):\n                for j, tensor in enumerate(tensor_list):\n                    gloo_util.copy_tensor(input_flattened[i][j], tensor)\n\n        self._collective(input_flattened,\n                         tensors,\n                         collective_fn,\n                         preprocess_fn=preprocess_fn)\n\n    def send(self, tensors, send_options=SendOptions()):\n        \"\"\"Send a tensor to a destination rank in the group.\n\n        Args:\n            tensors (List): the tensor to send.\n            send_options: send options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def p2p_fn(tensor, context, peer):\n            pygloo.send(context, gloo_util.get_tensor_ptr(tensor),\n                        gloo_util.get_tensor_n_elements(tensor),\n                        gloo_util.get_gloo_tensor_dtype(tensor), peer)\n\n        self._point2point(tensors, p2p_fn, send_options.dst_rank)\n\n    def recv(self, tensors, recv_options=RecvOptions()):\n        \"\"\"Receive a tensor from a source rank in the group.\n\n        Args:\n            tensors (List): the received tensor.\n            recv_options: Receive options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def p2p_fn(tensor, context, peer):\n            pygloo.recv(context, gloo_util.get_tensor_ptr(tensor),\n                        gloo_util.get_tensor_n_elements(tensor),\n                        gloo_util.get_gloo_tensor_dtype(tensor), peer)\n\n        self._point2point(tensors, p2p_fn, recv_options.src_rank)\n\n    def _collective(self,\n                    input_tensors,\n                    output_tensors,\n                    collective_fn,\n                    preprocess_fn=None,\n                    postprocess_fn=None):\n        \"\"\"A method to encapsulate all collective calls.\n\n        Args:\n            input_tensors: the list of the input tensors.\n            output_tensors: the list of the output tensors.\n            collective_fn: the collective function call.\n            preprocess_fn: preprocess procedures before collective calls.\n            postprocess_fn: postprocess procedures after collective calls.\n\n        Returns:\n            None\n        \"\"\"\n        _check_cpu_tensors(input_tensors)\n        _check_cpu_tensors(output_tensors)\n\n        if preprocess_fn:\n            preprocess_fn()\n        collective_fn(input_tensors[0], output_tensors[0], self._gloo_context)\n        if postprocess_fn:\n            postprocess_fn()\n\n    def _point2point(self, tensors, p2p_fn, peer_rank: int):\n        \"\"\"A method to encapsulate all peer-to-peer calls (i.e., send/recv).\n\n        Args:\n            tensors: the tensor to send or receive.\n            p2p_fn: the p2p function call.\n            peer_rank (int): the rank of the peer process.\n\n        Returns:\n            None\n        \"\"\"\n        _check_cpu_tensors(tensors)\n\n        p2p_fn(tensors[0], self._gloo_context, peer_rank)\n\n\ndef _check_cpu_tensors(tensors):\n    \"\"\"Check only have one tensor and located on CPU.\"\"\"\n    if not tensors or not isinstance(tensors, list):\n        raise RuntimeError(\"'tensors' must be a nonempty list.\")\n    if len(tensors) != 1:\n        raise RuntimeError(\"Gloo only accept one tensor in the tensor list.\"\n                           f\" Got {len(tensors)} != 1.\")\n    d = gloo_util.get_tensor_device(tensors[0])\n    if d != \"cpu\":\n        raise RuntimeError(\"Gloo only accept cpu tensor.\"\n                           f\" Got {d}.\")\n\n\ndef _flatten_for_scatter_gather(tensor_list, copy=False):\n    \"\"\"Flatten the tensor for gather/scatter operations.\n\n    Args:\n        tensor_list: the list of tensors to be scattered/gathered.\n        copy: whether the copy the tensors in tensor_list into the buffer.\n\n    Returns:\n        The flattened tensor buffer.\n    \"\"\"\n    if not tensor_list:\n        raise RuntimeError(\"Received an empty list.\")\n\n    t = tensor_list[0]\n    # note we need a numpy dtype here.\n    dtype = gloo_util.get_numpy_tensor_dtype(t)\n    buffer_shape = [len(tensor_list)] + gloo_util.get_tensor_shape(t)\n\n    buffer = numpy.empty(buffer_shape, dtype=dtype)\n    if copy:\n        for i, tensor in enumerate(tensor_list):\n            gloo_util.copy_tensor(buffer[i], tensor)\n    return buffer\n\n\ndef _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists):\n    \"\"\"Check the compatibility between tensor input and tensor list input.\"\"\"\n    if not tensors or not isinstance(tensors, list):\n        raise RuntimeError(\n            \"The first argument 'tensors' expects a list of tensors.\")\n\n    if len(tensors) != 1:\n        raise RuntimeError(\n            \"Gloo only accept one tensor in the first argument 'tensors'.\"\n            f\" Got {len(tensors)} != 1.\")\n\n    if not tensor_lists or not isinstance(tensor_lists, list):\n        raise RuntimeError(\"The second argument 'tensor_lists' \"\n                           \"expects a list of tensor list.\")\n\n    if len(tensor_lists) != 1:\n        raise RuntimeError(\"Gloo only accept one tensor list \"\n                           \"in the second argument 'tensor_lists'.\"\n                           f\" Got {len(tensor_lists)} != 1.\")\n\n    dtype = gloo_util.get_gloo_tensor_dtype(tensors[0])\n    shape = gloo_util.get_tensor_shape(tensors[0])\n\n    # check all tensors in `tensor_lists` match.\n    for t in tensor_lists[0]:\n        # check dtype\n        dt = gloo_util.get_gloo_tensor_dtype(t)\n        if dt != dtype:\n            raise RuntimeError(\n                \"All tensor operands to scatter/gather must \"\n                f\"have the same dtype. Got '{dt}' and '{dtype}'.\")\n        s = gloo_util.get_tensor_shape(t)\n        if s != shape:\n            raise RuntimeError(\"All tensor operands to scatter/gather must \"\n                               f\"have the same shape. Got '{s}' and '{shape}'.\")\n"
  },
  {
    "path": "alpa/collective/collective_group/gloo_util.py",
    "content": "\"\"\"Code to wrap some GLOO API calls.\"\"\"\nimport asyncio\nimport numpy\ntry:\n    import pygloo\nexcept ImportError as ie:\n    raise ImportError(\n        \"Can not import pygloo.\"\n        \"Please run 'pip install pygloo' to install pygloo.\") from ie\n\nimport ray\nfrom ray.util.queue import _QueueActor\nfrom alpa.collective.types import ReduceOp, torch_available\n\nGLOO_REDUCE_OP_MAP = {\n    ReduceOp.SUM: pygloo.ReduceOp.SUM,\n    ReduceOp.PRODUCT: pygloo.ReduceOp.PRODUCT,\n    ReduceOp.MIN: pygloo.ReduceOp.MIN,\n    ReduceOp.MAX: pygloo.ReduceOp.MAX,\n}\n\nNUMPY_GLOO_DTYPE_MAP = {\n    # INT types\n    numpy.uint8: pygloo.glooDataType_t.glooUint8,\n    numpy.uint32: pygloo.glooDataType_t.glooUint32,\n    numpy.uint64: pygloo.glooDataType_t.glooUint64,\n    numpy.int8: pygloo.glooDataType_t.glooInt8,\n    numpy.int32: pygloo.glooDataType_t.glooInt32,\n    numpy.int64: pygloo.glooDataType_t.glooInt64,\n    # FLOAT types\n    numpy.half: pygloo.glooDataType_t.glooFloat16,\n    numpy.float16: pygloo.glooDataType_t.glooFloat16,\n    numpy.float32: pygloo.glooDataType_t.glooFloat32,\n    numpy.float64: pygloo.glooDataType_t.glooFloat64,\n    numpy.double: pygloo.glooDataType_t.glooFloat64,\n}\n\nif torch_available():\n    import torch\n    TORCH_GLOO_DTYPE_MAP = {\n        torch.int: pygloo.glooDataType_t.glooInt32,\n        torch.uint8: pygloo.glooDataType_t.glooUint8,\n        torch.int8: pygloo.glooDataType_t.glooInt8,\n        torch.int32: pygloo.glooDataType_t.glooInt32,\n        torch.int64: pygloo.glooDataType_t.glooInt64,\n        torch.long: pygloo.glooDataType_t.glooInt64,\n        # FLOAT types\n        torch.half: pygloo.glooDataType_t.glooFloat16,\n        torch.float: pygloo.glooDataType_t.glooFloat32,\n        torch.float16: pygloo.glooDataType_t.glooFloat16,\n        torch.float32: pygloo.glooDataType_t.glooFloat32,\n        torch.float64: pygloo.glooDataType_t.glooFloat64,\n        torch.double: pygloo.glooDataType_t.glooFloat64,\n    }\n\n    TORCH_NUMPY_DTYPE_MAP = {\n        # INT types\n        torch.int: numpy.int32,\n        torch.uint8: numpy.uint8,\n        torch.int8: numpy.int8,\n        torch.int32: numpy.int32,\n        torch.int64: numpy.int64,\n        torch.long: numpy.int64,\n        # FLOAT types\n        torch.half: numpy.half,\n        torch.float: numpy.float32,\n        torch.float16: numpy.float16,\n        torch.float32: numpy.float32,\n        torch.float64: numpy.float64,\n    }\n\n\ndef create_gloo_context(rank, world_size):\n    \"\"\"Create a GLOO context using GLOO APIs.\n\n    Args:\n        rank (int): the rank of this process.\n        world_size (int): the number of processes of this collective group.\n\n    Returns:\n        context (pygloo.Context): a GLOO context.\n    \"\"\"\n    context = pygloo.rendezvous.Context(rank, world_size)\n    return context\n\n\ndef get_gloo_reduce_op(reduce_op):\n    \"\"\"Map the reduce op to GLOO reduce op type.\n\n    Args:\n        reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX).\n\n    Returns:\n        (pygloo.ReduceOp): the mapped GLOO reduce op.\n    \"\"\"\n    if reduce_op not in GLOO_REDUCE_OP_MAP:\n        raise RuntimeError(f\"Gloo does not support reduce op: '{reduce_op}'.\")\n    return GLOO_REDUCE_OP_MAP[reduce_op]\n\n\ndef get_gloo_tensor_dtype(tensor):\n    \"\"\"Return the corresponded GLOO dtype given a tensor.\"\"\"\n    if isinstance(tensor, numpy.ndarray):\n        return NUMPY_GLOO_DTYPE_MAP[tensor.dtype.type]\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            if not tensor.is_cuda:\n                return TORCH_GLOO_DTYPE_MAP[tensor.dtype]\n            else:\n                raise ValueError(\"Expect torch CPU tensor. \"\n                                 f\"Got {tensor.device}.\")\n    raise ValueError(\"Unsupported tensor type. \"\n                     f\"Got: {type(tensor)}.\")\n\n\ndef get_numpy_tensor_dtype(tensor):\n    \"\"\"Return the corresponded Cupy dtype given a tensor.\"\"\"\n    if isinstance(tensor, numpy.ndarray):\n        return tensor.dtype.type\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            return TORCH_NUMPY_DTYPE_MAP[tensor.dtype]\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported CPU tensor types are: torch.Tensor, \"\n                     \"numpy.ndarray.\")\n\n\ndef get_tensor_ptr(tensor):\n    \"\"\"Return the pointer to the underlying memory storage of a tensor.\"\"\"\n    if isinstance(tensor, numpy.ndarray):\n        return tensor.ctypes.data\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            if tensor.is_cuda:\n                raise RuntimeError(\"Torch tensor must be on CPU \"\n                                   \"when using GLOO collectives.\")\n            return tensor.data_ptr()\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported CPU tensor types are: torch.Tensor, \"\n                     \"numpy.ndarray.\")\n\n\ndef get_tensor_n_elements(tensor):\n    \"\"\"Return the number of elements in a tensor.\"\"\"\n    if isinstance(tensor, numpy.ndarray):\n        return tensor.size\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            return torch.numel(tensor)\n    raise ValueError(\"Unsupported tensor type. \"\n                     f\"Got: {type(tensor)}.\")\n\n\ndef get_gloo_store_path(store_name):\n    from ray._private.utils import get_ray_temp_dir  # pylint: disable=import-outside-toplevel\n    store_path = f\"{get_ray_temp_dir()}_collective/gloo/{store_name}\"\n    return store_path\n\n\ndef get_tensor_device(tensor):\n    if isinstance(tensor, numpy.ndarray):\n        return \"cpu\"\n    elif torch_available() and isinstance(tensor, torch.Tensor):\n        if not tensor.is_cuda:\n            return \"cpu\"\n        else:\n            return \"cuda\"\n    else:\n        raise RuntimeError(\"Unrecognized tensor type: \"\n                           f\"'{type(tensor)}'.\")\n\n\ndef get_tensor_shape(tensor):\n    \"\"\"Return the shape of the tensor as a list.\"\"\"\n    if isinstance(tensor, numpy.ndarray):\n        return list(tensor.shape)\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            return list(tensor.size())\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported CPU tensor types are: torch.Tensor, \"\n                     \"numpy.ndarray.\")\n\n\ndef copy_tensor(dst_tensor, src_tensor):\n    \"\"\"Copy the content from src_tensor to dst_tensor.\n\n    Args:\n        dst_tensor: the tensor to copy from.\n        src_tensor: the tensor to copy to.\n\n    Returns:\n        None\n    \"\"\"\n    copied = True\n    if (isinstance(dst_tensor, numpy.ndarray) and\n            isinstance(src_tensor, numpy.ndarray)):\n        numpy.copyto(dst_tensor, src_tensor)\n    elif torch_available():\n        if isinstance(dst_tensor, torch.Tensor) and isinstance(\n                src_tensor, torch.Tensor):\n            dst_tensor.copy_(src_tensor)\n        elif isinstance(dst_tensor, torch.Tensor) and isinstance(\n                src_tensor, numpy.ndarray):\n            t = torch.Tensor(src_tensor)\n            dst_tensor.copy_(t)\n        elif isinstance(dst_tensor, numpy.ndarray) and isinstance(\n                src_tensor, torch.Tensor):\n            t = src_tensor.numpy()\n            numpy.copyto(dst_tensor, t)\n        else:\n            copied = False\n    else:\n        copied = False\n    if not copied:\n        raise ValueError(\n            f\"Unsupported tensor type. Got: {type(dst_tensor)} and \"\n            f\"{type(src_tensor)}. Supported CPU tensor types are: \"\n            f\"torch.Tensor, numpy.ndarray.\")\n\n\n# Note(Hao): this requires Ray >= 1.2.0,\n# otherwise _QueueActor is an actor class.\nclass GlooQueue(_QueueActor):\n\n    def index(self, group_name):\n        try:\n            return self.queue._queue.index(group_name)  # pylint: disable=protected-access\n        except ValueError:\n            return -1\n\n\n@ray.remote(num_cpus=0)\nclass SignalActor:\n    \"\"\"An actor that can be used for sending signals.\"\"\"\n\n    def __init__(self, world_size):\n        self.ready_events = [asyncio.Event() for _ in range(world_size)]\n        self.world_size = world_size\n\n    def send(self, rank, clear=False):\n        self.ready_events[rank].set()\n        if clear:\n            self.ready_events[rank].clear()\n\n    async def wait(self, should_wait=True):\n        if should_wait:\n            for i in range(self.world_size):\n                await self.ready_events[i].wait()\n"
  },
  {
    "path": "alpa/collective/collective_group/nccl_collective_group.py",
    "content": "\"\"\"NCCL-based collective operations.\"\"\"\nimport logging\n\nimport ray\nimport cupy\nfrom jax._src.lib import xla_extension as xe\n\nfrom alpa.collective.const import ENV\nfrom alpa.collective.collective_group import nccl_util\nfrom alpa.collective.collective_group.base_collective_group import (BaseGroup,\n                                                                    Rendezvous)\nfrom alpa.collective.const import get_store_name\nfrom alpa.collective.types import (AllReduceOptions, BarrierOptions, Backend,\n                                   ReduceOptions, BroadcastOptions,\n                                   AllGatherOptions, ReduceScatterOptions,\n                                   SendOptions, RecvOptions)\nfrom alpa.collective.collective_group.cuda_stream import get_stream_pool\nfrom alpa.monkey_patch import override_get_backend\n\nlogger = logging.getLogger(__name__)\n\n\n# FIXME: should not assume that each worker has the same number of devices\nclass NCCLGroup(BaseGroup):\n    \"\"\"NCCL-based collective operations.\"\"\"\n\n    def __init__(self, world_size, rank, group_name):\n        \"\"\"Init an NCCL collective group.\"\"\"\n        super().__init__(world_size, rank, group_name)\n\n        # communicator and stream cache.\n        # TODO (Hao): we need a lock here...\n        self._barrier_tensor = None\n        self._dev_comm_map = {}\n        self._dev_streams_map = {}\n        self._xla_comm_keys = set()\n\n        # record the used GPU IDs.\n        self._used_gpu_indices = set()\n\n        # TODO(Fu): might need an event map\n        self._dev_event_map = {}\n        # This is only for cross-mesh all-reduce to use\n        backend = override_get_backend()\n        self.xla_comm_group = xe.CommGroup(backend)\n\n        if nccl_util.get_nccl_build_version() < 2000:\n            raise RuntimeError(\"NCCL in Ray requires NCCL >= 2.0.\")\n        if nccl_util.get_nccl_runtime_version() < 2704:\n            logger.warning(\"NCCL send/recv calls requires NCCL>=2.7.4\")\n\n    def destroy_group(self):\n        \"\"\"Destroy the group and release NCCL communicators.\"\"\"\n        if len(self._dev_comm_map.keys()) > 0:\n\n            # TODO(Hao): check this barrier call\n            # self.barrier()\n\n            # Destroy the communicators and streams.\n            for comm_key, comms in self._dev_comm_map.items():\n                for c in comms:\n                    # FIXME(yonghao): comms created in XLA should be destroied\n                    if hasattr(c, \"destroy\"):\n                        c.destroy()\n                self._dev_comm_map[comm_key] = None\n\n        if self.rank == 0:\n            for comm_key in self._dev_comm_map:\n                assert not self._dev_comm_map[comm_key]\n                group_key = self._generate_group_key(comm_key)\n                self._destroy_store(group_key)\n        self._barrier_tensor = None\n        self._dev_comm_map = None\n        self._dev_streams_map = None\n\n    @classmethod\n    def backend(cls):\n        return Backend.NCCL\n\n    def allreduce(self, tensors, allreduce_options=AllReduceOptions()):\n        \"\"\"AllReduce tensors across the collective group following options.\n\n        Args:\n            tensors (List): the list of tensors to be reduced. Each tensor must\n                            reside on one GPU of the current process.\n            allreduce_options: allreduce options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def collective_fn(input_tensor, output_tensor, comm, stream):\n            comm.allReduce(\n                nccl_util.get_tensor_ptr(input_tensor),\n                nccl_util.get_tensor_ptr(output_tensor),\n                nccl_util.get_tensor_n_elements(input_tensor),\n                nccl_util.get_nccl_tensor_dtype(input_tensor),\n                nccl_util.get_nccl_reduce_op(allreduce_options.reduce_op),\n                stream.ptr)\n\n        self._collective(tensors, tensors, collective_fn)\n\n    def barrier(self, barrier_options=BarrierOptions()):\n        \"\"\"Blocks until all processes reach this barrier.\n\n        Args:\n            barrier_options: barrier options.\n\n        Returns:\n            None\n        \"\"\"\n        # Get the device list.\n        if self._used_gpu_indices:\n            devices = list(self._used_gpu_indices)\n        else:\n            devices = list(range(nccl_util.get_num_gpus()))\n        barrier_tensors = [None] * len(devices)\n        for i, d in enumerate(devices):\n            with nccl_util.Device(d):\n                barrier_tensors[i] = cupy.array([1])\n        self.allreduce(barrier_tensors)\n\n    def reduce(self, tensors, reduce_options=ReduceOptions()):\n        \"\"\"Reduce tensors to a destination gpu following options.\n\n        Args:\n            tensors (List): the list of tensors to be reduced, each tensor\n                            must reside on one gpu of the current process.\n            reduce_options: reduce options.\n\n        Returns:\n            None\n        \"\"\"\n        root_rank = (len(tensors) * reduce_options.root_rank +\n                     reduce_options.root_tensor)\n\n        def collective_fn(input_tensor, output_tensor, comm, stream):\n            comm.reduce(nccl_util.get_tensor_ptr(input_tensor),\n                        nccl_util.get_tensor_ptr(output_tensor),\n                        nccl_util.get_tensor_n_elements(input_tensor),\n                        nccl_util.get_nccl_tensor_dtype(input_tensor),\n                        nccl_util.get_nccl_reduce_op(reduce_options.reduce_op),\n                        root_rank, stream.ptr)\n\n        self._collective(tensors, tensors, collective_fn)\n\n    def broadcast_partialgpu(self,\n                             tensors,\n                             broadcast_options=BroadcastOptions()):\n        \"\"\"Broadcast tensors to all other gpus following options.\n        It will only involve subset of gpu in this worker.\n\n        Args:\n            tensors (List): tensors to be broadcast or received.\n            broadcast_options: broadcast options.\n\n        Returns:\n            None\n        \"\"\"\n        root_rank = 0\n\n        def collective_fn(input_tensor, output_tensor, comm, stream):\n            comm.broadcast(\n                nccl_util.get_tensor_ptr(input_tensor),\n                nccl_util.get_tensor_ptr(output_tensor),\n                broadcast_options.n_elements if broadcast_options.n_elements > 0\n                else nccl_util.get_tensor_n_elements(input_tensor),\n                nccl_util.get_nccl_tensor_dtype(input_tensor), root_rank,\n                stream.ptr)\n\n        _check_gpu_tensors(tensors)\n\n        key = broadcast_options.comm_key\n        comms = self._get_nccl_broadcast_communicator(\n            key, broadcast_options.world_size, broadcast_options.devices_ids,\n            broadcast_options.devices_global_rank)\n        streams = self._dev_streams_map[key]\n        events = self._dev_event_map[key]\n        self._sync_streams(broadcast_options.devices_ids, events, streams)\n\n        nccl_util.groupStart()\n        for i, tensor in enumerate(tensors):\n            collective_fn(tensor, tensor, comms[i], streams[i])\n        nccl_util.groupEnd()\n\n    def _get_nccl_broadcast_communicator(self,\n                                         comm_key,\n                                         world_size,\n                                         devices_ids,\n                                         devices_global_rank,\n                                         nccl_uid=None):\n        \"\"\"Create or retrieve an NCCL communicator for broadcast from cache.\n        Here we only use partial devices in a host, so we create this function\n        besides _get_nccl_collective_communicator.\n\n        If the communicator is found in cache, return the communicator. If not,\n        a communicator and a stream will be created and put in cache.\n\n        Args:\n            comm_key (str): the key to query the communicator cache.\n            world_size (int): the number of devices in this collective\n                              communicator.\n            devices_ids (List): a list of GPU devices of the current process\n                                that participates into the collective.\n            devices_global_rank (List): the corresponding global rank for device\n                                        in devices_ids.\n            nccl_uid : If it is None, we will create a nccl_uid here.\n\n        Returns:\n            communicator: the NCCL communicator corresponded to the devices.\n        \"\"\"\n        if not comm_key:\n            raise RuntimeError(\"Got empty communicator key.\")\n\n        # TODO(Hao): lock the _dev_comm_map here.\n        if comm_key in self._dev_comm_map:\n            return self._dev_comm_map[comm_key]\n\n        for d in devices_ids:\n            self._used_gpu_indices.add(d)\n\n        nccl_uid = self._rendezvous_nccl_uid(devices_global_rank[0], comm_key,\n                                             self.world_size, nccl_uid)\n\n        # Now create the communicators\n        comms = [None] * len(devices_ids)\n        streams = [None] * len(devices_ids)\n        events = [None] * len(devices_ids)\n        nccl_util.groupStart()\n        for i, (global_rank,\n                device_id) in enumerate(zip(devices_global_rank, devices_ids)):\n            with nccl_util.Device(device_id):\n                comms[i] = nccl_util.create_nccl_communicator(\n                    world_size, nccl_uid, global_rank)\n                streams[i] = get_stream_pool(device_id).get_stream()\n                events[i] = cupy.cuda.Event()\n        nccl_util.groupEnd()\n        self._dev_comm_map[comm_key] = comms\n        self._dev_streams_map[comm_key] = streams\n        self._dev_event_map[comm_key] = events\n        return comms\n\n    def broadcast(self, tensors, broadcast_options=BroadcastOptions()):\n        \"\"\"Broadcast tensors to all other gpus following options.\n\n        Args:\n            tensors (List): tensors to be broadcast or received.\n            broadcast_options: broadcast options.\n\n        Returns:\n            None\n        \"\"\"\n        root_rank = (len(tensors) * broadcast_options.root_rank +\n                     broadcast_options.root_tensor)\n\n        def collective_fn(input_tensor, output_tensor, comm, stream):\n            comm.broadcast(nccl_util.get_tensor_ptr(input_tensor),\n                           nccl_util.get_tensor_ptr(output_tensor),\n                           nccl_util.get_tensor_n_elements(input_tensor),\n                           nccl_util.get_nccl_tensor_dtype(input_tensor),\n                           root_rank, stream.ptr)\n\n        self._collective(tensors, tensors, collective_fn)\n\n    def allgather(self,\n                  tensor_lists,\n                  tensors,\n                  allgather_options=AllGatherOptions()):\n        \"\"\"Allgather tensors across gpus into a list of tensors.\n\n        Args:\n            tensor_lists (List[List[Tensor]]): allgathered tensors.\n            tensors: the list of tensors to allgather across the group.\n                     Each tensor must lolcate on a GPU of the process.\n            allgather_options: allgather options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def collective_fn(input_tensor, output_tensor, comm, stream):\n            comm.allGather(nccl_util.get_tensor_ptr(input_tensor),\n                           nccl_util.get_tensor_ptr(output_tensor),\n                           nccl_util.get_tensor_n_elements(input_tensor),\n                           nccl_util.get_nccl_tensor_dtype(input_tensor),\n                           stream.ptr)\n\n        _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)\n        output_flattened = [\n            _flatten_for_scatter_gather(tensor_list, copy=False)\n            for tensor_list in tensor_lists\n        ]\n\n        def postprocess_fn(stream):\n            # pylint: disable=unused-argument\n            # TODO(Hao): designate a copy stream.\n            for i, tensor_list in enumerate(tensor_lists):\n                for j, tensor in enumerate(tensor_list):\n                    nccl_util.copy_tensor(tensor, output_flattened[i][j])\n\n        self._collective(tensors,\n                         output_flattened,\n                         collective_fn,\n                         postprocess_fn=postprocess_fn)\n\n    def reducescatter(self,\n                      tensors,\n                      tensor_lists,\n                      reducescatter_options=ReduceScatterOptions()):\n        \"\"\"Reduce then scatter a list of tensors across the group.\n\n        Args:\n            tensors (List): the output tensors (could be unspecified), each\n                            located on a GPU of the current process.\n            tensor_lists (List[List]): the list of tensors to be reduced then\n                                       scattered.\n            reducescatter_options: reduce-scatter options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def collective_fn(input_tensor, output_tensor, comm, stream):\n            comm.reduceScatter(\n                nccl_util.get_tensor_ptr(input_tensor),\n                nccl_util.get_tensor_ptr(output_tensor),\n                nccl_util.get_tensor_n_elements(output_tensor),\n                nccl_util.get_nccl_tensor_dtype(output_tensor),\n                nccl_util.get_nccl_reduce_op(reducescatter_options.reduce_op),\n                stream.ptr)\n\n        _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists)\n        input_flattened = [\n            _flatten_for_scatter_gather(tensor_list, copy=False)\n            for tensor_list in tensor_lists\n        ]\n\n        def preprocess_fn(stream):\n            # pylint: disable=unused-argument\n            for i, tensor_list in enumerate(tensor_lists):\n                for j, tensor in enumerate(tensor_list):\n                    nccl_util.copy_tensor(input_flattened[i][j], tensor)\n\n        self._collective(input_flattened,\n                         tensors,\n                         collective_fn,\n                         preprocess_fn=preprocess_fn)\n\n    def send(self, tensors, send_options=SendOptions()):\n        \"\"\"Send a tensor to a destination gpu in the group.\n\n        Args:\n            tensors (List): the tensor to send.\n            send_options: send options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def p2p_fn(tensor, comm, stream, peer):\n            comm.send(\n                nccl_util.get_tensor_ptr(tensor),\n                send_options.n_elements if send_options.n_elements > 0 else\n                nccl_util.get_tensor_n_elements(tensor),\n                nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr)\n\n        self._point2point(tensors, p2p_fn, send_options.dst_rank,\n                          send_options.dst_gpu_index)\n\n    def recv(self, tensors, recv_options=RecvOptions()):\n        \"\"\"Receive a tensor from a source gpu in the group.\n\n        Args:\n            tensors (List): the received tensor.\n            recv_options: Receive options.\n\n        Returns:\n            None\n        \"\"\"\n\n        def p2p_fn(tensor, comm, stream, peer):\n            comm.recv(\n                nccl_util.get_tensor_ptr(tensor),\n                recv_options.n_elements if recv_options.n_elements > 0 else\n                nccl_util.get_tensor_n_elements(tensor),\n                nccl_util.get_nccl_tensor_dtype(tensor), peer, stream.ptr)\n\n        self._point2point(tensors, p2p_fn, recv_options.src_rank,\n                          recv_options.src_gpu_index)\n\n    def _get_nccl_collective_communicator(self, comm_key, device_list):\n        \"\"\"Create or retrieve an NCCL communicator from cache.\n\n        If the communicator is found in cache, return the communicator. If not,\n        a communicator and a stream will be created and put in cache.\n        TODO(Hao): this function is not thread-safe now.\n\n        Args:\n            comm_key (str): the key to query the communicator cache.\n            device_list (List): a list of GPU devices of the current process\n                                that participates into the collective.\n\n        Returns:\n            communicator: the NCCL communicator corresponded to the devices.\n        \"\"\"\n        if not comm_key:\n            raise RuntimeError(\"Got empty communicator key.\")\n\n        # TODO(Hao): lock the _dev_comm_map here.\n        if comm_key in self._dev_comm_map:\n            return self._dev_comm_map[comm_key]\n\n        for d in device_list:\n            self._used_gpu_indices.add(d)\n\n        nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key,\n                                             self.world_size)\n\n        # Now create the communicators\n        actual_world_size = len(device_list) * self.world_size\n        comms = [None] * len(device_list)\n        streams = [None] * len(device_list)\n        events = [None] * len(device_list)\n\n        nccl_util.groupStart()\n        for i, device in enumerate(device_list):\n            actual_rank = self.rank * len(device_list) + i\n            with nccl_util.Device(device):\n                comms[i] = nccl_util.create_nccl_communicator(\n                    actual_world_size, nccl_uid, actual_rank)\n                # request a stream from the pool\n                # note the device_idx is absolute index.\n                streams[i] = get_stream_pool(device).get_stream()\n                # TODO(Fu): double check the parameters\n                events[i] = cupy.cuda.Event()\n        nccl_util.groupEnd()\n        # TODO(Fu): lock\n        self._dev_comm_map[comm_key] = comms\n        self._dev_streams_map[comm_key] = streams\n        self._dev_event_map[comm_key] = events\n        return comms\n\n    def create_nccl_collective_communicator(self, devices):\n        key = _get_comm_key_from_devices(devices)\n        self._get_nccl_collective_communicator(key, devices)\n\n    def create_and_set_xla_communicators(self, devices, key):\n        comm_key = _get_comm_key_from_devices(devices)\n        if comm_key in self._xla_comm_keys:\n            return\n        for d in devices:\n            self._used_gpu_indices.add(d)\n\n        nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key,\n                                             self.world_size)\n\n        # Now create the communicators\n        actual_world_size = len(devices) * self.world_size\n        # FIXME: pass the start rank at the initial point\n        start_rank = self.rank * len(devices)\n        actual_ranks = [start_rank + i for i in range(len(devices))]\n        local_ids = list(range(len(devices)))\n        self.xla_comm_group.nccl_create_communicators(actual_world_size,\n                                                      actual_ranks, local_ids,\n                                                      nccl_uid)\n\n        xe.set_comm_group_info(key, self.xla_comm_group, nccl_uid)\n        self._xla_comm_keys.add(comm_key)\n\n    @staticmethod\n    def _sync_streams(device_list, events, streams):\n        \"\"\"Let NCCL streams wait for current streams for every device.\"\"\"\n        # TODO(Fu): recordStream besides calling this function?\n        if ENV.NCCL_USE_MULTISTREAM.val:\n            for i, device in enumerate(device_list):\n                with nccl_util.Device(device):\n                    events[i].record(cupy.cuda.get_current_stream())\n                    streams[i].wait_event(events[i])\n\n    def _get_nccl_p2p_communicator(self,\n                                   comm_key,\n                                   my_gpu_idx,\n                                   peer_rank,\n                                   peer_gpu_idx,\n                                   nccl_uid=None):\n        \"\"\"Create or retrieve an NCCL communicator for p2p tasks.\n\n        Note(Hao): this function is not thread-safe now.\n\n        Args:\n            comm_key (str): communicator key.\n            my_gpu_idx (int): the gpu index on the current process.\n            peer_rank (int): the rank of the destination process.\n            peer_gpu_idx (int): the gpu index on the peer process.\n        Returns:\n            communicator\n        \"\"\"\n        # pylint: disable=unused-argument\n        if not comm_key:\n            raise RuntimeError(\"Got empty communicator key.\")\n\n        # TODO(Hao): lock the _dev_comm_map here.\n        if comm_key in self._dev_comm_map:\n            return self._dev_comm_map[comm_key]\n\n        # Note (Hao): This is a bit complex so I decide to take a note here.\n        # Here we need to consider three cases:\n        # Case 1: src_rank != dst_rank, hence the send and recv happen on\n        # different process (actors/tasks); each process makes independent\n        # collective calls and manages corresponding communicators.\n        # Case 2: src_rank == dst_rank, src_gpu_idx == dst_gpu_idx; for\n        # this case, we simply throw a RuntimeError;\n        # Case 3: src_rank == dst_rank, src_gpu_idx != dst_gpu_idx, which\n        # means the send and recv will be called on the same process. We\n        # DO NOT support this case for now. We need to properly scope:\n        # (1) communicators creation, and\n        # (2) send/recv calls\n        # using groupStart(（ and groupEnd() calls to avoid deadlocks.\n        if self.rank < peer_rank:\n            my_p2p_rank = 0\n        elif self.rank > peer_rank:\n            my_p2p_rank = 1\n        else:\n            raise RuntimeError(\n                \"Send and recv happens on the same process! \"\n                \"alpa.collective does not support this case as of now. \"\n                \"Alternatively, consider doing GPU to GPU memcpy?\")\n        nccl_uid = self._rendezvous_nccl_uid(my_p2p_rank, comm_key, 2, nccl_uid)\n\n        # create the p2p communicators\n        with nccl_util.Device(my_gpu_idx):\n            comm = nccl_util.create_nccl_communicator(2, nccl_uid, my_p2p_rank)\n            stream = get_stream_pool(my_gpu_idx).get_stream()\n            event = cupy.cuda.Event()\n\n        self._dev_comm_map[comm_key] = [comm]\n        self._dev_streams_map[comm_key] = [stream]\n        self._dev_event_map[comm_key] = [event]\n        return [comm]\n\n    def _generate_group_key(self, comm_key):\n        \"\"\"Generate a unique key used to initialize the KV store.\n\n        The group key is a concatenation of the communicator key and\n        the group name, following: [comm_key]@[group_name].\n        \"\"\"\n        return comm_key + \"@\" + self.group_name\n\n    @staticmethod\n    def _destroy_store(group_key):\n        \"\"\"Destroy the KV store (Ray named actor).\n\n        Args:\n            group_key (str): the unique key to retrieve the KV store.\n\n        Returns:\n            None\n        \"\"\"\n        store_name = get_store_name(group_key)\n        try:\n            store = ray.get_actor(store_name)\n            ray.kill(store)\n        except ValueError:\n            logger.info(f\"The store with name {store_name} has been destroyed \"\n                        f\"somewhere else.\")\n\n    @staticmethod\n    def generate_nccl_uid():\n        group_uid = nccl_util.get_nccl_unique_id()\n        return group_uid\n\n    def _generate_nccl_uid(self, key):\n        \"\"\"Generate an NCCL unique ID for initializing communicators.\n\n        The method will also create a KV store using Ray named actor and store\n        the NCCLUniqueID in the store. The store needs to be garbage collected\n        when destroying the collective group.\n\n        Args:\n            key (str): the key of the .\n\n        Returns:\n            NCCLUniqueID (str): NCCL unique ID.\n        \"\"\"\n        group_uid = nccl_util.get_nccl_unique_id()\n        store_name = get_store_name(key)\n        # Avoid a potential circular dependency in ray/actor.py\n        from alpa.collective.util import NCCLUniqueIDStore  # pylint: disable=import-outside-toplevel\n        self._store = NCCLUniqueIDStore.options(\n            name=store_name).remote(store_name)\n        ray.get([self._store.set_id.remote(group_uid)])\n        return group_uid\n\n    def _collective(self,\n                    input_tensors,\n                    output_tensors,\n                    collective_fn,\n                    preprocess_fn=None,\n                    postprocess_fn=None):\n        \"\"\"A method to encapsulate all collective calls.\n\n        Args:\n            input_tensors: the list of the input tensors.\n            output_tensors: the list of the output tensors.\n            collective_fn: the collective function call.\n            preprocess_fn: preprocess procedures before collective calls.\n            postprocess_fn: postprocess procedures after collective calls.\n\n        Returns:\n            None\n        \"\"\"\n        _check_gpu_tensors(input_tensors)\n        _check_gpu_tensors(output_tensors)\n\n        devices = nccl_util.get_tensor_device_list(input_tensors)\n        key = _get_comm_key_from_devices(devices)\n        comms = self._get_nccl_collective_communicator(key, devices)\n        streams = self._dev_streams_map[key]\n        events = self._dev_event_map[key]\n\n        # TODO(Hao): sync streams and events\n        self._sync_streams(devices, events, streams)\n\n        # Make the collective call\n        if preprocess_fn:\n            preprocess_fn(streams)\n\n        nccl_util.groupStart()\n        # TODO(Fu): how to recordStreams as there are no library functions\n        # We also need to make sure input tensors are not freed before their\n        # usages on ncclStreams finish. This can be achieved by calling\n        # c10::cuda::CUDACachingAllocator::recordStream, which remembers the\n        # usage stream (ncclStream), creates an event on the usage stream\n        # when GC attempts to free the input tensor, and delays GC until that\n        # event is done.\n        for i, tensor in enumerate(input_tensors):\n            collective_fn(tensor, output_tensors[i], comms[i], streams[i])\n        nccl_util.groupEnd()\n        if postprocess_fn:\n            postprocess_fn(streams)\n\n    def create_p2p_communicator(self,\n                                my_gpu_idx: int,\n                                peer_rank: int,\n                                peer_gpu_idx: int,\n                                nccl_uid: str = None):\n        \"\"\"A public method to create p2p communicators\n\n        Args:\n            my_gpu_idx (int): the gpu index on self rank.\n            peer_rank (int): the rank of the peer process.\n            peer_gpu_idx (int): the index of the gpu on the peer process.\n            nccl_uid (str, optional): optionally to provide the NCCLUniqueID in\n                advance.\n\n        Returns:\n            None\n        \"\"\"\n        comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,\n                                           peer_gpu_idx)\n        self._get_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,\n                                        peer_gpu_idx, nccl_uid)\n\n    def create_nccl_broadcast_communicator(self,\n                                           comm_key,\n                                           world_size,\n                                           devices_ids,\n                                           devices_global_rank,\n                                           nccl_uid=None):\n        self._get_nccl_broadcast_communicator(comm_key, world_size, devices_ids,\n                                              devices_global_rank, nccl_uid)\n\n    def _point2point(self, tensors, p2p_fn, peer_rank: int, peer_gpu_idx: int):\n        \"\"\"A method to encapsulate all peer-to-peer calls (i.e., send/recv).\n\n        Args:\n            tensors: the tensor to send or receive.\n            p2p_fn: the p2p function call.\n            peer_rank (int): the rank of the peer process.\n            peer_gpu_idx (int): the index of the gpu on the peer process.\n\n        Returns:\n            None\n        \"\"\"\n        # check send/recv availability.\n        if nccl_util.get_nccl_runtime_version() < 2704:\n            raise RuntimeError(\"P2p send/recv requires NCCL >= 2.7.4. \"\n                               f\"Got '{nccl_util.get_nccl_runtime_version()}'.\")\n        _check_gpu_tensors(tensors)\n\n        # we currently only support single device to single device send/recv.\n        assert len(tensors) == 1\n        my_gpu_idx = nccl_util.get_tensor_device(tensors[0])\n        comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,\n                                           peer_gpu_idx)\n        comms = self._get_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,\n                                                peer_gpu_idx)\n        streams = self._dev_streams_map[comm_key]\n        events = self._dev_event_map[comm_key]\n\n        # TODO(Hao): sync streams and events\n        self._sync_streams([my_gpu_idx], events, streams)\n\n        # We have made sure that self.rank != peer_rank during API check.\n        peer_p2p_rank = 0 if self.rank > peer_rank else 1\n        for i, t in enumerate(tensors):\n            p2p_fn(t, comms[i], streams[i], peer_p2p_rank)\n\n    def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None):\n        group_key = self._generate_group_key(comm_key)\n        if rank == 0:\n            if nccl_uid is None:\n                nccl_uid = self._generate_nccl_uid(group_key)\n        else:\n            if nccl_uid is None:\n                rendezvous = Rendezvous(group_key)\n                rendezvous.meet()\n                nccl_uid = rendezvous.get_nccl_id()\n\n                # Recycle the NCCLUniqueIDStore named actor *pro-activately* to\n                # avoid named actor leak.\n                if rendezvous.get_access_counter() == max_counter:\n                    logger.debug(\n                        \"NCCLUniqueID has been broadcasted. The \"\n                        \"NCCLUniqueIDStore will go out of context and be \"\n                        \"destroyed.\")\n                    rendezvous.destroy_store()\n        return nccl_uid\n\n\ndef _flatten_for_scatter_gather(tensor_list, copy=False):\n    \"\"\"Flatten the tensor for gather/scatter operations.\n\n    Args:\n        tensor_list: the list of tensors to be scattered/gathered.\n        copy: whether the copy the tensors in tensor_list into the buffer.\n\n    Returns:\n        The flattened tensor buffer.\n    \"\"\"\n    if not tensor_list:\n        raise RuntimeError(\"Received an empty list.\")\n    t = tensor_list[0]\n    # note we need a cupy dtype here.\n    dtype = nccl_util.get_cupy_tensor_dtype(t)\n    buffer_shape = [len(tensor_list)] + nccl_util.get_tensor_shape(t)\n    device = nccl_util.get_tensor_device(t)\n    with nccl_util.Device(device):\n        buffer = cupy.empty(buffer_shape, dtype=dtype)\n    if copy:\n        for i, tensor in enumerate(tensor_list):\n            nccl_util.copy_tensor(buffer[i], tensor)\n    return buffer\n\n\ndef _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists):\n    \"\"\"Check the compatibility between tensor input and tensor list input.\"\"\"\n    if not tensors or not isinstance(tensors, list):\n        raise RuntimeError(\n            \"The first argument 'tensors' expects a list of tensors.\")\n    if not tensor_lists or not isinstance(tensor_lists, list):\n        raise RuntimeError(\"The second argument 'tensor_lists' \"\n                           \"expects a list of tensor list.\")\n    dtype = nccl_util.get_nccl_tensor_dtype(tensors[0])\n    shape = nccl_util.get_tensor_shape(tensors[0])\n    for i, tl in enumerate(tensor_lists):\n        # check all tensor in `tensors` match.\n        dt = nccl_util.get_nccl_tensor_dtype(tensors[i])\n        if dt != dtype:\n            raise RuntimeError(\n                \"All tensor operands to scatter/gather must \"\n                f\"have the same dtype. Got '{dt}' and '{dtype}'.\")\n        # Note: typically CCL libraries only requires they have the same\n        # number of elements; Here we make it more strict -- we require\n        # exact shape match.\n        s = nccl_util.get_tensor_shape(tensors[i])\n        if s != shape:\n            raise RuntimeError(\"All tensor operands to scatter/gather must \"\n                               f\"have the same shape. Got '{s}' and '{shape}'.\")\n        # check all tensors in `tensor_lists` match.\n        for t in tl:\n            # check dtype\n            dt = nccl_util.get_nccl_tensor_dtype(t)\n            if dt != dtype:\n                raise RuntimeError(\n                    \"All tensor operands to scatter/gather must \"\n                    f\"have the same dtype. Got '{dt}' and '{dtype}'.\")\n            s = nccl_util.get_tensor_shape(t)\n            if s != shape:\n                raise RuntimeError(\n                    \"All tensor operands to scatter/gather must \"\n                    f\"have the same shape. Got '{s}' and '{shape}'.\")\n\n\ndef _check_gpu_tensors(tensors):\n    \"\"\"Check all tensors are distributed on different GPUs.\"\"\"\n    if not tensors or not isinstance(tensors, list):\n        raise RuntimeError(\"'tensors' must be a nonempty list.\")\n    if len(tensors) > nccl_util.get_num_gpus():\n        raise RuntimeError(\"Tensor list cannot be larger than the number\"\n                           f\"of available GPUs. Got {len(tensors)} > \"\n                           f\"{nccl_util.get_num_gpus()}.\")\n    t0 = tensors[0]\n    dt = nccl_util.get_nccl_tensor_dtype(t0)\n    s = nccl_util.get_tensor_shape(t0)\n    d = nccl_util.get_tensor_device(t0)\n    for i, t in enumerate(tensors):\n        if i == 0:\n            continue\n        # We need to check the following:\n        # (1) tensor is cuda (already checked during API)\n        # (2) tensor dtype\n        # (3) tensor shape match\n        # (4) each tensor is on a different GPU\n        dtype = nccl_util.get_nccl_tensor_dtype(t)\n        if dt != dtype:\n            raise RuntimeError(\n                f\"Tensors must have identical dtypes. Got: '{dtype}'.\")\n        shape = nccl_util.get_tensor_shape(t)\n        if s != shape:\n            raise RuntimeError(\n                f\"Tensors must have identical shapes. Got: '{shape}'.\")\n        device = nccl_util.get_tensor_device(t)\n        if device == d:\n            raise RuntimeError(\"Tensor must be on distinct GPUs.\")\n\n\ndef _get_comm_key_from_devices(devices):\n    \"\"\"Return a key from a list of devices for collective calls.\n\n    For example, if the tensors are on gpus 0, 1, 2, 3,\n    then the key would be \"0,1,2,3\".\n\n    Args:\n        devices(list): a list of GPU device indices\n\n    Returns:\n        str: a string represents the key to query the communicator cache.\n\n    \"\"\"\n    return \",\".join([str(d) for d in devices])\n\n\ndef _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx):\n    \"\"\"Return a key given source and destination ranks for p2p tasks.\n\n    The p2p key is in the following form:\n                [min_rank]_[gpu_index]:[max_rank]_[gpu_index].\n\n    Args:\n        my_rank (int): the rank of the source process.\n        my_gpu_idx (int): the source gpu index on the process.\n        peer_rank (int): the rank of the destination process.\n        peer_gpu_idx (int): the destination gpu index on the process.\n\n    Returns:\n        comm_key (str): a string key to query the communication cache.\n    \"\"\"\n    if my_rank < peer_rank:\n        lower_key = str(my_rank) + \"_\" + str(my_gpu_idx)\n        higher_key = str(peer_rank) + \"_\" + str(peer_gpu_idx)\n    elif my_rank > peer_rank:\n        lower_key = str(peer_rank) + \"_\" + str(peer_gpu_idx)\n        higher_key = str(my_rank) + \"_\" + str(my_gpu_idx)\n    else:\n        raise RuntimeError(\n            \"Send and recv happens on the same process. alpa.collective \"\n            \"does not support this case as of now. Alternatively, consider \"\n            \"doing GPU to GPU memcpy?\")\n    comm_key = lower_key + \":\" + higher_key\n    return comm_key\n"
  },
  {
    "path": "alpa/collective/collective_group/nccl_util.py",
    "content": "\"\"\"Code to wrap some NCCL API calls.\"\"\"\nimport numpy\n\nfrom alpa.collective.types import ReduceOp, torch_available\nfrom alpa.global_env import global_config\n\nif global_config.has_cuda:\n    try:\n        import cupy\n        from cupy.cuda import nccl\n        from cupy.cuda import Device  # pylint: disable=unused-import\n        from cupy.cuda.nccl import get_version\n        from cupy.cuda.nccl import get_build_version\n        from cupy.cuda.nccl import NcclCommunicator\n        from cupy.cuda.nccl import groupStart  # pylint: disable=unused-import\n        from cupy.cuda.nccl import groupEnd  # pylint: disable=unused-import\n    except ImportError:\n        # pylint: disable=raise-missing-from\n        raise ImportError(\n            \"Please install nccl library following the above instructions\")\n\n    NCCL_REDUCE_OP_MAP = {\n        ReduceOp.SUM: nccl.NCCL_SUM,\n        ReduceOp.PRODUCT: nccl.NCCL_PROD,\n        ReduceOp.MIN: nccl.NCCL_MIN,\n        ReduceOp.MAX: nccl.NCCL_MAX,\n    }\n\n    # cupy types are the same with numpy types\n    NUMPY_NCCL_DTYPE_MAP = {\n        # INT types\n        numpy.uint8: nccl.NCCL_UINT8,\n        numpy.uint32: nccl.NCCL_UINT32,\n        numpy.uint64: nccl.NCCL_UINT64,\n        numpy.int8: nccl.NCCL_INT8,\n        numpy.int32: nccl.NCCL_INT32,\n        numpy.int64: nccl.NCCL_INT64,\n        # FLOAT types\n        numpy.half: nccl.NCCL_HALF,\n        numpy.float16: nccl.NCCL_FLOAT16,\n        numpy.float32: nccl.NCCL_FLOAT32,\n        numpy.float64: nccl.NCCL_FLOAT64,\n        numpy.double: nccl.NCCL_DOUBLE\n    }\n\nif torch_available():\n    import torch\n    import torch.utils.dlpack\n\n    if global_config.has_cuda:\n        TORCH_NCCL_DTYPE_MAP = {\n            # INT types\n            torch.int: nccl.NCCL_INT,\n            torch.uint8: nccl.NCCL_UINT8,\n            torch.int8: nccl.NCCL_INT8,\n            torch.int32: nccl.NCCL_INT32,\n            torch.int64: nccl.NCCL_INT64,\n            torch.long: nccl.NCCL_INT64,\n            # FLOAT types\n            torch.half: nccl.NCCL_HALF,\n            torch.float: nccl.NCCL_FLOAT,\n            torch.float16: nccl.NCCL_FLOAT16,\n            torch.float32: nccl.NCCL_FLOAT32,\n            torch.float64: nccl.NCCL_FLOAT64,\n            torch.double: nccl.NCCL_DOUBLE,\n        }\n\n    TORCH_NUMPY_DTYPE_MAP = {\n        # INT types\n        torch.int: numpy.int32,\n        torch.uint8: numpy.uint8,\n        torch.int8: numpy.int8,\n        torch.int32: numpy.int32,\n        torch.int64: numpy.int64,\n        torch.long: numpy.int64,\n        # FLOAT types\n        torch.half: numpy.half,\n        torch.float: numpy.float32,\n        torch.float16: numpy.float16,\n        torch.float32: numpy.float32,\n        torch.float64: numpy.float64,\n    }\n\n\ndef get_num_gpus():\n    \"\"\"Returns the number of compute-capable GPUs.\"\"\"\n    return cupy.cuda.runtime.getDeviceCount()\n\n\ndef get_nccl_build_version():\n    return get_build_version()\n\n\ndef get_nccl_runtime_version():\n    return get_version()\n\n\ndef get_nccl_unique_id():\n    return nccl.get_unique_id()\n\n\ndef create_nccl_communicator(world_size, nccl_unique_id, rank):\n    \"\"\"Create an NCCL communicator using NCCL APIs.\n\n    Args:\n        world_size (int): the number of processes of this communicator group.\n        nccl_unique_id (str): the NCCLUniqueID for this group.\n        rank (int): the rank of this process.\n    Returns:\n        comm (nccl.ncclComm_t): an NCCL communicator.\n    \"\"\"\n    comm = NcclCommunicator(world_size, nccl_unique_id, rank)\n    return comm\n\n\ndef get_nccl_reduce_op(reduce_op):\n    \"\"\"Map the reduce op to NCCL reduce op type.\n\n    Args:\n        reduce_op (ReduceOp): ReduceOp Enum (SUM/PRODUCT/MIN/MAX).\n    Returns:\n        (nccl.ncclRedOp_t): the mapped NCCL reduce op.\n    \"\"\"\n    if reduce_op not in NCCL_REDUCE_OP_MAP:\n        raise RuntimeError(f\"NCCL does not support reduce op: '{reduce_op}'.\")\n    return NCCL_REDUCE_OP_MAP[reduce_op]\n\n\ndef get_nccl_tensor_dtype(tensor):\n    \"\"\"Return the corresponded NCCL dtype given a tensor.\"\"\"\n    if isinstance(tensor, cupy.ndarray):\n        return NUMPY_NCCL_DTYPE_MAP[tensor.dtype.type]\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            return TORCH_NCCL_DTYPE_MAP[tensor.dtype]\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported GPU tensor types are: torch.Tensor, \"\n                     \"cupy.ndarray.\")\n\n\ndef get_cupy_tensor_dtype(tensor):\n    \"\"\"Return the corresponded Cupy dtype given a tensor.\"\"\"\n    if isinstance(tensor, cupy.ndarray):\n        return tensor.dtype.type\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            return TORCH_NUMPY_DTYPE_MAP[tensor.dtype]\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported GPU tensor types are: torch.Tensor, \"\n                     \"cupy.ndarray.\")\n\n\ndef get_tensor_ptr(tensor):\n    \"\"\"Return the pointer to the underlying memory storage of a tensor.\"\"\"\n    if isinstance(tensor, cupy.ndarray):\n        return tensor.data.ptr\n    if isinstance(tensor, numpy.ndarray):\n        return tensor.data\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            if not tensor.is_cuda:\n                raise RuntimeError(\"Torch tensor must be on GPU \"\n                                   \"when using NCCL collectives.\")\n            return tensor.data_ptr()\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported GPU tensor types are: torch.Tensor, \"\n                     \"cupy.ndarray.\")\n\n\ndef get_tensor_n_elements(tensor):\n    \"\"\"Return the number of elements in a tensor.\"\"\"\n    if isinstance(tensor, (cupy.ndarray, numpy.ndarray)):\n        return tensor.size\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            return torch.numel(tensor)\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported GPU tensor types are: torch.Tensor, \"\n                     \"cupy.ndarray.\")\n\n\ndef get_tensor_shape(tensor):\n    \"\"\"Return the shape of the tensor as a list.\"\"\"\n    if isinstance(tensor, cupy.ndarray):\n        return list(tensor.shape)\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            return list(tensor.size())\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported GPU tensor types are: torch.Tensor, \"\n                     \"cupy.ndarray.\")\n\n\ndef get_tensor_strides(tensor):\n    \"\"\"Return the strides of the tensor as a list.\"\"\"\n    if isinstance(tensor, cupy.ndarray):\n        return [\n            int(stride / tensor.dtype.itemsize) for stride in tensor.strides\n        ]\n    if torch_available():\n        if isinstance(tensor, torch.Tensor):\n            return list(tensor.stride())\n    raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}. \"\n                     \"Supported GPU tensor types are: torch.Tensor, \"\n                     \"cupy.ndarray.\")\n\n\ndef get_tensor_device(tensor):\n    \"\"\"Return the GPU index of a tensor.\"\"\"\n    if isinstance(tensor, cupy.ndarray):\n        try:\n            device = tensor.device.id\n        except AttributeError as e:\n            raise RuntimeError(\"The tensor is not on a valid GPU.\") from e\n    elif torch_available() and isinstance(tensor, torch.Tensor):\n        device = tensor.device.index\n        if not isinstance(device, int):\n            raise RuntimeError(\"The tensor is not on a valid GPU.\")\n    else:\n        raise ValueError(f\"Unsupported tensor type. Got: {type(tensor)}.\")\n    return device\n\n\ndef copy_tensor(dst_tensor, src_tensor):\n    \"\"\"Copy the content from src_tensor to dst_tensor.\n\n    Args:\n        dst_tensor: the tensor to copy from.\n        src_tensor: the tensor to copy to.\n\n    Returns:\n        None\n    \"\"\"\n    copied = True\n    if (isinstance(dst_tensor, cupy.ndarray) and\n            isinstance(src_tensor, cupy.ndarray)):\n        cupy.copyto(dst_tensor, src_tensor)\n    elif torch_available():\n        if isinstance(dst_tensor, torch.Tensor) and isinstance(\n                src_tensor, torch.Tensor):\n            dst_tensor.copy_(src_tensor)\n        elif isinstance(dst_tensor, torch.Tensor) and isinstance(\n                src_tensor, cupy.ndarray):\n            t = torch.utils.dlpack.from_dlpack(src_tensor.toDlpack())\n            dst_tensor.copy_(t)\n        elif isinstance(dst_tensor, cupy.ndarray) and isinstance(\n                src_tensor, torch.Tensor):\n            t = cupy.fromDlpack(torch.utils.dlpack.to_dlpack(src_tensor))\n            cupy.copyto(dst_tensor, t)\n        else:\n            copied = False\n    else:\n        copied = False\n    if not copied:\n        raise ValueError(\n            f\"Unsupported tensor type. Got: {type(dst_tensor)} and \"\n            f\"{type(src_tensor)}. Supported GPU tensor types are: \"\n            f\"torch.Tensor, cupy.ndarray.\")\n\n\ndef get_tensor_device_list(tensors):\n    \"\"\"Returns the gpu devices of the list of input tensors.\n\n    Args:\n        tensors(list): a list of tensors, each locates on a GPU.\n\n    Returns:\n        list: the list of GPU devices.\n\n    \"\"\"\n    if not isinstance(tensors, list):\n        raise RuntimeError(\n            \"Expect a list of tensors each locates on a GPU device. \"\n            f\"Got: '{type(tensors)}'.\")\n    devices = [get_tensor_device(t) for t in tensors]\n    return devices\n"
  },
  {
    "path": "alpa/collective/collective_group/xla_nccl_collective_group.py",
    "content": "\"\"\"NCCL-based collective operations with apis from xla extension.\"\"\"\nimport logging\n\nimport ray\nfrom jax._src.lib import xla_extension as xe\n\nfrom alpa.collective.collective_group import xla_nccl_util\nfrom alpa.collective.collective_group.base_collective_group import BaseGroup, Rendezvous\nfrom alpa.collective.const import get_store_name\nfrom alpa.collective.types import (Backend, BroadcastOptions, AllReduceOptions,\n                                   BarrierOptions, ReduceOptions,\n                                   AllGatherOptions, ReduceScatterOptions,\n                                   SendOptions, RecvOptions)\n\nfrom alpa.global_env import global_config\nfrom alpa.monkey_patch import override_get_backend\n\nlogger = logging.getLogger(__name__)\n\n\nclass XLANCCLGroup(BaseGroup):\n    \"\"\"NCCL-based collective operations with apis from xla extension.\"\"\"\n\n    def __init__(self, world_size, rank, group_name):\n        \"\"\"Init an NCCL collective group.\"\"\"\n        super().__init__(world_size, rank, group_name)\n\n        self.use_default_stream = not global_config.enable_overlapping\n        self._dev_comm_uids = {}\n\n        # record the used GPU IDs.\n        self._used_gpu_indices = set()\n\n        backend = override_get_backend()\n        self.xla_comm_group = xe.CommGroup(backend)\n\n        if xla_nccl_util.get_nccl_runtime_version() < 2704:\n            logger.warning(\"NCCL send/recv calls requires NCCL>=2.7.4\")\n\n    def destroy_group(self):\n        \"\"\"Destroy the group and release NCCL communicators.\"\"\"\n        if len(self._dev_comm_uids) > 0:\n\n            # Destroy the communicators and streams.\n            for comm_key in self._dev_comm_uids:\n                key = self._dev_comm_uids[comm_key]\n                self.xla_comm_group.nccl_destroy_comms(key)\n\n        if self.rank == 0:\n            for comm_key in self._dev_comm_uids:\n                group_key = self._generate_group_key(comm_key)\n                self._destroy_store(group_key)\n        self._dev_comm_uids = None\n\n    # functions to get communicator:\n    def create_nccl_broadcast_communicator(self,\n                                           comm_key,\n                                           world_size,\n                                           devices_ids,\n                                           devices_global_rank,\n                                           nccl_uid=None):\n        \"\"\"Create or retrieve a list of NCCL communicators for\n        broadcast from cache. Here we only use partial devices in a host, so\n        we create this function besides _create_nccl_collective_communicator.\n\n        If the communicator is found in cache, return the communicator. If not,\n        a communicator and a stream will be created and put in cache.\n\n        Args:\n            comm_key (str): the key to query the communicator cache.\n            world_size (int): the number of devices in this collective\n                              communicator.\n            devices_ids (List): a list of GPU devices of the current process\n                                that participates into the collective.\n            devices_global_rank (List): the corresponding global rank for\n                                device in devices_ids.\n            nccl_uid : If it is None, we will create a nccl_uid here.\n\n        Returns:\n            communicator: the NCCL communicator corresponded to the devices.\n        \"\"\"\n        if not comm_key:\n            raise RuntimeError(\"Got empty communicator key.\")\n\n        # TODO(Hao): lock the _dev_comm_map here.\n        if comm_key in self._dev_comm_uids:\n            return\n\n        for d in devices_ids:\n            self._used_gpu_indices.add(d)\n\n        nccl_uid = self._rendezvous_nccl_uid(devices_global_rank[0], comm_key,\n                                             self.world_size, nccl_uid)\n\n        self.xla_comm_group.nccl_create_communicators(world_size,\n                                                      devices_global_rank,\n                                                      devices_ids, nccl_uid)\n        self._dev_comm_uids[comm_key] = nccl_uid\n\n    def _create_nccl_collective_communicator(self, comm_key, device_list):\n        \"\"\"Create or retrieve an NCCL communicator from cache.\n\n        If the communicator is found in cache, return the communicator. If not,\n        a communicator and a stream will be created and put in cache.\n        TODO(Hao): this function is not thread-safe now.\n\n        Args:\n            comm_key (str): the key to query the communicator cache.\n            device_list (List): a list of GPU devices of the current process\n                                that participates into the collective.\n\n        Returns:\n            communicator: the NCCL communicator corresponded to the devices.\n        \"\"\"\n        if not comm_key:\n            raise RuntimeError(\"Got empty communicator key.\")\n\n        # TODO(Hao): lock the _dev_comm_map here.\n        if comm_key in self._dev_comm_uids:\n            return\n\n        for d in device_list:\n            self._used_gpu_indices.add(d)\n\n        nccl_uid = self._rendezvous_nccl_uid(self.rank, comm_key,\n                                             self.world_size)\n\n        # Now create the communicators\n        actual_world_size = len(device_list) * self.world_size\n\n        # FIXME: pass the start rank at the initial point\n        start_rank = self.rank * len(device_list)\n        actual_ranks = [start_rank + i for i in range(len(device_list))]\n        local_ids = list(range(len(device_list)))\n        self.xla_comm_group.nccl_create_communicators(actual_world_size,\n                                                      actual_ranks, local_ids,\n                                                      nccl_uid)\n\n        self._dev_comm_uids[comm_key] = nccl_uid\n\n    def create_nccl_collective_communicator(self, devices):\n        key = _get_comm_key_from_devices(devices)\n        self._create_nccl_collective_communicator(key, devices)\n\n    def _create_nccl_p2p_communicator(self,\n                                      comm_key,\n                                      my_gpu_idx,\n                                      peer_rank,\n                                      peer_gpu_idx,\n                                      nccl_uid=None):\n        \"\"\"Create or retrieve an NCCL communicator for p2p tasks.\n\n        Args:\n            comm_key (str): communicator key.\n            my_gpu_idx (int): the gpu index on the current process.\n            peer_rank (int): the rank of the destination process.\n            peer_gpu_idx (int): the gpu index on the peer process.\n        Returns:\n            communicator\n        \"\"\"\n        # pylint: disable=unused-argument\n        if not comm_key:\n            raise RuntimeError(\"Got empty communicator key.\")\n\n        # TODO(Hao): lock the _dev_comm_map here.\n        if comm_key in self._dev_comm_uids:\n            return\n\n        # Note (Hao): This is a bit complex so I decide to take a note here.\n        # Here we need to consider three cases:\n        # Case 1: src_rank != dst_rank, hence the send and recv happen on\n        # different process (actors/tasks); each process makes independent\n        # collective calls and manages corresponding communicators.\n        # Case 2: src_rank == dst_rank, src_gpu_idx == dst_gpu_idx; for\n        # this case, we simply throw a RuntimeError;\n        # Case 3: src_rank == dst_rank, src_gpu_idx != dst_gpu_idx, which\n        # means the send and recv will be called on the same process. We\n        # DO NOT support this case for now. We need to properly scope:\n        # (1) communicators creation, and\n        # (2) send/recv calls\n        # using groupStart(（ and groupEnd() calls to avoid deadlocks.\n        if self.rank < peer_rank:\n            my_p2p_rank = 0\n        elif self.rank > peer_rank:\n            my_p2p_rank = 1\n        else:\n            raise RuntimeError(\n                \"Send and recv happens on the same process! \"\n                \"alpa.collective does not support this case as of now. \"\n                \"Alternatively, consider doing GPU to GPU memcpy?\")\n        nccl_uid = self._rendezvous_nccl_uid(my_p2p_rank, comm_key, 2, nccl_uid)\n\n        self.xla_comm_group.nccl_create_communicators(2, [my_p2p_rank],\n                                                      [my_gpu_idx], nccl_uid)\n        self._dev_comm_uids[comm_key] = nccl_uid\n\n    def create_p2p_communicator(self,\n                                my_gpu_idx: int,\n                                peer_rank: int,\n                                peer_gpu_idx: int,\n                                nccl_uid: str = None):\n        \"\"\"A public method to create p2p communicators\n\n        Args:\n            my_gpu_idx (int): the gpu index on self rank.\n            peer_rank (int): the rank of the peer process.\n            peer_gpu_idx (int): the index of the gpu on the peer process.\n            nccl_uid (str, optional): optionally to provide the NCCLUniqueID in\n                advance.\n\n        Returns:\n            None\n        \"\"\"\n        comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,\n                                           peer_gpu_idx)\n        self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,\n                                           peer_gpu_idx, nccl_uid)\n\n    def create_and_set_xla_communicators(self, devices, key):\n        comm_key = _get_comm_key_from_devices(devices)\n        self._create_nccl_collective_communicator(comm_key, devices)\n        nccl_uid = self._dev_comm_uids[comm_key]\n        xe.set_comm_group_info(key, self.xla_comm_group, nccl_uid)\n\n    # communicate operations\n    def broadcast_partialgpu(self,\n                             tensors,\n                             broadcast_options=BroadcastOptions()):\n        \"\"\"Broadcast tensors to all other gpus following options.\n        It will only involve subset of gpu in this worker.\n\n        Args:\n            tensors (List): tensors to be broadcast or received.\n            broadcast_options: broadcast options.\n\n        Returns:\n            None\n        \"\"\"\n        root_rank = 0\n\n        self.create_nccl_broadcast_communicator(\n            broadcast_options.comm_key, broadcast_options.world_size,\n            broadcast_options.devices_ids,\n            broadcast_options.devices_global_rank)\n        key = self._dev_comm_uids[broadcast_options.comm_key]\n        is_receiver = broadcast_options.devices_global_rank[0] != 0\n        self.xla_comm_group.nccl_broadcast_partial_gpus(\n            key, tensors, broadcast_options.local_start_pos_list,\n            broadcast_options.n_elements, root_rank, is_receiver,\n            self.use_default_stream)\n\n    def send(self, tensors, send_options=SendOptions()):\n        \"\"\"Send a tensor to a destination gpu in the group.\n\n        Args:\n            tensors (List): the tensor to send.\n            send_options: send options.\n\n        Returns:\n            None\n        \"\"\"\n\n        buffer = tensors[0]\n        my_gpu_idx = xe.get_buffer_device_id(buffer)\n        peer_rank, peer_gpu_idx = \\\n            send_options.dst_rank, send_options.dst_gpu_index\n        comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,\n                                           peer_gpu_idx)\n        self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,\n                                           peer_gpu_idx)\n\n        key = self._dev_comm_uids[comm_key]\n        peer_p2p_rank = 0 if self.rank > peer_rank else 1\n        self.xla_comm_group.nccl_send(key, buffer, send_options.start_pos,\n                                      send_options.n_elements, peer_p2p_rank,\n                                      self.use_default_stream)\n\n    def recv(self, tensors, recv_options=RecvOptions()):\n        \"\"\"Receive a tensor from a source gpu in the group.\n\n        Args:\n            tensors (List): the received tensor.\n            recv_options: Receive options.\n\n        Returns:\n            None\n        \"\"\"\n\n        buffer = tensors[0]\n        my_gpu_idx = xe.get_buffer_device_id(buffer)\n        peer_rank, peer_gpu_idx = \\\n            recv_options.src_rank, recv_options.src_gpu_index\n        comm_key = _get_comm_key_send_recv(self.rank, my_gpu_idx, peer_rank,\n                                           peer_gpu_idx)\n        self._create_nccl_p2p_communicator(comm_key, my_gpu_idx, peer_rank,\n                                           peer_gpu_idx)\n\n        peer_p2p_rank = 0 if self.rank > peer_rank else 1\n        key = self._dev_comm_uids[comm_key]\n        self.xla_comm_group.nccl_recv(key, buffer, recv_options.start_pos,\n                                      recv_options.n_elements, peer_p2p_rank,\n                                      self.use_default_stream)\n\n    def record_events(self, uuids, num_devices, is_send):\n        \"\"\"Record events for all devices on send/recv streams.\"\"\"\n        self.xla_comm_group.record_events(uuids, num_devices, is_send)\n\n    def wait_events(self, uuids, num_devices, is_send):\n        \"\"\"Wait events for all devices on send/recv streams.\"\"\"\n        self.xla_comm_group.wait_events(uuids, num_devices, is_send)\n\n    def comm_wait_compute(self, is_send, is_compute, device_id):\n        self.xla_comm_group.comm_wait_compute(is_send, is_compute, device_id)\n\n    def compute_wait_comm(self, is_send, is_compute, device_id):\n        self.xla_comm_group.compute_wait_comm(is_send, is_compute, device_id)\n\n    # helper functions to build communicatiors\n    def _generate_group_key(self, comm_key):\n        \"\"\"Generate a unique key used to initialize the KV store.\n\n        The group key is a concatenation of the communicator key and\n        the group name, following: [comm_key]@[group_name].\n        \"\"\"\n        return comm_key + \"@\" + self.group_name\n\n    @staticmethod\n    def _destroy_store(group_key):\n        \"\"\"Destroy the KV store (Ray named actor).\n\n        Args:\n            group_key (str): the unique key to retrieve the KV store.\n\n        Returns:\n            None\n        \"\"\"\n        store_name = get_store_name(group_key)\n        try:\n            store = ray.get_actor(store_name)\n            ray.kill(store)\n        except ValueError:\n            logger.info(f\"The store with name {store_name} has been destroyed \"\n                        f\"somewhere else.\")\n\n    @staticmethod\n    def generate_nccl_uid():\n        group_uid = xla_nccl_util.get_nccl_unique_id()\n        return group_uid\n\n    def _generate_nccl_uid(self, key):\n        \"\"\"Generate an NCCL unique ID for initializing communicators.\n\n        The method will also create a KV store using Ray named actor and store\n        the NCCLUniqueID in the store. The store needs to be garbage collected\n        when destroying the collective group.\n\n        Args:\n            key (str): the key for storage of NCCLUniqueID.\n\n        Returns:\n            NCCLUniqueID (str): NCCL unique ID.\n        \"\"\"\n        group_uid = xla_nccl_util.get_nccl_unique_id()\n        store_name = get_store_name(key)\n        # Avoid a potential circular dependency in ray/actor.py\n        from alpa.collective.util import NCCLUniqueIDStore  # pylint: disable=import-outside-toplevel\n        self._store = NCCLUniqueIDStore.options(\n            name=store_name).remote(store_name)\n        ray.get([self._store.set_id.remote(group_uid)])\n        return group_uid\n\n    # unimplemented\n    def allreduce(self, tensors, allreduce_options=AllReduceOptions()):\n        raise NotImplementedError()\n\n    def barrier(self, barrier_options=BarrierOptions()):\n        raise NotImplementedError()\n\n    def reduce(self, tensors, reduce_options=ReduceOptions()):\n        raise NotImplementedError()\n\n    def allgather(self,\n                  tensor_lists,\n                  tensors,\n                  allgather_options=AllGatherOptions()):\n        raise NotImplementedError()\n\n    def broadcast(self, tensors, broadcast_options=BroadcastOptions()):\n        raise NotImplementedError()\n\n    def reducescatter(self,\n                      tensors,\n                      tensor_lists,\n                      reducescatter_options=ReduceScatterOptions()):\n        raise NotImplementedError()\n\n    @classmethod\n    def backend(cls):\n        return Backend.NCCL\n\n    def _rendezvous_nccl_uid(self, rank, comm_key, max_counter, nccl_uid=None):\n        group_key = self._generate_group_key(comm_key)\n        if rank == 0:\n            if nccl_uid is None:\n                nccl_uid = self._generate_nccl_uid(group_key)\n        else:\n            if nccl_uid is None:\n                rendezvous = Rendezvous(group_key)\n                rendezvous.meet(timeout_s=3000)\n                nccl_uid = rendezvous.get_nccl_id()\n\n                # Recycle the NCCLUniqueIDStore named actor *pro-activately* to\n                # avoid named actor leak.\n                if rendezvous.get_access_counter() == max_counter:\n                    logger.debug(\n                        \"NCCLUniqueID has been broadcasted. The \"\n                        \"NCCLUniqueIDStore will go out of context and be \"\n                        \"destroyed.\")\n                    rendezvous.destroy_store()\n        return nccl_uid\n\n\ndef _get_comm_key_from_devices(devices):\n    \"\"\"Return a key from a list of devices for collective calls.\n\n    For example, if the tensors are on gpus 0, 1, 2, 3,\n    then the key would be \"0,1,2,3\".\n\n    Args:\n        devices(list): a list of GPU device indices\n\n    Returns:\n        str: a string represents the key to query the communicator cache.\n\n    \"\"\"\n    return \",\".join([str(d) for d in devices])\n\n\ndef _get_comm_key_send_recv(my_rank, my_gpu_idx, peer_rank, peer_gpu_idx):\n    \"\"\"Return a key given source and destination ranks for p2p tasks.\n\n    The p2p key is in the following form:\n                [min_rank]_[gpu_index]:[max_rank]_[gpu_index].\n\n    Args:\n        my_rank (int): the rank of the source process.\n        my_gpu_idx (int): the source gpu index on the process.\n        peer_rank (int): the rank of the destination process.\n        peer_gpu_idx (int): the destination gpu index on the process.\n\n    Returns:\n        comm_key (str): a string key to query the communication cache.\n    \"\"\"\n    if my_rank < peer_rank:\n        lower_key = str(my_rank) + \"_\" + str(my_gpu_idx)\n        higher_key = str(peer_rank) + \"_\" + str(peer_gpu_idx)\n    elif my_rank > peer_rank:\n        lower_key = str(peer_rank) + \"_\" + str(peer_gpu_idx)\n        higher_key = str(my_rank) + \"_\" + str(my_gpu_idx)\n    else:\n        raise RuntimeError(\n            \"Send and recv happens on the same process. alpa.collective \"\n            \"does not support this case as of now. Alternatively, consider \"\n            \"doing GPU to GPU memcpy?\")\n    comm_key = lower_key + \":\" + higher_key\n    return comm_key\n"
  },
  {
    "path": "alpa/collective/collective_group/xla_nccl_util.py",
    "content": "\"\"\"Code to wrap NCCL API calls from XLA extension.\"\"\"\nfrom jax._src.lib import xla_extension as xe\n\n\ndef get_nccl_runtime_version():\n    return xe.nccl_get_version()\n\n\ndef get_nccl_unique_id():\n    return xe.nccl_get_unique_id()\n"
  },
  {
    "path": "alpa/collective/const.py",
    "content": "\"\"\"\nConstants.\n\nContains constants used to setup collective groups.\n\"\"\"\nimport hashlib\nimport os\nfrom enum import Enum, auto\n\n\ndef get_store_name(group_name):\n    \"\"\"Generate the unique name for the NCCLUniqueID store (named actor).\n\n    Args:\n        group_name (str): unique user name for the store.\n    Return:\n        str: MD5-hexlified name for the store.\n    \"\"\"\n    if not group_name:\n        raise ValueError(\"group_name is None.\")\n    hexlified_name = hashlib.md5(group_name.encode()).hexdigest()\n    return hexlified_name\n\n\nclass ENV(Enum):\n    \"\"\"Environment variables.\"\"\"\n\n    NCCL_USE_MULTISTREAM = auto(), lambda v: (v or \"True\") == \"True\"\n\n    @property\n    def val(self):\n        \"\"\"Return the output of the lambda against the system's env value.\"\"\"\n        _, default_fn = self.value  # pylint: disable=unpacking-non-sequence\n        return default_fn(os.getenv(self.name))\n"
  },
  {
    "path": "alpa/collective/requirements.txt",
    "content": "cupy-cuda111"
  },
  {
    "path": "alpa/collective/types.py",
    "content": "\"\"\"Types conversion between different backends.\"\"\"\nfrom enum import Enum\nfrom dataclasses import dataclass\nfrom datetime import timedelta\n\n_NUMPY_AVAILABLE = True\n_TORCH_AVAILABLE = False\n_CUPY_AVAILABLE = True\n\ntry:\n    import cupy as cp  # pylint: disable=unused-import\nexcept ImportError:\n    _CUPY_AVAILABLE = False\n\n\ndef cupy_available():\n    return _CUPY_AVAILABLE\n\n\ndef torch_available():\n    return _TORCH_AVAILABLE\n\n\nclass Backend:\n    \"\"\"A class to represent different backends.\"\"\"\n    NCCL = \"nccl\"\n    MPI = \"mpi\"\n    GLOO = \"gloo\"\n    UNRECOGNIZED = \"unrecognized\"\n\n    def __new__(cls, name: str):\n        backend = getattr(Backend, name.upper(), Backend.UNRECOGNIZED)\n        if backend == Backend.UNRECOGNIZED:\n            raise ValueError(f\"Unrecognized backend: '{name}'. \"\n                             \"Only NCCL is supported\")\n        if backend == Backend.MPI:\n            raise RuntimeError(\"Ray does not support MPI backend.\")\n        return backend\n\n\nclass ReduceOp(Enum):\n    SUM = 0\n    PRODUCT = 1\n    MIN = 2\n    MAX = 3\n\n\nunset_timeout_ms = timedelta(milliseconds=-1)\n\n\n@dataclass\nclass AllReduceOptions:\n    reduce_op = ReduceOp.SUM\n    timeout_ms = unset_timeout_ms\n\n\n@dataclass\nclass BarrierOptions:\n    timeout_ms = unset_timeout_ms\n\n\n@dataclass\nclass ReduceOptions:\n    reduce_op = ReduceOp.SUM\n    root_rank = 0\n    root_tensor = 0  # index for multi-gpu reduce operations\n    timeout_ms = unset_timeout_ms\n\n\n@dataclass\nclass AllGatherOptions:\n    timeout_ms = unset_timeout_ms\n\n\n#\n# @dataclass\n# class GatherOptions:\n#     root_rank = 0\n#     timeout = unset_timeout\n\n\n@dataclass\nclass BroadcastOptions:\n    comm_key = \"\"\n    world_size = 0\n    devices_ids = []\n    devices_global_rank = []\n    n_elements = 0\n    timeout_ms = unset_timeout_ms\n    local_start_pos_list = []\n\n\n@dataclass\nclass ReduceScatterOptions:\n    reduce_op = ReduceOp.SUM\n    timeout_ms = unset_timeout_ms\n\n\n@dataclass\nclass SendOptions:\n    dst_rank = 0\n    dst_gpu_index = 0\n    n_elements = 0\n    timeout_ms = unset_timeout_ms\n    start_pos = 0\n\n\n@dataclass\nclass RecvOptions:\n    src_rank = 0\n    src_gpu_index = 0\n    n_elements = 0\n    unset_timeout_ms = unset_timeout_ms\n    start_pos = 0\n"
  },
  {
    "path": "alpa/collective/util.py",
    "content": "\"\"\"Some utility class for Collectives.\"\"\"\nimport logging\nimport ray\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.DEBUG)\n\n\n@ray.remote\nclass NCCLUniqueIDStore:\n    \"\"\"NCCLUniqueID Store as a named actor class.\n\n    Args:\n        name (str): the unique name for this named actor.\n\n    Attributes:\n        name (str): the unique name for this named actor.\n        nccl_id (str): the NCCLUniqueID held in this store.\n    \"\"\"\n\n    def __init__(self, name):\n        self.name = name\n        self.nccl_id = None\n\n        # A counter for this actor to auto-destory itself.\n        self.access_counter = 1\n\n    def set_id(self, uid):\n        \"\"\"\n        Initialize the NCCL unique ID for this store.\n\n        Args:\n            uid (str): the unique ID generated via the NCCL get_unique_id API.\n\n        Returns:\n            None\n        \"\"\"\n        self.nccl_id = uid\n        return self.nccl_id\n\n    def get_id(self):\n        \"\"\"Get the NCCL unique ID held in this store.\"\"\"\n        if not self.nccl_id:\n            logger.debug(\"The NCCL ID has not been set yet \"\n                         f\"for store {self.name} by rank-0 process.\")\n            return None\n        else:\n            self.access_counter += 1\n            return self.nccl_id\n\n    def get_access_counter(self):\n        return self.access_counter\n\n\n@ray.remote\nclass Info:\n    \"\"\"Store the group information created via `create_collective_group`.\n\n    Note: Should be used as a NamedActor.\n    \"\"\"\n\n    def __init__(self):\n        self.ids = None\n        self.world_size = -1\n        self.rank = -1\n        self.backend = None\n        self.access_counter = 0\n\n    def set_info(self, ids, world_size, rank, backend):\n        \"\"\"Store collective information.\"\"\"\n        self.ids = ids\n        self.world_size = world_size\n        self.rank = rank\n        self.backend = backend\n\n    def get_info(self):\n        \"\"\"Get previously stored collective information.\"\"\"\n        self.access_counter += 1\n        return self.ids, self.world_size, self.rank, self.backend\n\n    def get_access_counter(self):\n        return self.access_counter\n"
  },
  {
    "path": "alpa/collective/worker_nccl_util.py",
    "content": "\"\"\"Unified Nccl APIs for cross-mesh resharding.\"\"\"\nfrom typing import Sequence\n\nimport alpa.collective.worker_nccl_util_cupy as cupy_impl\nimport alpa.collective.worker_nccl_util_xla as xla_impl\nfrom alpa.global_env import global_config\n\n\ndef _switch_impl(cupy_fn, xla_fn, *args):\n    if global_config.nccl_mode == \"cupy\":\n        return cupy_fn(*args)\n    elif global_config.nccl_mode == \"xla_extension\":\n        return xla_fn(*args)\n    else:\n        raise ValueError(f\"nccl mode {global_config.nccl_mode} is illegal\")\n\n\ndef send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice],\n              dst_rank: int, dst_gpu_idx: int, group_name: str):\n    return _switch_impl(cupy_impl.send_tile, xla_impl.send_tile, worker, uuid,\n                        device_id, offset, dst_rank, dst_gpu_idx, group_name)\n\n\ndef recv_tile(worker, uuid: int, device_id: int,\n              indices_in_dst_tile: Sequence[slice], src_rank: int,\n              src_gpu_idx: int, group_name: str):\n    return _switch_impl(cupy_impl.recv_tile, xla_impl.recv_tile, worker, uuid,\n                        device_id, indices_in_dst_tile, src_rank, src_gpu_idx,\n                        group_name)\n\n\ndef broadcast(worker, uuid: int, comm_key: str, world_size: int,\n              devices_ids: Sequence[int], devices_global_rank: Sequence[int],\n              tensor_slices: Sequence[Sequence[slice]], group_name: str):\n    return _switch_impl(cupy_impl.broadcast, xla_impl.broadcast, worker, uuid,\n                        comm_key, world_size, devices_ids, devices_global_rank,\n                        tensor_slices, group_name)\n\n\ndef allgather(worker, uuid: int, device_ids: Sequence[int],\n              tensor_slices: Sequence[Sequence[slice]], output_slice):\n    return _switch_impl(cupy_impl.allgather, xla_impl.allgather, worker, uuid,\n                        device_ids, tensor_slices, output_slice)\n\n\ndef to_signal_buffer(jax_tensor):\n    return _switch_impl(cupy_impl.to_signal_buffer, xla_impl.to_signal_buffer,\n                        jax_tensor)\n"
  },
  {
    "path": "alpa/collective/worker_nccl_util_cupy.py",
    "content": "\"\"\"Utility functions for device mesh workers to call nccl APIs.\"\"\"\nimport logging\nfrom typing import Sequence\n\nimport cupy\nimport jax.numpy as jnp\nfrom jax import device_put\nfrom jax._src.dlpack import from_dlpack, to_dlpack\nfrom jax._src.lib import xla_bridge as xb, xla_client as xc\nimport numpy as np\n\nimport alpa.collective as col\nfrom alpa.collective.collective_group import nccl_util\nfrom alpa.util import (jax_tensor_set, jax_tensor_index,\n                       xla_buffer_to_jax_tensor, jax_tensor_to_xla_buffer,\n                       is_continuous_subset, infer_offset_and_n_elements)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\n# Note: in this device mesh code, we will use 3 types of tensors:\n# (1) JAX high-level _DeviceArray, which is index-able, has __cuda_array__\n#     interface\n# (2) XLA low-level PyLocalBuffer, which is not index-able\n# (3) cupy array, which is an intermediate format for ray collective\ndef send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice],\n              dst_rank: int, dst_gpu_idx: int, group_name: str):\n    \"\"\"\n    Send a slice of a source buffer to a target GPU.\n\n    Args:\n        uuid: the uuid of the xla buffers.\n        device_id: the device where the buffer is sent.\n        offset: the slice to be sent in the buffer.\n        dst_rank: destination rank to send.\n        dst_gpu_idx: the gpu index on the destination rank.\n        group_name: collective group name\n    \"\"\"\n    buffer = worker.buffers[uuid][device_id]\n    tensor_shape = buffer.shape\n    if is_continuous_subset(offset, tensor_shape):\n        # fast path, two cases: (1) same shape, (2) continuous subset.\n        slice_shape = tuple(ind.stop - ind.start for ind in offset)\n        to_send = xla_buffer_to_cupy(buffer)\n        if slice_shape == tensor_shape:\n            col.send_multigpu(to_send, dst_rank, dst_gpu_idx, group_name)\n        else:\n            ind, n_elements = infer_offset_and_n_elements(offset)\n            col.send_multigpu(to_send[ind],\n                              dst_rank,\n                              dst_gpu_idx,\n                              group_name,\n                              n_elements=n_elements)\n    else:\n        # slower path, because of indexing.\n        logger.debug(\"Send goes along the slowest path. \"\n                     \"If this is for transformers, please check the resharding \"\n                     \"specs.\")\n        start_indices = tuple(o.start for o in offset)\n        slice_sizes = tuple(o.stop - o.start for o in offset)\n        src_buffer = jax_tensor_index(xla_buffer_to_jax_tensor(buffer),\n                                      start_indices, slice_sizes)\n        to_send = jax_tensor_to_cupy(src_buffer)\n        col.send_multigpu(to_send, dst_rank, dst_gpu_idx, group_name)\n\n\ndef recv_tile(worker, uuid: int, device_id: int,\n              indices_in_dst_tile: Sequence[slice], src_rank: int,\n              src_gpu_idx: int, group_name: str):\n    \"\"\"\n    Receive a slice from a source GPU and in-place write it on the target\n    buffer.\n\n    Args:\n        uuid: the uuid of the xla buffers.\n        device_id: the device where the buffer is received, used to allocate\n            tmp buffer.\n        indices_in_dst_tile: the slice index to be written on destination\n            buffer.\n        src_rank: source rank to receive from.\n        src_gpu_idx: the sender gpu index on the source rank.\n        group_name: collective group name.\n    \"\"\"\n\n    buffer = worker.buffers[uuid][device_id]\n    tensor_shape = buffer.shape\n    slice_shape = tuple(ind.stop - ind.start for ind in indices_in_dst_tile)\n    is_bool = buffer.dtype == np.bool_\n    if is_continuous_subset(indices_in_dst_tile, tensor_shape):\n        to_recv = xla_buffer_to_cupy(buffer, take_ownership=True)\n        if slice_shape == tensor_shape:\n            col.recv_multigpu(to_recv, src_rank, src_gpu_idx, group_name)\n        else:\n            ind, n_elements = infer_offset_and_n_elements(indices_in_dst_tile)\n            col.recv_multigpu(to_recv[ind],\n                              src_rank,\n                              src_gpu_idx,\n                              group_name,\n                              n_elements=n_elements)\n        new_buffer = cupy_to_xla_buffer(to_recv)\n    else:\n        # The following call will allocate memory and cause a few H2D and\n        # D2D kernels.\n        # See: https://github.com/alpa-projects/alpa/issues/145\n        logger.debug(\"Recv goes along the slowest path. \"\n                     \"If this is for transformers, please check the resharding \"\n                     \"specs.\")\n        tmp_buffer = device_put(jnp.ones(slice_shape, dtype=buffer.dtype),\n                                worker.local_devices[device_id])\n        to_recv = jax_tensor_to_cupy(tmp_buffer, take_ownership=True)\n        col.recv_multigpu(to_recv, src_rank, src_gpu_idx, group_name)\n        recv_tensor = cupy_to_jax_tensor(to_recv)\n        start_indices = tuple(\n            ind_in_dst.start for ind_in_dst in indices_in_dst_tile)\n\n        # The following in-place write will cause a D2D copy kernel\n        # See: https://github.com/alpa-projects/alpa/issues/144\n        # It is unavoidable, but it is better than:\n        # new_buffer = dynamic_update_slice(src_buf, update, start_indices)\n        # which is not in-place and will cause extra allocation-related\n        # kernels.\n        new_buffer = jax_tensor_set(xla_buffer_to_jax_tensor(buffer),\n                                    recv_tensor, start_indices)\n        new_buffer = jax_tensor_to_xla_buffer(new_buffer)\n    if is_bool:\n        new_buffer = _uint8_to_bool(new_buffer)\n    worker.buffers[uuid][device_id] = new_buffer\n\n\ndef allgather(worker, uuid: int, device_ids: Sequence[int],\n              tensor_slices: Sequence[Sequence[slice]], output_slice):\n    cupy_buffers = []\n    communicators = worker.allgather_communicators[repr(sorted(device_ids))]\n    relative_idx = dict(zip(sorted(device_ids), range(len(device_ids))))\n    output_idx, _ = infer_offset_and_n_elements(output_slice)\n    is_bool = worker.buffers[uuid][0].dtype == np.bool_\n    nccl_util.groupStart()\n    for device_id, tensor_slice in zip(device_ids, tensor_slices):\n        xla_buffer = worker.buffers[uuid][device_id]\n        cupy_buffer = xla_buffer_to_cupy(xla_buffer, take_ownership=True)\n        ind, n_elements = infer_offset_and_n_elements(tensor_slice)\n        cupy_slice = cupy_buffer[ind]\n        cupy_output_slice = cupy_buffer[output_idx]\n        communicators[relative_idx[device_id]].allGather(\n            nccl_util.get_tensor_ptr(cupy_slice),\n            nccl_util.get_tensor_ptr(cupy_output_slice), n_elements,\n            nccl_util.get_nccl_tensor_dtype(cupy_buffer),\n            cupy.cuda.Stream.null.ptr)\n        cupy_buffers.append(cupy_buffer)\n    nccl_util.groupEnd()\n    for device_id, cupy_buffer in zip(device_ids, cupy_buffers):\n        buf = cupy_to_xla_buffer(cupy_buffer)\n        if is_bool:\n            buf = _uint8_to_bool(buf)\n        worker.buffers[uuid][device_id] = buf\n\n\ndef broadcast(worker, uuid, comm_key, world_size, devices_ids,\n              devices_global_rank, tensor_slices, group_name):\n    to_use = []\n    for_buffer = []\n    is_bool = worker.buffers[uuid][devices_ids[0]].dtype == np.bool_\n    for device_id, global_rank, tensor_slice in zip(devices_ids,\n                                                    devices_global_rank,\n                                                    tensor_slices):\n        buffer = worker.buffers[uuid][device_id]\n        tensor_shape = buffer.shape\n        slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice)\n        if is_continuous_subset(tensor_slice, tensor_shape):\n            # fast path, two cases: (1) same shape, (2) continuous subset.\n            tmp = xla_buffer_to_cupy(buffer)\n            if slice_shape != tensor_shape:\n                ind, _ = infer_offset_and_n_elements(tensor_slice)\n                to_use.append(tmp[ind])\n            else:\n                to_use.append(tmp)\n            for_buffer.append(tmp)\n        else:\n            tmp = None\n            if global_rank == 0:\n                start_indices = tuple(o.start for o in tensor_slice)\n                tmp = jax_tensor_index(xla_buffer_to_jax_tensor(buffer),\n                                       start_indices, slice_shape)\n                tmp = jax_tensor_to_cupy(tmp)\n            else:\n                tmp = device_put(jnp.ones(slice_shape, dtype=buffer.dtype),\n                                 worker.local_devices[device_id])\n                tmp = jax_tensor_to_cupy(tmp, take_ownership=True)\n            to_use.append(tmp)\n            for_buffer.append(tmp)\n\n    _, n_elements = infer_offset_and_n_elements(tensor_slices[0])\n    col.broadcast_partialgpu(to_use, n_elements, comm_key, world_size,\n                             devices_ids, devices_global_rank, group_name)\n\n    for for_buffer_tensor, device_id, global_rank, tensor_slice in zip(\n            for_buffer, devices_ids, devices_global_rank, tensor_slices):\n        if global_rank == 0:\n            continue\n        buffer = worker.buffers[uuid][device_id]\n        tensor_shape = buffer.shape\n        slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice)\n        if is_continuous_subset(tensor_slice, tensor_shape):\n            new_buffer = cupy_to_xla_buffer(for_buffer_tensor)\n        else:\n            recv_tensor = cupy_to_jax_tensor(for_buffer_tensor)\n            start_indices = tuple(\n                ind_in_dst.start for ind_in_dst in tensor_slice)\n            new_buffer = jax_tensor_set(xla_buffer_to_jax_tensor(buffer),\n                                        recv_tensor, start_indices)\n            new_buffer = jax_tensor_to_xla_buffer(new_buffer)\n        if is_bool:\n            new_buffer = _uint8_to_bool(new_buffer)\n        worker.buffers[uuid][device_id] = new_buffer\n\n\ndef to_signal_buffer(jax_tensor):\n    return jax_tensor_to_cupy(jax_tensor, take_ownership=True)\n\n\ndef xla_buffer_to_cupy(xla_buf, take_ownership=False):\n    \"\"\"Convert an xla buffer directly to cupy, w/o transitioning from jax\n    buffer.\"\"\"\n    return cupy.fromDlpack(\n        xc._xla.buffer_to_dlpack_managed_tensor(  # pylint: disable=protected-access\n            xla_buf,\n            take_ownership=take_ownership))\n\n\ndef cupy_to_xla_buffer(tensor):\n    \"\"\"Convert cupy tensors to XLA buffers.\"\"\"\n    if isinstance(tensor, list):\n        return list(map(cupy_to_xla_buffer, tensor))\n    cpu_backend = xb.get_backend(\"cpu\")\n    try:\n        gpu_backend = xb.get_backend(\"gpu\")\n    except RuntimeError:\n        gpu_backend = None\n    buf = xc._xla.dlpack_managed_tensor_to_buffer(  # pylint: disable=protected-access\n        tensor.toDlpack(), cpu_backend, gpu_backend)\n    return buf\n\n\ndef jax_tensor_to_cupy(tensors, take_ownership=False):\n    \"\"\"Convert a Jax DeviceArray to cupy tensor; zero copy.\"\"\"\n    if isinstance(tensors, list):\n        return list(map(jax_tensor_to_cupy, tensors))\n    return cupy.fromDlpack(to_dlpack(tensors, take_ownership=take_ownership))\n\n\ndef cupy_to_jax_tensor(tensors):\n    \"\"\"Convert cupy tensors to JAX tensors.\"\"\"\n    if isinstance(tensors, list):\n        return list(map(cupy_to_jax_tensor, tensors))\n    return from_dlpack(tensors.toDlpack())\n\n\n# in XLA pred(bool) and uint8 are different, but xla->dlpack->xla\n# turns a bool into uint8. This implementation is slow.\ndef _uint8_to_bool(xla_buffer):\n    buf = xla_buffer_to_jax_tensor(xla_buffer).astype(np.bool_)\n    return jax_tensor_to_xla_buffer(buf)\n"
  },
  {
    "path": "alpa/collective/worker_nccl_util_xla.py",
    "content": "\"\"\"Utility functions for device mesh workers to call nccl APIs.\"\"\"\nimport logging\nfrom typing import Sequence\n\nimport jax.numpy as jnp\nfrom jax import device_put\nfrom jax._src.lib import xla_extension as xe\nimport numpy as np\n\nimport alpa.collective as col\nfrom alpa.util import (jax_tensor_set, jax_tensor_index,\n                       xla_buffer_to_jax_tensor, jax_tensor_to_xla_buffer,\n                       is_continuous_subset, infer_offset_and_n_elements,\n                       infer_start_pos_and_n_elements)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\ndef send_tile(worker, uuid: int, device_id: int, offset: Sequence[slice],\n              dst_rank: int, dst_gpu_idx: int, group_name: str):\n    buffer = worker.buffers[uuid][device_id]\n    tensor_shape = buffer.shape\n    if is_continuous_subset(offset, tensor_shape):\n        start_pos, n_elements = (infer_start_pos_and_n_elements(\n            tensor_shape, offset))\n        col.send_multigpu(buffer,\n                          dst_rank,\n                          dst_gpu_idx,\n                          group_name,\n                          start_pos=start_pos,\n                          n_elements=n_elements)\n    else:\n        # slower path, because of indexing.\n        logger.debug(\"Send goes along the slowest path. \"\n                     \"If this is for transformers, please check the resharding \"\n                     \"specs.\")\n        start_indices = tuple(o.start for o in offset)\n        slice_sizes = tuple(o.stop - o.start for o in offset)\n        src_buffer = jax_tensor_index(xla_buffer_to_jax_tensor(buffer),\n                                      start_indices, slice_sizes)\n        to_send = jax_tensor_to_xla_buffer(src_buffer)\n        n_elements = np.prod(slice_sizes)\n        # dummy_compute_on_default_stream(device_id)\n\n        # let send stream wait for compute stream\n        col.comm_wait_compute(group_name, True, True, device_id)\n\n        col.send_multigpu(to_send,\n                          dst_rank,\n                          dst_gpu_idx,\n                          group_name,\n                          start_pos=0,\n                          n_elements=n_elements)\n\n\ndef recv_tile(worker, uuid: int, device_id: int,\n              indices_in_dst_tile: Sequence[slice], src_rank: int,\n              src_gpu_idx: int, group_name: str):\n    buffer = worker.buffers[uuid][device_id]\n    tensor_shape = buffer.shape\n    slice_shape = tuple(ind.stop - ind.start for ind in indices_in_dst_tile)\n    if is_continuous_subset(indices_in_dst_tile, tensor_shape):\n        start_pos, n_elements = infer_start_pos_and_n_elements(\n            tensor_shape, indices_in_dst_tile)\n        col.recv_multigpu(buffer,\n                          src_rank,\n                          src_gpu_idx,\n                          group_name,\n                          start_pos=start_pos,\n                          n_elements=n_elements)\n    else:\n        tmp_buffer = device_put(jnp.ones(slice_shape, dtype=buffer.dtype),\n                                worker.local_devices[device_id])\n        to_recv = jax_tensor_to_xla_buffer(tmp_buffer)\n        n_elements = np.prod(slice_shape)\n        # let recv stream wait for d2d stream\n        col.comm_wait_compute(group_name, False, False, device_id)\n        # let recv stream wait for compute stream\n        col.comm_wait_compute(group_name, False, True, device_id)\n\n        col.recv_multigpu(to_recv,\n                          src_rank,\n                          src_gpu_idx,\n                          group_name,\n                          start_pos=0,\n                          n_elements=n_elements)\n        # let compute stream wait for recv stream\n        col.compute_wait_comm(group_name, False, True, device_id)\n\n        start_indices = tuple(\n            ind_in_dst.start for ind_in_dst in indices_in_dst_tile)\n        new_buffer = jax_tensor_set(xla_buffer_to_jax_tensor(buffer),\n                                    xla_buffer_to_jax_tensor(to_recv),\n                                    start_indices)\n        worker.buffers[uuid][device_id] = jax_tensor_to_xla_buffer(new_buffer)\n\n\ndef allgather(worker, uuid: int, device_ids: Sequence[int],\n              tensor_slices: Sequence[Sequence[slice]], output_slice):\n    # FIXME: handle the case that local device ids are the same but global ids\n    # are different\n    communicators = worker.allgather_communicators[repr(sorted(device_ids))]\n    tensor_shape = worker.buffers[uuid][device_ids[0]].shape\n    global_start_pos, _ = infer_start_pos_and_n_elements(\n        tensor_shape, output_slice)\n\n    buffers = []\n    local_start_pos_list = []\n    for device_id, tensor_slice in zip(device_ids, tensor_slices):\n        xla_buffer = worker.buffers[uuid][device_id]\n        start_pos, _ = infer_start_pos_and_n_elements(tensor_shape,\n                                                      tensor_slice)\n        buffers.append(xla_buffer)\n        local_start_pos_list.append(start_pos)\n\n    _, local_n_elements = infer_offset_and_n_elements(tensor_slices[0])\n    xe.nccl_local_all_gather(communicators, buffers, local_start_pos_list,\n                             global_start_pos, local_n_elements)\n\n    for device_id, buf in zip(device_ids, buffers):\n        worker.buffers[uuid][device_id] = buf\n\n\ndef broadcast(worker, uuid, comm_key, world_size, devices_ids,\n              devices_global_rank, tensor_slices, group_name):\n    buffers = []\n    local_start_pos_list = []\n    _, n_elements = infer_offset_and_n_elements(tensor_slices[0])\n    for device_id, global_rank, tensor_slice in zip(devices_ids,\n                                                    devices_global_rank,\n                                                    tensor_slices):\n        buffer = worker.buffers[uuid][device_id]\n        tensor_shape = buffer.shape\n        slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice)\n        if is_continuous_subset(tensor_slice, tensor_shape):\n            # fast path, two cases: (1) same shape, (2) continuous subset.\n            start_pos, _ = infer_start_pos_and_n_elements(\n                tensor_shape, tensor_slice)\n            local_start_pos_list.append(start_pos)\n            buffers.append(buffer)\n        else:\n            tmp = None\n            if global_rank == 0:\n                start_indices = tuple(o.start for o in tensor_slice)\n                tmp = jax_tensor_index(xla_buffer_to_jax_tensor(buffer),\n                                       start_indices, slice_shape)\n            else:\n                tmp = device_put(jnp.ones(slice_shape, dtype=buffer.dtype),\n                                 worker.local_devices[device_id])\n            # let communicate stream wait for compute stream\n            is_send = global_rank == 0\n            col.comm_wait_compute(group_name, is_send, True, device_id)\n            # let communicate stream wait for d2d stream\n            col.comm_wait_compute(group_name, is_send, False, device_id)\n\n            local_start_pos_list.append(0)\n            buffers.append(jax_tensor_to_xla_buffer(tmp))\n\n    col.broadcast_partialgpu(buffers, n_elements, comm_key, world_size,\n                             devices_ids, devices_global_rank, group_name,\n                             local_start_pos_list)\n\n    for xla_buffer, device_id, global_rank, tensor_slice in zip(\n            buffers, devices_ids, devices_global_rank, tensor_slices):\n        if global_rank == 0:\n            continue\n        buffer = worker.buffers[uuid][device_id]\n        tensor_shape = buffer.shape\n        slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice)\n        if is_continuous_subset(tensor_slice, tensor_shape):\n            new_buffer = xla_buffer\n        else:\n            start_indices = tuple(\n                ind_in_dst.start for ind_in_dst in tensor_slice)\n            # let compute stream wait for communicator stream\n            is_send = global_rank == 0\n            col.compute_wait_comm(group_name, is_send, True, device_id)\n            new_buffer = jax_tensor_set(xla_buffer_to_jax_tensor(buffer),\n                                        xla_buffer_to_jax_tensor(xla_buffer),\n                                        start_indices)\n            new_buffer = jax_tensor_to_xla_buffer(new_buffer)\n        worker.buffers[uuid][device_id] = new_buffer\n\n\nto_signal_buffer = jax_tensor_to_xla_buffer\n"
  },
  {
    "path": "alpa/create_state_parallel.py",
    "content": "\"\"\"Compile executables for creating training state distributedly.\"\"\"\nfrom collections import defaultdict, deque\nfrom typing import Sequence, Optional\n\nfrom jax.core import Var\nfrom jax.interpreters import pxla\nfrom jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef\nimport numpy as np\n\nfrom alpa.device_mesh import ReplicatedDistributedArray, PhysicalDeviceMeshGroup\nfrom alpa.global_env import global_config\nfrom alpa.mesh_executable import (NormalMeshDriverExecutable,\n                                  GradAccMeshDriverExecutable)\nfrom alpa.parallel_plan import PlacementSpec\nfrom alpa.pipeline_parallel.compile_executable import compile_pipeshard_executable_internal\nfrom alpa.pipeline_parallel.layer_construction import add_pipeline_marks_for_sliced_eqns\nfrom alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable\nfrom alpa.pipeline_parallel.runtime_emitter import PipeshardConfig\nfrom alpa.pipeline_parallel.stage_construction import UniformStageOption\nfrom alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass,\n                                               AutoShardingOption)\nfrom alpa.util import jaxpr_to_hlo, trace_jaxpr_with_micro_batch\n\n\nclass CreateStateExecutable(PipeshardDriverExecutable):\n    \"\"\"\n    A distributed executable that creates a training state for a function\n    parallelized by PipeshardParallel.\n    \"\"\"\n\n    def __init__(self,\n                 mesh_group: PhysicalDeviceMeshGroup,\n                 pipeshard_config: PipeshardConfig,\n                 target_placement_specs: Sequence[PlacementSpec],\n                 in_tree: PyTreeDef,\n                 out_tree: Optional[PyTreeDef] = None,\n                 static_argnums: Optional[Sequence[int]] = None):\n        super().__init__(mesh_group=mesh_group,\n                         pipeshard_config=pipeshard_config,\n                         num_batch=1,\n                         layer_option=None,\n                         in_tree=in_tree,\n                         out_tree=out_tree,\n                         static_argnums=static_argnums)\n        self.target_placement_specs = target_placement_specs\n\n    def launch_on_driver(self, *args):\n        outputs = super().launch_on_driver(*args)\n\n        # Handle the creation of ReplicatedDistributedArray\n        for idx, (array,\n                  spec) in enumerate(zip(outputs, self.target_placement_specs)):\n            assert array.device_mesh.mesh_id == spec.mesh_ids[0]\n            assert array.indices == pxla.spec_to_indices(\n                array.shape, spec.sharding_specs[0])\n\n            if len(spec.mesh_ids) > 1:\n                meshes = tuple(self.mesh_group[i] for i in spec.mesh_ids)\n                distributed_arrays = [array]\n                for mesh_id, sharding_spec in zip(spec.mesh_ids[1:],\n                                                  spec.sharding_specs[1:]):\n                    indices = pxla.spec_to_indices(array.shape, sharding_spec)\n                    dis_array = self.mesh_group[mesh_id].shard_args_to_arrays(\n                        (array.aval,), (indices,), (sharding_spec,),\n                        (np.asarray(array),))[0]\n                    distributed_arrays.append(dis_array)\n                outputs[idx] = ReplicatedDistributedArray(\n                    meshes, distributed_arrays)\n\n        return outputs\n\n\ndef compile_create_state_executable(fun, in_tree, out_tree_thunk,\n                                    static_argnums, donated_invars, train_step,\n                                    other_args, *avals):\n    # Trace to get jaxpr and HloModule\n    closed_jaxpr, _ = trace_jaxpr_with_micro_batch(fun, [False] * len(avals), 1,\n                                                   avals)\n    out_avals = [v.aval for v in closed_jaxpr.jaxpr.outvars]\n    jaxpr = closed_jaxpr.jaxpr\n\n    name = f\"{fun.__name__}_create_state_parallel\"\n    hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars)\n\n    # Compile train_step to get the placement specs.\n    out_tree = out_tree_thunk()\n    state_aval = tree_unflatten(out_tree, out_avals)\n    executable = train_step.get_executable(state_aval, other_args)\n    placement_specs = executable.get_input_placement_specs()[0]\n    placement_specs, _ = tree_flatten(placement_specs)\n\n    if (not isinstance(executable, NormalMeshDriverExecutable) and\n            global_config.backend == \"tpu\"):\n        raise NotImplementedError(f\"{type(executable)} is not supported in tpu\")\n    if isinstance(executable,\n                  (NormalMeshDriverExecutable, GradAccMeshDriverExecutable)):\n        sharding_protos = []\n        for spec in placement_specs:\n            assert len(spec.mesh_ids) == 1\n            sharding_protos.append(spec.sharding_specs[0].sharding_proto())\n\n        physical_mesh = executable.physical_mesh\n\n        # Run sharding propagation\n        hlo.set_output_shardings(sharding_protos)\n        hlo, stage_plan = run_auto_sharding_pass(\n            hlo,\n            physical_mesh.get_logical_mesh(\n                executable.stage_plan.logical_mesh_shape), \"single\", 1,\n            AutoShardingOption(enable_auto_sharding=False))\n\n        return NormalMeshDriverExecutable(physical_mesh, hlo, stage_plan, avals,\n                                          out_avals, [False] * len(avals),\n                                          static_argnums, in_tree, out_tree)\n    else:\n        # Construct a new pipelined jaxpr\n        outvars = jaxpr.outvars\n\n        var2mesh = {}  # Dict[var -> mesh_id]\n        eqn2mesh = {}  # Dict[eqn_idx -> mesh_id]\n\n        output_shardings = []\n        for var, spec in zip(outvars, placement_specs):\n            if isinstance(var, Var):\n                var2mesh[var] = spec.mesh_ids[0]\n            output_shardings.append(spec.sharding_specs[0])\n\n        num_meshes = len(executable.mesh_group)\n\n        propagate_mesh_assignment(jaxpr, var2mesh, eqn2mesh)\n        sliced_eqns = slice_jaxpr_with_mesh_assignment(jaxpr, eqn2mesh,\n                                                       num_meshes)\n        new_jaxpr = add_pipeline_marks_for_sliced_eqns(closed_jaxpr,\n                                                       sliced_eqns)\n\n        # Compile a pipeshard executable with predefined output shardings\n        pipeshard_config = compile_pipeshard_executable_internal(\n            new_jaxpr, None, 1, [False] * len(avals), [False] * len(avals),\n            executable.mesh_group.parent, 1, \"inference\",\n            AutoShardingOption(enable_auto_sharding=False),\n            UniformStageOption(), name, None, output_shardings, None, None)\n\n        return CreateStateExecutable(mesh_group=executable.mesh_group,\n                                     pipeshard_config=pipeshard_config,\n                                     target_placement_specs=placement_specs,\n                                     in_tree=in_tree,\n                                     out_tree=out_tree_thunk(),\n                                     static_argnums=static_argnums)\n\n\ndef propagate_mesh_assignment(jaxpr, var2mesh, eqn2mesh):\n    \"\"\"Propagate mesh assignment for all variables and equations.\n\n    Note that this is different from the propagation in apply_grad.\n    create_state_parallel: always assign one equation to one mesh.\n      If one equation is used by multiple meshes, use send/recv to\n      pass the value.\n    apply_grad: can assign one equation to multiple meshes.\n      If one equation is used by multiple meshes, replicate the\n      computation on all meshes.\n    \"\"\"\n    def_eqn = {}  # Dict[var -> eqn_idx]\n\n    for idx, eqn in enumerate(jaxpr.eqns):\n        for var in eqn.outvars:\n            def_eqn[var] = idx\n\n    mesh2vars = defaultdict(list)\n    for var, mesh_idx in var2mesh.items():\n        mesh2vars[mesh_idx].append(var)\n\n    mesh_indices = list(mesh2vars.keys())\n    mesh_indices.sort()\n\n    for mesh_idx in mesh_indices:\n        for var in mesh2vars[mesh_idx]:\n            eqn_idx = def_eqn[var]\n            if eqn_idx not in eqn2mesh:\n                # Propagate from the definition equation to\n                # all related equations\n                queue = deque((eqn_idx,))\n\n                while queue:\n                    eqn_idx = queue.popleft()\n                    eqn2mesh[eqn_idx] = mesh_idx\n\n                    for var in jaxpr.eqns[eqn_idx].invars:\n                        if isinstance(var, Var):\n                            eqn_idx = def_eqn[var]\n                            if eqn_idx not in eqn2mesh:\n                                queue.append(eqn_idx)\n\n\ndef slice_jaxpr_with_mesh_assignment(jaxpr, eqn2mesh, num_meshes):\n    sliced_eqns = [[] for _ in range(num_meshes)]\n\n    for idx, eqn in enumerate(jaxpr.eqns):\n        if idx in eqn2mesh:\n            sliced_eqns[eqn2mesh[idx]].append(eqn)\n\n    return sliced_eqns\n"
  },
  {
    "path": "alpa/data_loader.py",
    "content": "\"\"\"\"Distributed data loaders for loading data into device meshes.\"\"\"\nimport collections\nimport itertools\n\nimport jax\nfrom jax.interpreters import pxla\nimport numpy as np\nimport ray\n\nfrom alpa.device_mesh import (DistributedArray, LocalPhysicalDeviceMesh,\n                              get_global_physical_mesh,\n                              create_remote_array_refs)\n\n\nclass DataLoader:\n    \"\"\"A driver-only dataloader that loads data on the driver process and\n    sends the data to all workers.\"\"\"\n\n    def __init__(self, input_iter, placement_specs, prefetch_size=1):\n        self.input_iter = input_iter\n        self.prefetch_size = prefetch_size\n\n        self.physical_mesh = get_global_physical_mesh()\n        self.avals = []\n        self.indices = []\n        self.sharding_specs = []\n        for ps in jax.tree_util.tree_leaves(placement_specs):\n            assert len(ps.mesh_ids) == 1\n            assert ps.mesh_ids[0] == self.physical_mesh.mesh_id\n\n            self.avals.append(ps.aval)\n            self.sharding_specs.append(ps.sharding_specs[0])\n            self.indices.append(\n                tuple(ps.sharding_specs[0].indices(ps.aval.shape).flatten()))\n\n        self.queue = collections.deque()\n\n    def enqueue(self, num_batches):\n        for batch in itertools.islice(self.input_iter, num_batches):\n            flatten_args, tree = jax.tree_flatten(batch)\n            new_args = self.physical_mesh.shard_args_to_arrays(\n                self.avals, self.indices, self.sharding_specs, flatten_args)\n            self.queue.append(jax.tree_unflatten(tree, new_args))\n\n    def __iter__(self):\n        if self.prefetch_size:\n            self.enqueue(self.prefetch_size)\n            while self.queue:\n                yield self.queue.popleft()\n                self.enqueue(1)\n        else:\n            while True:\n                self.enqueue(1)\n                if self.queue:\n                    yield self.queue.popleft()\n                else:\n                    break\n\n\n# The global executable and buffer counter.\nmesh_data_loader_counter = 0\n\n\ndef next_mesh_data_loader_uuid():\n    \"\"\"Return the next uuid of a mesh data loader.\"\"\"\n    global mesh_data_loader_counter\n    mesh_data_loader_counter = (mesh_data_loader_counter + 1) % (1 << 60)\n    return mesh_data_loader_counter\n\n\ndef get_num_devices_for_whole_batch(sharding_spec, batch_dim=0):\n    \"\"\"Get the number of devices for a whole batch.\"\"\"\n    num_devices = 1\n    for sharding in sharding_spec.sharding:\n        if isinstance(sharding, pxla.Chunked):\n            num_devices *= np.prod(sharding.chunks)\n\n    for assignment in sharding_spec.mesh_mapping:\n        if isinstance(assignment, pxla.Replicated):\n            num_devices *= assignment.replicas\n\n    sharding = sharding_spec.sharding[batch_dim]\n\n    num_data_chunk = 1\n    if isinstance(sharding, pxla.Chunked):\n        num_data_chunk = np.prod(sharding.chunks)\n\n        # Assert the data chunk is mapped to the first dim of device mesh\n        for assignment in sharding_spec.mesh_mapping:\n            if isinstance(assignment, pxla.ShardedAxis):\n                assert assignment.axis == 0\n                break\n\n    return num_devices / num_data_chunk\n\n\nclass MeshDriverDataLoader:\n    \"\"\"The driver part of a distributed data loader. The driver part creates\n    distributed arrays and sends commands to let workers load the data in\n    parallel.\n\n    Args:\n        batch_size: The global batch size.\n        num_samples: The number of samples in the whole dataset.\n        input_iter_func: A function with the following signature.\n          func(start: int, end: int, batch_size: int) -> Iterator\n          It returns dataset[start:end] one batch by one batch.\n        placement_specs: The placement specs of batch arguments.\n        prefetch_size: The number of batches to prefetch.\n        repeat: If true, repeat the dataset indefinitely. The\n          returned iterator will never stop.\n\n    Note:\n        Currently, this only works for ShardParallel without\n        gradient accumulation.\n    \"\"\"\n\n    def __init__(self,\n                 batch_size,\n                 num_samples,\n                 input_iter_func,\n                 placement_specs,\n                 prefetch_size=1,\n                 repeat=False):\n        self.repeat = repeat\n\n        physical_mesh = get_global_physical_mesh()\n        assert not isinstance(physical_mesh, LocalPhysicalDeviceMesh), (\n            \"Please use alpa.DataLoader instead of alpa.MeshWorkerDataLoader \"\n            \"for local physical device mesh.\")\n\n        avals = []\n        sharding_specs = []\n        indices = []\n        for ps in jax.tree_util.tree_leaves(placement_specs):\n            avals.append(ps.aval)\n            assert len(ps.mesh_ids) == 1\n            assert ps.mesh_ids[0] == physical_mesh.mesh_id\n            sharding_specs.append(ps.sharding_specs[0])\n            indices.append(np.ravel(ps.sharding_specs[0].indices(\n                ps.aval.shape)))\n\n        self.uuid = next_mesh_data_loader_uuid()\n        self.physical_mesh = physical_mesh\n\n        # Create output DisributedArray\n        ary_refs, ary_uuids = create_remote_array_refs(physical_mesh,\n                                                       len(avals))\n        self.output_uuids = ary_uuids\n        self.output_arrays = []\n        for i in range(len(avals)):\n            self.output_arrays.append(\n                DistributedArray(physical_mesh, avals[i], sharding_specs[i],\n                                 ary_refs[i]))\n\n        # Create worker part data loaders\n        self.worker_data_loaders = []\n        self.num_batches = num_samples // batch_size\n\n        # Adjust sharding indices\n        # Basic idea:\n        # 1. For each host, assign a contiguous range of the whole dataset to it\n        # 2. Adjust the per-device view of sharding indices to per-host view.\n        for i in range(physical_mesh.num_hosts):\n            host_indices = []\n            for j in range(len(avals)):\n                batch_size = avals[j].shape[0]\n                num_devices_for_one_batch = get_num_devices_for_whole_batch(\n                    sharding_specs[j])\n                num_hosts_for_one_batch = max(\n                    1, num_devices_for_one_batch /\n                    physical_mesh.num_devices_per_host)\n                assert float(num_hosts_for_one_batch).is_integer(\n                ), f\"{num_hosts_for_one_batch}\"\n                num_hosts_for_one_batch = int(num_hosts_for_one_batch)\n\n                batch_size_per_host = batch_size / (physical_mesh.num_hosts /\n                                                    num_hosts_for_one_batch)\n                assert batch_size_per_host.is_integer()\n                batch_size_per_host = int(batch_size_per_host)\n\n                num_samples_per_host = self.num_batches * batch_size_per_host\n\n                start = (i // num_hosts_for_one_batch) * num_samples_per_host\n                end = (\n                    (i // num_hosts_for_one_batch) + 1) * num_samples_per_host\n\n                host_indices.append([])\n                for k in range(physical_mesh.num_devices_per_host):\n                    device_id = i * physical_mesh.num_devices_per_host + k\n                    tmp_indices = list(indices[j][device_id])\n                    offset = i // num_hosts_for_one_batch * batch_size_per_host\n                    if tmp_indices[0].start is not None:\n                        tmp_indices[0] = slice(tmp_indices[0].start - offset,\n                                               tmp_indices[0].stop - offset,\n                                               tmp_indices[0].step)\n                    host_indices[-1].append(tuple(tmp_indices))\n\n            args = (input_iter_func, (start, end, batch_size_per_host),\n                    self.output_uuids, host_indices, prefetch_size)\n            physical_mesh.workers[i].put_data_loader.remote(self.uuid, *args)\n\n    def __iter__(self):\n        # Create the iterators on workers\n        for w in self.physical_mesh.workers:\n            w.data_loader_iter.remote(self.uuid)\n\n        # Yield the next batch\n        while True:\n            for _ in range(self.num_batches):\n                for w in self.physical_mesh.workers:\n                    w.data_loader_next.remote(self.uuid)\n                for a in self.output_arrays:\n                    a.flush()\n                yield self.output_arrays\n\n            if not self.repeat:\n                break\n\n    def __del__(self):\n        physical_mesh = self.physical_mesh\n        if physical_mesh.workers is None or not ray.is_initialized():\n            return\n\n        for i in range(physical_mesh.num_hosts):\n            physical_mesh.workers[i].delete_data_loader.remote(self.uuid)\n\n\nclass MeshWorkerDataLoader:\n    \"\"\"The worker part of a distributed data loader. The driver part creates\n    distributed arrays and sends commands to let workers load the data in\n    parallel.\"\"\"\n\n    def __init__(self, mesh_host_worker, input_iter_func, input_iter_args,\n                 output_uuids, shard_indices, prefetch_size):\n        self.input_iter = input_iter_func(*input_iter_args)\n        self.output_uuids = output_uuids\n        self.shard_indices = shard_indices\n        self.prefetch_size = prefetch_size\n\n        self.devices = mesh_host_worker.local_devices\n        self.buffers = mesh_host_worker.buffers\n\n        # A queue for prefetching\n        self.queue = collections.deque()\n\n    def enqueue(self, num_batches):\n        for args in itertools.islice(self.input_iter, num_batches):\n            batch = []\n            for i in range(len(args)):\n                shards = [\n                    args[i][self.shard_indices[i][k]]\n                    for k in range(len(self.devices))\n                ]\n                buffers = [\n                    jax.device_put(x, d) for x, d in zip(shards, self.devices)\n                ]\n                batch.append(buffers)\n\n            self.queue.append(batch)\n\n    def pop_left(self):\n        batch = self.queue.popleft()\n        for i, shards in enumerate(batch):\n            self.buffers[self.output_uuids[i]] = shards\n\n    def __iter__(self):\n        if self.prefetch_size:\n            self.enqueue(self.prefetch_size)\n            while self.queue:\n                yield self.pop_left()\n                self.enqueue(1)\n        else:\n            while True:\n                self.enqueue(1)\n                if self.queue:\n                    yield self.pop_left()\n                else:\n                    break\n"
  },
  {
    "path": "alpa/device_mesh.py",
    "content": "# pylint: disable=protected-access\n\"\"\"The device mesh runtime that manages buffers and runs computation\ndistributedly.\n\nThe hierarchy of classes defined in this file:\n\nDeviceCluster  (the whole ray cluster)\n|\nPhysicalDeviceMeshGroup  (multiple device meshes)\n|\nPhysicalDeviceMesh  (one device mesh)\n|\nMeshHostWorker  (one host in a device mesh)\n\nBesides, we have two additional classes: VirtualPhysicalMesh and\nLogicalDeviceMesh. They are only used during compilation time. They are used to\nmanipulate meshes flexibly without allocating real resources during compilation\ntime.\n\"\"\"\nfrom abc import ABC, abstractmethod\nimport asyncio\nfrom collections import defaultdict, namedtuple\nfrom collections.abc import Iterable\nimport logging\nfrom operator import attrgetter\nimport os\nimport pickle\nimport shutil\nimport threading\nimport time\nfrom typing import Any, List, Union, Sequence, Tuple, Optional\n\nfrom jax import core, xla, device_put\nfrom jax._src.api import ShapeDtypeStruct\nfrom jax._src.lib import xla_bridge as xb, xla_extension as xe\nfrom jax._src.tree_util import tree_leaves\nfrom jax.abstract_arrays import array_types\nfrom jax.core import ShapedArray\nfrom jax.interpreters import pxla\nfrom jax.interpreters.pxla import (ShardingSpec, _hashable_index,\n                                   ShardedDeviceArray, Index)\nfrom jax.lib import xla_client\nimport jax.numpy as jnp\nimport numpy as np\nimport ray\nfrom ray.util.placement_group import remove_placement_group\n\nfrom alpa import mesh_profiling\nimport alpa.collective as col\nfrom alpa.global_env import global_config\nfrom alpa.monkey_patch import set_override_backend\nfrom alpa.shard_parallel.auto_sharding import (LogicalDeviceMesh)\nfrom alpa.parallel_plan import PlacementSpec\nfrom alpa.timer import timers, tracer\nfrom alpa.util import (benchmark_func, list_gpu_info, OrderedSet,\n                       update_jax_platform, is_ray_node_resource,\n                       try_import_ray_worker, create_placement_group,\n                       get_bundle_idx, retrieve_placement_group, get_bundle2ip,\n                       check_server_port)\n\nray_worker = try_import_ray_worker()\n\nif global_config.backend == \"gpu\" and global_config.has_cuda:\n    from alpa.collective import worker_nccl_util\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\nReshardingTileSpec = namedtuple(\"ReshardingTileSpec\",\n                                [\"offset\", \"rank\", \"gpu_idx\"])\nReshardingSendSpec = namedtuple(\"ReshardingSendSpec\",\n                                [\"device_id\", \"tile_spec\"])\nReshardingSendTask = namedtuple(\"ReshardingSendTask\",\n                                [\"tile_specs\", \"group_name\"])\nReshardingRecvSpec = namedtuple(\"ReshardingRecvSpec\",\n                                [\"device_id\", \"shape\", \"dtype\", \"tile_specs\"])\nReshardingRecvTask = namedtuple(\"ReshardingRecvTask\",\n                                [\"recv_specs\", \"group_name\"])\nReshardingBroadcastSpec = namedtuple(\"ReshardingBroadcastSpec\", [\n    \"comm_key\", \"world_size\", \"devices_ids\", \"devices_global_rank\",\n    \"tensor_slices\", \"recv_tile_shape\", \"dtype\"\n])\nReshardingBroadcastTask = namedtuple(\"ReshardingBroadcastTask\",\n                                     [\"broadcast_specs\", \"group_name\"])\n\n\n########################################\n# Ray Workers\n########################################\nclass DaemonMoveWorker:\n    \"\"\"\n        A ray actor that moves local checkpoint into the shared\n        filesystem in the background.\n    \"\"\"\n\n    def move(self, from_dir: str, to_dir: str):\n        os.makedirs(to_dir, exist_ok=True)\n        for file in os.listdir(from_dir):\n            from_path = os.path.join(from_dir, file)\n            to_path = os.path.join(to_dir, file)\n            shutil.move(from_path, to_path)\n\n    def sync(self):\n        \"\"\"Noop function used to synchronize.\"\"\"\n\n\nclass MeshHostWorker:\n    \"\"\"\n    A ray actor that manages the xla computation and buffers on a single host.\n    \"\"\"\n\n    def __init__(self, server_address: str, num_hosts: int, host_id: int,\n                 mesh_id: int, move_worker: DaemonMoveWorker,\n                 runtime_random_seed: int, worker_global_config: dict):\n        self.num_hosts = num_hosts\n        self.host_id = host_id\n        self.mesh_id = mesh_id\n        self.move_worker = move_worker\n        self.distributed_client = (\n            xla_client._xla.get_distributed_runtime_client(\n                server_address, host_id, use_coordination_service=False))\n        logger.debug(\n            f\"{host_id}: Trying to connect to xla runtime at {server_address}\")\n        self.distributed_client.connect()\n        logger.debug(\n            f\"{host_id}: Success to connect to xla runtime at {server_address}\")\n\n        # Set global config to follow the driver\n        global_config.update_worker_config(worker_global_config)\n        if global_config.backend == \"gpu\":\n            self.backend = xla_client.make_gpu_client(self.distributed_client,\n                                                      node_id=host_id)\n        else:\n            raise NotImplementedError(\n                f\"backend {global_config.backend} is not supported\")\n        # Monkey patch the backend\n        set_override_backend(self.backend)\n        self.local_devices = self.backend.local_devices()\n        self.num_devices = len(self.local_devices)\n        if global_config.enable_overlapping:\n            xe.set_num_device_on_host(self.num_devices)\n\n        self.buffers = {}  # Dict[uuid -> Sequence[DeviceArray]]\n        self.executables = {}  # Dict[uud -> MeshWorkerExecutable]\n\n        self.send_tasks = {}  # Dict[uuid -> ReshardingSendTask]\n        self.recv_tasks = {}  # Dict[uuid -> ReshardingRecvTask]\n        self.broadcast_tasks = {}  # Dict[uuid -> BroadcastTask]\n        self.broadcast_communicators = {}\n\n        self.data_loaders = {}  # Dict[uuid -> MeshWorkerDataLoader]\n        self.data_loader_iters = {}  # Dict[uuid -> iterator]\n\n        self.set_runtime_random_seed(runtime_random_seed)\n\n        if global_config.pipeline_use_signal_send_recv:\n            print(\"Use signal send recv for debugging.\")\n            self.signal_buffers = []\n            for d in self.local_devices:\n                jax_tensor = device_put(jnp.ones((1,), dtype=jnp.int8), d)\n                self.signal_buffers.append(\n                    worker_nccl_util.to_signal_buffer(jax_tensor))\n\n    ##### Buffer Related Functions #####\n    def put_buffers(self,\n                    uuids: Union[int, Sequence[int]],\n                    datas: Sequence[np.ndarray],\n                    num_batch=1,\n                    batch_dim=0):\n        assert len(datas) == self.num_devices\n        if not isinstance(uuids, Iterable):\n            uuids = [uuids]\n        assert len(uuids) == num_batch\n        if num_batch > 1:\n            split_datas = []\n            for data in datas:\n                split_buffers = np.split(data, num_batch, batch_dim)\n                split_datas.extend(split_buffers)\n            datas = split_datas\n        arys = [([None] * self.num_devices) for _ in range(num_batch)]\n        for i, data in enumerate(datas):\n            if data.dtype == np.int64:\n                data = data.astype(np.int32)\n            device_id, batch_id = divmod(i, num_batch)\n            arys[batch_id][device_id] = (self.backend.buffer_from_pyval(\n                data, self.local_devices[device_id]))\n\n        for uuid, ary in zip(uuids, arys):\n            self.buffers[uuid] = ary\n\n    def shard_and_put_non_zero_buffer(self, uuids: Union[Sequence[int], int],\n                                      shape: Sequence[int], dtype: np.dtype,\n                                      indices: Sequence, num_batch: int):\n        if isinstance(uuids, int):\n            uuids = [uuids]\n        assert len(uuids) == num_batch\n        assert len(indices) == self.num_devices * num_batch\n        arys = [([None] * self.num_devices) for _ in range(num_batch)]\n        for device_id in range(self.num_devices):\n            for b in range(num_batch):\n                shard_shape = []\n                idx = device_id * num_batch + b\n                for j, s in enumerate(indices[idx]):\n                    filled_slice = s.indices(shape[j])\n                    dim_size = len(range(*filled_slice))\n                    shard_shape.append(dim_size)\n                arys[b][device_id] = (self.backend.buffer_from_pyval(\n                    np.full(shard_shape, 1e-8, dtype),\n                    self.local_devices[device_id]))\n        for uuid, ary in zip(uuids, arys):\n            self.buffers[uuid] = ary\n\n    def _get_buffers_with_local_ids(self, uuid: int, device_ids: Sequence[int]):\n        bufs = self.buffers[uuid]\n        # TODO(yonghao): sync communication events. Currently it's safe because\n        # we never get values immediately after a cross-mesh communication.\n        if device_ids is None:\n            return map(np.asarray, bufs)\n        elif not isinstance(device_ids, Iterable):\n            return np.asarray(bufs[device_ids])\n        return [np.asarray(bufs[device_id]) for device_id in device_ids]\n\n    def get_buffers(self,\n                    uuids: Union[Sequence[int], int],\n                    device_indices: Sequence[int] = None):\n        if not isinstance(uuids, Iterable):\n            return self._get_buffers_with_local_ids(uuids, device_indices)\n        if device_indices is not None:\n            assert len(uuids) == len(device_indices)\n        else:\n            device_indices = [None] * len(uuids)\n        return [\n            self._get_buffers_with_local_ids(uuid, local_ids)\n            for uuid, local_ids in zip(uuids, device_indices)\n        ]\n\n    def delete_buffers(self, uuids: Union[Sequence[int], int]):\n        if isinstance(uuids, Iterable):\n            for uuid in uuids:\n                del self.buffers[uuid]\n        else:\n            del self.buffers[uuids]\n\n    def block_until_ready_buffers(self, uuids: Union[Sequence[int], int]):\n        # We have to block all buffers to avoid the last operation is\n        # cross-mesh resharding(not SPMD)\n        if isinstance(uuids, Iterable):\n            for uuid in uuids:\n                for buf in self.buffers[uuid]:\n                    buf.block_until_ready()\n        else:\n            for buf in self.buffers[uuids]:\n                buf.block_until_ready()\n\n    def get_memory_allocated(self):\n        self.sync()\n        return max(d.memory_allocated() for d in self.local_devices)\n\n    def get_max_memory_allocated(self):\n        self.sync()\n        return max(d.max_memory_allocated() for d in self.local_devices)\n\n    def get_available_memory(self):\n        self.sync()\n        return min(d.available_memory() for d in self.local_devices)\n\n    def reset_memory_stats(self):\n        self.sync()\n        for device in self.local_devices:\n            device.clear_memory_stats()\n\n    ##### Executable Related Functions #####\n    def put_executable(self, uuid: int,\n                       executable_class: \"MeshWorkerExecutable\", *args):\n        self.executables[uuid] = executable_class(self, uuid, *args)\n\n    def delete_executable(self, uuid: int):\n        if uuid in self.executables:\n            del self.executables[uuid]\n\n    def run_executable(self, uuid: int, *args, **kwargs):\n        self.executables[uuid].execute_on_worker(*args, **kwargs)\n\n    def get_exec_hlo_text(self, uuid: int):\n        return self.executables[uuid].get_hlo_text()\n\n    def get_exec_total_allocation_size(self, uuid: int):\n        return self.executables[uuid].get_total_allocation_size()\n\n    def get_exec_grad_sync_channel_ids(self, uuid: int):\n        return self.executables[uuid].grad_sync_channel_ids\n\n    def set_runtime_random_seed(self, seed: int):\n        seed = seed + (self.mesh_id << 20 if self.mesh_id else 0)\n        for d in self.local_devices:\n            d.set_seed(seed)\n\n    ##### Serialization Related Functions #####\n    def sync_move_worker(self):\n        ray.get(self.move_worker.sync.remote())\n\n    def save_array(self, ckpt_dir: str, local_cache_dir: Union[str, None],\n                   uuid: int, device_ids: Sequence[int],\n                   shard_indices: Sequence[Index], global_shape: Sequence[int]):\n        assert uuid in self.buffers\n        array_buffers = self.buffers[uuid]\n\n        shard_names = [\n            f\"shard_{self.host_id}.{i}\" for i in range(len(device_ids))\n        ]\n\n        metadata = {\n            \"global_shape\": global_shape,\n            \"dtype\": self.buffers[uuid][0].dtype,\n            \"shard_names\": shard_names,\n            \"shard_indices\": shard_indices,\n        }\n\n        # create directories if not exist\n        os.makedirs(ckpt_dir, exist_ok=True)\n        if local_cache_dir is not None:\n            os.makedirs(local_cache_dir, exist_ok=True)\n            save_dir = local_cache_dir\n        else:\n            save_dir = ckpt_dir\n\n        for shard_name, device_id in zip(shard_names, device_ids):\n            with open(os.path.join(save_dir, shard_name), \"wb\") as datafile:\n                np.save(datafile, array_buffers[device_id])\n\n        with open(os.path.join(save_dir, f\"metadata_{self.host_id}\"),\n                  \"wb\") as metafile:\n            pickle.dump(metadata, metafile)\n\n        # move data\n        if local_cache_dir is not None:\n            self.move_worker.move.remote(local_cache_dir, ckpt_dir)\n\n    def load_array(self, ckpt_dir: str, uuid: Sequence[int],\n                   device_ids: Sequence[int], shard_indices: Sequence[Index]):\n        metadatas = list(\n            filter(lambda fname: fname.startswith(\"metadata\"),\n                   os.listdir(ckpt_dir)))\n        # pylint: disable=import-outside-toplevel\n        from alpa.serialization import load_sharded_array\n        entire_arr = load_sharded_array(ckpt_dir, metadatas)\n        array_buffers = [None] * self.num_devices\n        for index, device_id in zip(shard_indices, device_ids):\n            data = entire_arr[index]\n            if data.dtype == np.int64:\n                data = data.astype(np.int32)\n            array_buffers[device_id] = (self.backend.buffer_from_pyval(\n                data, self.local_devices[device_id]))\n        self.buffers[uuid] = array_buffers\n\n    ##### Data loader Related Functions #####\n    def put_data_loader(self, uuid: int, *args):\n        # pylint: disable=import-outside-toplevel\n        from alpa.data_loader import MeshWorkerDataLoader\n        self.data_loaders[uuid] = MeshWorkerDataLoader(self, *args)\n\n    def data_loader_iter(self, uuid: int):\n        self.data_loader_iters[uuid] = iter(self.data_loaders[uuid])\n\n    def data_loader_next(self, uuid: int):\n        next(self.data_loader_iters[uuid])\n\n    def delete_data_loader(self, uuid: int):\n        del self.data_loaders[uuid]\n\n    ##### Cross Mesh Resharding Related Functions #####\n    @staticmethod\n    def init_collective_group(world_size, rank, backend, group_name):\n        \"\"\"Initialize the collective group eagerly.\"\"\"\n        col.init_collective_group(world_size,\n                                  rank,\n                                  backend=backend,\n                                  group_name=group_name)\n\n    @staticmethod\n    def generate_nccl_uid(group_name):\n        \"\"\"Generate the NCCL unique ID in advance.\"\"\"\n        g = col.check_and_get_group(group_name)\n        uid = g.generate_nccl_uid()\n        return uid\n\n    @staticmethod\n    def init_p2p_communicator(group_name, my_rank, my_gpu_idx, peer_rank,\n                              peer_gpu_idx, nccl_uid):\n        \"\"\"Initialize the P2P communicator from within the mesh workers.\"\"\"\n        assert col.is_group_initialized(group_name)\n        assert col.get_rank(group_name) == my_rank\n        g = col.check_and_get_group(group_name)\n        g.create_p2p_communicator(my_gpu_idx, peer_rank, peer_gpu_idx, nccl_uid)\n\n    @staticmethod\n    def init_broadcast_communicator(group_name, comm_key, world_size,\n                                    device_ids, devices_global_rank, nccl_uid):\n        \"\"\"Initialize the P2P communicator from within the mesh workers.\"\"\"\n        assert col.is_group_initialized(group_name)\n        g = col.check_and_get_group(group_name)\n        g.create_nccl_broadcast_communicator(comm_key, world_size, device_ids,\n                                             devices_global_rank, nccl_uid)\n\n    @staticmethod\n    def destroy_collective_group(group_name: str = \"default\"):\n        col.destroy_collective_group(group_name)\n\n    def create_and_set_cross_mesh_communicators(self, world_size, rank, backend,\n                                                group_name, key):\n        \"\"\"Create collective communicators for the cross mesh group.\"\"\"\n        if not col.is_group_initialized(group_name):\n            self.init_collective_group(world_size, rank, backend, group_name)\n        g = col.check_and_get_group(group_name)\n        devices = list(range(self.num_devices))\n        g.create_and_set_xla_communicators(devices, key)\n\n    def put_resharding_send_task(self, uuid, tasks, group_name):\n        self.send_tasks[uuid] = ReshardingSendTask(tile_specs=tasks,\n                                                   group_name=group_name)\n\n    def put_resharding_recv_task(self, uuid, tasks, group_name):\n        self.recv_tasks[uuid] = ReshardingRecvTask(recv_specs=tasks,\n                                                   group_name=group_name)\n\n    def run_resharding_send_task(self, uuid, ary_uuid):\n        task: ReshardingSendTask = self.send_tasks[uuid]\n        group_name = task.group_name\n        if global_config.enable_overlapping:\n            col.wait_events(group_name, [ary_uuid], self.num_devices, True)\n\n        for send_tile_spec in task.tile_specs:\n            send_tile_spec: ReshardingSendSpec\n            self.send_tile(ary_uuid, send_tile_spec.device_id,\n                           send_tile_spec.tile_spec.offset,\n                           send_tile_spec.tile_spec.rank,\n                           send_tile_spec.tile_spec.gpu_idx, task.group_name)\n\n    def run_resharding_recv_task(self, uuid, ary_uuid, set_empty_buffer=True):\n        task: ReshardingRecvTask = self.recv_tasks[uuid]\n        group_name = task.group_name\n        if set_empty_buffer and ary_uuid not in self.buffers:\n            assert not global_config.enable_overlapping, \"Unsupported.\"\n            self.buffers[ary_uuid] = [None] * self.num_devices\n\n        if global_config.enable_overlapping:\n            col.wait_events(group_name, [ary_uuid], self.num_devices, False)\n\n        buffers = self.buffers[ary_uuid]\n        for recv_spec in task.recv_specs:\n            recv_spec: ReshardingRecvSpec\n            device_id = recv_spec.device_id\n            if set_empty_buffer:\n                buffers[device_id] = self.backend.buffer_from_pyval(\n                    np.full(recv_spec.shape, 1e-8, recv_spec.dtype),\n                    self.local_devices[device_id])\n\n            for recv_tile_spec in recv_spec.tile_specs:\n                recv_tile_spec: ReshardingTileSpec\n                self.recv_tile(ary_uuid, device_id, recv_tile_spec.offset,\n                               recv_tile_spec.rank, recv_tile_spec.gpu_idx,\n                               task.group_name)\n\n        if global_config.enable_overlapping:\n            col.record_events(group_name, [ary_uuid], self.num_devices, False)\n\n    def send_tile(self, uuid: int, device_id: int, offset: Sequence[slice],\n                  dst_rank: int, dst_gpu_idx: int, group_name: str):\n        if global_config.pipeline_use_signal_send_recv:\n            signal = self.signal_buffers[device_id]\n            col.send_multigpu(signal,\n                              dst_rank,\n                              dst_gpu_idx,\n                              group_name,\n                              start_pos=0,\n                              n_elements=1)\n        else:\n            worker_nccl_util.send_tile(self, uuid, device_id, offset, dst_rank,\n                                       dst_gpu_idx, group_name)\n\n    def recv_tile(self, uuid: int, device_id: int,\n                  indices_in_dst_tile: Sequence[slice], src_rank: int,\n                  src_gpu_idx: int, group_name: str):\n        if uuid not in self.buffers:\n            raise RuntimeError(\"Buffer has not been created.\")\n\n        if global_config.pipeline_use_signal_send_recv:\n            signal = self.signal_buffers[device_id]\n            col.recv_multigpu(signal,\n                              src_rank,\n                              src_gpu_idx,\n                              group_name,\n                              start_pos=0,\n                              n_elements=1)\n        else:\n            worker_nccl_util.recv_tile(self, uuid, device_id,\n                                       indices_in_dst_tile, src_rank,\n                                       src_gpu_idx, group_name)\n\n    def put_resharding_broadcast_task(self, uuid, tasks, group_name):\n        self.broadcast_tasks[uuid] = ReshardingBroadcastTask(\n            broadcast_specs=tasks, group_name=group_name)\n\n    def run_resharding_broadcast_task(self,\n                                      uuid,\n                                      ary_uuid,\n                                      set_empty_buffer=True):\n        task: ReshardingBroadcastTask = self.broadcast_tasks[uuid]\n        group_name = task.group_name\n        broadcast_specs = task.broadcast_specs\n        if set_empty_buffer and ary_uuid not in self.buffers:\n            assert not global_config.enable_overlapping, \"Unsupported.\"\n            picked_spec = list(broadcast_specs.values())[0]\n            shape = picked_spec.recv_tile_shape\n            dtype = picked_spec.dtype\n            self.buffers[ary_uuid] = [\n                self.backend.buffer_from_pyval(np.full(shape, 1e-8, dtype),\n                                               self.local_devices[device_id])\n                for device_id in range(self.num_devices)\n            ]\n\n        has_recv = False\n        for group_idx in broadcast_specs:\n            broadcast_spec: ReshardingBroadcastSpec = broadcast_specs[group_idx]\n            is_send = broadcast_spec.devices_global_rank[0] == 0\n            has_recv = has_recv or not is_send\n            if global_config.enable_overlapping:\n                col.wait_events(group_name, [ary_uuid], self.num_devices,\n                                is_send)\n\n            worker_nccl_util.broadcast(self, ary_uuid, broadcast_spec.comm_key,\n                                       broadcast_spec.world_size,\n                                       broadcast_spec.devices_ids,\n                                       broadcast_spec.devices_global_rank,\n                                       broadcast_spec.tensor_slices,\n                                       task.group_name)\n        if global_config.enable_overlapping and has_recv:\n            col.record_events(group_name, [ary_uuid], self.num_devices, False)\n\n    ##### Profiling and Debugging Related Functions #####\n    def profile_hlo_ops(self, op_infos: Sequence[Any], cache_filename: str,\n                        single_timeout: float):\n        num_devices = self.num_hosts * len(self.local_devices)\n        return mesh_profiling.profile_hlo_ops(op_infos, self.backend,\n                                              self.local_devices, self.host_id,\n                                              num_devices, cache_filename,\n                                              single_timeout)\n\n    def profile_executable_with_dummy_inputs(self, uuid: int, **kwargs):\n        return self.executables[uuid].profile_with_dummy_inputs(\n            self.backend, self.local_devices, **kwargs)\n\n    def profile_resharding_send_task(self,\n                                     uuid,\n                                     buf_uuids,\n                                     warmup=1,\n                                     repeat=3,\n                                     number=3,\n                                     sync=False):\n        # TODO(yonghao): the sync function should be carefully reconsidered\n        def run_fn():\n            self.run_resharding_send_task(uuid, buf_uuids)\n\n        sync_fn = self.sync if sync else None\n        costs = benchmark_func(run_fn, sync_fn, warmup, repeat, number)\n        return np.mean(costs)\n\n    def profile_resharding_recv_task(self,\n                                     uuid,\n                                     buf_uuids,\n                                     warmup=1,\n                                     repeat=3,\n                                     number=3,\n                                     sync=False):\n        set_empty_buffer = True\n\n        def run_fn():\n            nonlocal set_empty_buffer\n            self.run_resharding_recv_task(uuid, buf_uuids, set_empty_buffer)\n            set_empty_buffer = False\n\n        sync_fn = self.sync if sync else None\n        costs = benchmark_func(run_fn, sync_fn, warmup, repeat, number)\n        return np.mean(costs)\n\n    @staticmethod\n    def get_timer(name: str):\n        return timers(name)\n\n    @staticmethod\n    def reset_timer(name: str):\n        timers(name).reset()\n\n    @staticmethod\n    def get_tracer():\n        return tracer\n\n    def get_live_buffer_uuids(self):\n        return list(self.buffers.keys())\n\n    ##### Other Functions #####\n    def sync(self, sync_all_devices=False):\n        # We sync one device instead of all for smaller runtime overhead.\n        # This is correct because of SPMD.\n        if sync_all_devices:\n            for device in self.local_devices:\n                device.synchronize_all_activity()\n        else:\n            self.local_devices[0].synchronize_all_activity()\n\n    def sync_all(self):\n        for device in self.local_devices:\n            device.synchronize_all_activity()\n\n    @staticmethod\n    def check_alive():\n        return True\n\n    def shutdown(self):\n        self.sync()\n        self.buffers.clear()\n        self.executables.clear()\n        self.distributed_client.shutdown()\n        # sync & shutdown DaemonMoveWorker\n        self.sync_move_worker()\n        ray.kill(self.move_worker)\n        self.move_worker = None\n\n\n########################################\n# DeviceMeshs\n########################################\nclass PhysicalDeviceMesh(ABC):\n    \"\"\"The base class of physical device mesh.\n\n    A physical device mesh is a 2-dimensional mesh that runs SPMD computation on\n    all devices in the mesh.\n    \"\"\"\n\n    num_hosts: int\n    num_devices_per_host: int\n    mesh_id: int\n    operation_executables: dict\n    one_replica_ids: dict\n\n    def get_signature(self) -> str:\n        \"\"\"Return a signature string that contains the mesh shape and GPU\n        model.\"\"\"\n        gpu_type = list_gpu_info()\n        gpu_name = gpu_type.split(\"\\n\")[0].split(\" (UUID:\")[0][7:]\n        ret = f\"{self.num_hosts},{self.num_devices_per_host},{gpu_name}\"\n        ret = ret.replace(\" \", \"-\")\n        return ret\n\n    def _compute_one_replica_ids(self, indices, aval_shape, sharding_spec):\n        # Tuple (aval_shape, sharding_spec) is 1-1 mapped to indices\n        # used to compute one_replica_ids\n        if (aval_shape, sharding_spec) in self.one_replica_ids:\n            return self.one_replica_ids[(aval_shape, sharding_spec)]\n\n        one_replica_indices = []\n        one_replica_host_local_ids = []\n        seen_index_hashes = set()\n        for i, index in enumerate(indices):\n            hashed_index = _hashable_index(index)\n            if hashed_index not in seen_index_hashes:\n                one_replica_indices.append(i)\n                one_replica_host_local_ids.append(\n                    divmod(i, self.num_devices_per_host))\n                seen_index_hashes.add(hashed_index)\n        self.one_replica_ids[(\n            aval_shape,\n            sharding_spec)] = one_replica_indices, one_replica_host_local_ids\n        return one_replica_indices, one_replica_host_local_ids\n\n    @property\n    def shape(self):\n        return self.num_hosts, self.num_devices_per_host\n\n    @property\n    def num_devices(self):\n        \"\"\"Return the total number of GPUs on this mesh.\"\"\"\n        return self.num_hosts * self.num_devices_per_host\n\n    ##### Logical Mesh Related Functions #####\n    def get_logical_mesh(self,\n                         mesh_shape: Optional[Sequence[int]] = None,\n                         mesh_alpha: Optional[float] = None,\n                         mesh_beta: Optional[float] = None,\n                         mesh_topology: Optional[str] = None,\n                         intra_host_bandwidth: Optional[float] = None,\n                         inter_host_bandwidth: Optional[float] = None):\n        \"\"\"\n        Return a logical mesh and parameters of the alpha-beta communication\n        cost model. The logical view is used for auto-sharding.\n        \"\"\"\n        if mesh_shape is None:\n            mesh_shape = (self.num_hosts, self.num_devices_per_host)\n\n        id_mesh = np.arange(self.num_devices).reshape(mesh_shape)\n\n        if mesh_topology is None:\n            # Use the provided mesh_alpha and mesh_beta\n            mesh_alpha = mesh_alpha or (1, 1)\n            mesh_beta = mesh_beta or (1, 0.1)\n        elif mesh_topology == \"tree\":\n            # Derive mesh_alpha and mesh_beta from topology,\n            # intra_host_bandwidth and inter_host_bandwidth\n            assert mesh_alpha is None\n            assert mesh_beta is None\n            mesh_alpha = [1] * 2\n            mesh_beta = [None] * 2\n            host_ids = np.tile(\n                np.arange(self.num_hosts).reshape(-1, 1),\n                self.num_devices_per_host)\n            host_ids = host_ids.reshape(mesh_shape)\n\n            # Compute bandwidth of doing communication along dim 0.\n            # 1. Compute the number of links between each host pairs.\n            #    Assume using ring-based algorithms.\n            host_link_ct = defaultdict(int)\n            for j in range(mesh_shape[1]):\n                for i in range(mesh_shape[0]):\n                    left = host_ids[i][j]\n                    right = host_ids[(i + 1) % mesh_shape[0]][j]\n                    if left != right:\n                        if left > right:\n                            left, right = right, left\n                        host_link_ct[(left, right)] += 1\n\n            j = 0\n            # 2. Bandwidth between two hosts\n            #    = total_bandwidth / number_of_links.\n            #    Bandwdith along a communication dimension\n            #    = min bandwidth of all links.\n            bandwidth = intra_host_bandwidth\n            for i in range(mesh_shape[0]):\n                left = host_ids[i][j]\n                right = host_ids[(i + 1) % mesh_shape[0]][j]\n                if left != right:\n                    if left > right:\n                        left, right = right, left\n                    bandwidth = min(\n                        bandwidth,\n                        inter_host_bandwidth / host_link_ct[(left, right)])\n            mesh_beta[0] = 1 / bandwidth\n\n            # Compute bandwidth of doing communication along dim 1.\n            host_link_ct = defaultdict(int)\n            for i in range(mesh_shape[0]):\n                for j in range(mesh_shape[1]):\n                    left = host_ids[i][j]\n                    right = host_ids[i][(j + 1) % mesh_shape[1]]\n                    if left != right:\n                        if left > right:\n                            left, right = right, left\n                        host_link_ct[(left, right)] += 1\n\n            i = 0\n            bandwidth = intra_host_bandwidth\n            for j in range(mesh_shape[1]):\n                left = host_ids[i][j]\n                right = host_ids[i][(j + 1) % mesh_shape[1]]\n                if left != right:\n                    if left > right:\n                        left, right = right, left\n                    bandwidth = min(\n                        bandwidth,\n                        inter_host_bandwidth / host_link_ct[(left, right)])\n            mesh_beta[1] = 1 / bandwidth\n\n        return LogicalDeviceMesh(self, id_mesh, mesh_alpha, mesh_beta)\n\n    ##### Executable Related Functions #####\n    @abstractmethod\n    def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]],\n                           donated_invars: Sequence[bool],\n                           batch_invars: Sequence[bool], num_micro_batches: int,\n                           args: Sequence[Any]):\n        \"\"\"Shard high-level arguments as low-level buffers.\"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def shard_args_to_arrays(self, avals: Sequence[ShapedArray],\n                             shard_indices: Sequence[Sequence[Index]],\n                             sharding_specs: Sequence[ShardingSpec],\n                             args: Sequence[Any]):\n        \"\"\"Shard arguments (np.ndarray) as distributed arrays.\"\"\"\n        raise NotImplementedError()\n\n    def shard_args_to_arrays_ps(self, placement_specs: PlacementSpec,\n                                args: Sequence[Any]):\n        \"\"\"\n        Shard arguments (np.ndarray) as distributed arrays according to\n        PlacementSpec.\n        \"\"\"\n        avals = tuple(x.aval for x in placement_specs)\n        assert all(\n            len(x.mesh_ids) == 1 and x.mesh_ids[0] == self.mesh_id\n            for x in placement_specs)\n        specs = tuple(x.sharding_specs[0] for x in placement_specs)\n        indices = tuple(\n            pxla.spec_to_indices(aval.shape, spec)\n            for aval, spec in zip(avals, specs))\n        return self.shard_args_to_arrays(avals, indices, specs, args)\n\n    @abstractmethod\n    def get_outputs_handler(self, avals: Sequence[ShapedArray],\n                            sharding_specs: Sequence[ShardingSpec]):\n        \"\"\"\n        Get a function that wraps low-level buffers to high-level output arrays.\n        \"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def set_runtime_random_seed(self, seed: int):\n        raise NotImplementedError()\n\n    ##### Profiling Related Functions #####\n    @abstractmethod\n    def get_remote_timer(self, timer_name: str):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def reset_remote_timer(self, timer_name: str):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def get_remote_tracer(self):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def get_memory_allocated(self):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def get_max_memory_allocated(self):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def get_available_memory(self):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def reset_memory_stats(self):\n        raise NotImplementedError()\n\n    ##### Other Functions #####\n    @abstractmethod\n    def sync_workers(self):\n        \"\"\"Sync device activities on all workers.\"\"\"\n        raise NotImplementedError()\n\n    @abstractmethod\n    def shutdown(self, forced=False):\n        \"\"\"Shut down the mesh.\"\"\"\n        raise NotImplementedError()\n\n\nclass LocalPhysicalDeviceMesh(PhysicalDeviceMesh):\n    \"\"\"\n    A single-host physical device mesh to run computation on local devices.\n    It uses the native XLA runtime.\n    \"\"\"\n\n    def __init__(self, devices: Sequence[\"Device\"] = None):\n        self.devices = devices if devices is not None else xb.local_devices()\n        self.num_hosts = 1\n        self.num_devices_per_host = len(self.devices)\n        self.mesh_id = -1\n        self.device_strs = []\n        self.operation_executables = {}\n        self.one_replica_ids = {}\n\n        self.backend = xb.get_backend(global_config.backend)\n\n        self.set_runtime_random_seed(global_config.runtime_random_seed)\n\n    ##### Executable Related Functions #####\n    def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]],\n                           donated_invars: Sequence[bool],\n                           batch_invars: Sequence[bool], num_micro_batches: int,\n                           args: Sequence[Any]):\n        bufs = []\n        for arg, indices, donated, is_batch_var in zip(args, shard_indices,\n                                                       donated_invars,\n                                                       batch_invars):\n            if is_batch_var:\n                micro_batches = jnp.split(arg, num_micro_batches)\n                bufs.append([\n                    pxla._shard_arg(x, self.devices, indices, None)\n                    for x in micro_batches\n                ])\n            else:\n                if (isinstance(arg, pxla.ShardedDeviceArray) and\n                        arg.indices == indices):\n                    bufs.append(arg.device_buffers)\n                else:\n                    bufs.append(\n                        pxla._shard_arg(arg, self.devices, indices, None))\n\n            if isinstance(arg, xe.DeviceArray) and donated:\n                arg.delete()\n\n        return bufs\n\n    def shard_args_to_arrays(self, avals: Sequence[ShapedArray],\n                             shard_indices: Sequence[Sequence[Index]],\n                             sharding_specs: Sequence[ShardingSpec],\n                             args: Sequence[Any]):\n        arrays = []\n        for i in range(len(avals)):\n            if global_config.use_dummy_value_for_benchmarking:\n                args[i] = np.full(avals[i].shape, 1e-8, avals[i].dtype)\n            shards = [\n                args[i][shard_indices[i][k]] for k in range(len(self.devices))\n            ]\n            buffers = [device_put(x, d) for x, d in zip(shards, self.devices)]\n            arrays.append(\n                pxla._ShardedDeviceArray(avals[i], sharding_specs[i], buffers,\n                                         shard_indices[i]))\n        return arrays\n\n    def get_outputs_handler(self, avals: Sequence[ShapedArray],\n                            sharding_specs: Sequence[ShardingSpec]):\n        pmap_specs = pxla._get_pmap_sharding(np.arange(self.num_devices),\n                                             sharding_specs)\n        outs_handler = pxla.local_avals_to_results_handler(avals, pmap_specs)\n        return outs_handler\n\n    def set_runtime_random_seed(self, seed: int):\n        for d in self.devices:\n            if d is not None:\n                d.set_seed(seed)\n\n    ##### Profiling Related Functions #####\n    def get_remote_timer(self, timer_name: str):\n        return timers(timer_name)\n\n    def reset_remote_timer(self, timer_name: str):\n        timers(timer_name).reset()\n\n    def get_remote_tracer(self):\n        return tracer\n\n    def get_memory_allocated(self):\n        return max(d.memory_allocated() for d in self.devices)\n\n    def get_max_memory_allocated(self):\n        return max(d.max_memory_allocated() for d in self.devices)\n\n    def get_available_memory(self):\n        return min(device.available_memory() for device in self.devices)\n\n    def reset_memory_stats(self):\n        for device in self.devices:\n            device.clear_memory_stats()\n\n    ##### Other Functions #####\n    def sync_workers(self):\n        # We sync one device instead of all for smaller runtime overhead.\n        # This is correct because of SPMD.\n        self.devices[0].synchronize_all_activity()\n\n    def shutdown(self, forced=False):\n        self.sync_workers()\n        self.operation_executables.clear()\n\n\ndef device_id_to_str(host_ip, device_id, device_type=\"gpu\"):\n    \"\"\"Convert device id (int) to a canonical device string.\"\"\"\n    return f\"{host_ip}:{device_type}:{device_id}\"\n\n\n# Used ports for XLA distributed runtime servers.\nused_port_set = set((None,))\n\n\nclass DistributedPhysicalDeviceMesh(PhysicalDeviceMesh):\n    \"\"\"\n    A multi-host physical device mesh to run computation distributedly.\n    It uses ray actors and the distributed XLA runtime.\n    \"\"\"\n\n    def __init__(self,\n                 host_ids: Sequence[int],\n                 host_info: Sequence[dict],\n                 num_devices_per_host: int,\n                 parent: Optional[\"VirtualPhysicalMesh\"] = None,\n                 devices: Optional[Sequence[Sequence[int]]] = None,\n                 mesh_id: Optional[int] = None,\n                 namespace: Optional[str] = None):\n        # host_ids are the indices of hosts in the global DeviceCluster\n        self.host_ids = host_ids\n        self.host_info = host_info\n        self.num_hosts = len(host_ids)\n        self.num_devices_per_host = num_devices_per_host\n        self.parent = parent\n        self.mesh_id = mesh_id\n        self.workers = None\n        self.service_server = None\n        self.operation_executables = {}\n        self.one_replica_ids = {}\n        self.namespace = namespace\n\n        if devices is not None:\n            if len(devices) != len(host_ids):\n                raise RuntimeError(\n                    \"Please specify the gpu IDs used on each host.\")\n            if not all(len(ids) == num_devices_per_host for ids in devices):\n                raise RuntimeError(\n                    \"Devices specified for each host does not align \"\n                    \"with `num_devices_per_host`.\")\n        else:\n            devices = [list(range(num_devices_per_host)) for _ in host_ids]\n\n        self.devices = devices\n        self.device_strs = []\n        self.node_ips = []\n        for i in range(self.num_hosts):\n            ip = self.host_info[i][\"NodeManagerAddress\"]\n            self.device_strs.extend(\n                [device_id_to_str(ip, j) for j in devices[i]])\n            self.node_ips.append(ip)\n\n        found_existing_workers = False\n        if self.namespace:\n            try:\n                ray.get_actor(self.get_host_worker_name(0))\n                found_existing_workers = True\n            except ValueError:\n                pass\n\n        if found_existing_workers:\n            self.service_server = None\n            self.workers = self.connect_to_existing_workers()\n            self.launched = False\n        else:\n            self.service_server, self.workers = self.launch_xla_servers()\n            self.launched = True\n\n        self.to_delete_remote_refs = []\n        self.to_delete_remote_ref_ct = 0\n\n    def get_host_worker_name(self, host_id):\n        if self.namespace:\n            return f\"mesh_{self.mesh_id}_host_{host_id}\"\n        else:\n            return None\n\n    def connect_to_existing_workers(self):\n        workers = []\n        for i in range(self.num_hosts):\n            workers.append(ray.get_actor(self.get_host_worker_name(i)))\n        return workers\n\n    def launch_xla_servers(self):\n        # Launch distributed xla runtime\n        port = None\n        while port in used_port_set:\n            port = np.random.randint(global_config.xla_server_port_start,\n                                     global_config.xla_server_port_end)\n            if check_server_port(ray.util.get_node_ip_address(), port):\n                port = None\n        used_port_set.add(port)\n\n        server_address = f\"{ray.util.get_node_ip_address()}:{port}\"\n        logger.debug(f\"Trying to start XLA gRPC server on port: {port}...\")\n        service_server = xla_client._xla.get_distributed_runtime_service(\n            server_address, self.num_hosts, use_coordination_service=False)\n        logger.debug(f\"Success to start XLA gRPC server on port: {port}...\")\n        time.sleep(0.4)\n\n        # Launch workers\n        workers = []\n\n        # retrieve the placement group\n        placement_group = retrieve_placement_group()\n\n        # get the sorted bundle index list\n        device_bundle_idx_list = get_bundle_idx(placement_group, self.node_ips)\n\n        for i in range(self.num_hosts):\n            # Set XLA environment variables\n            env_vars = {\n                \"ALPA_IS_WORKER\":\n                    \"True\",\n                \"NCCL_USE_MULTISTREAM\":\n                    \"False\",\n                \"XLA_PYTHON_CLIENT_MEM_FRACTION\":\n                    str(global_config.xla_client_mem_fraction),\n                \"XLA_FLAGS\": (os.environ.get(\"XLA_FLAGS\", \"\") +\n                              f\" --xla_gpu_autotune_level\"\n                              f\"={global_config.xla_gpu_autotune_level}\"),\n                \"XLA_PYTHON_CLIENT_PREALLOCATE\":\n                    global_config.xla_client_client_preallocate,\n                # \"NCCL_LAUNCH_MODE\": \"PARALLEL\",\n                # \"XLA_FLAGS\": \"--xla_dump_to=hlo --xla_dump_hlo_pass_re=.*\"\n                # \"NCCL_DEBUG\": \"INFO\" if i == 0 else \"VERSION\",\n                # \"NCCL_DEBUG_SUBSYS\": \"ALL\",\n                # \"RAY_IGNORE_UNHANDLED_ERRORS\": \"True\",\n            }\n\n            if global_config.resharding_mode == \"broadcast\":\n                env_vars[\"NCCL_ALGO\"] = \"Ring\"\n                env_vars[\"NCCL_PROTO\"] = \"Simple\"\n\n            if \"XLA_PYTHON_CLIENT_ALLOCATOR\" in os.environ:\n                env_vars[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = os.environ[\n                    \"XLA_PYTHON_CLIENT_ALLOCATOR\"]\n\n            if \"NCCL_DEBUG\" in os.environ:\n                env_vars[\"NCCL_DEBUG\"] = os.environ[\n                    \"NCCL_DEBUG\"] if i == 0 else \"VERSION\"\n\n            if global_config.use_aws_efa:\n                env_vars.update({\n                    \"FI_PROVIDER\": \"efa\",\n                    \"FI_EFA_USE_DEVICE_RDMA\": \"1\",\n                    \"LD_LIBRARY_PATH\": os.environ.get(\"LD_LIBRARY_PATH\",\n                                                      \"\"),  # For libnccl-net.so\n                    \"NCCL_PROTO\": \"simple\",\n                })\n\n            bundle_index = device_bundle_idx_list[i]\n\n            host_worker_name = self.get_host_worker_name(i)\n\n            # Launch the DaemonMoveWorker\n            cls = ray.remote(num_cpus=0)(DaemonMoveWorker)\n            move_worker = cls.options(\n                placement_group=placement_group,\n                placement_group_bundle_index=bundle_index).remote()\n\n            # Launch the MeshHostWorker\n            cls = ray.remote(num_cpus=0,\n                             num_gpus=self.num_devices_per_host)(MeshHostWorker)\n            worker = cls.options(placement_group=placement_group,\n                                 placement_group_bundle_index=bundle_index,\n                                 name=host_worker_name,\n                                 runtime_env={\n                                     \"env_vars\": env_vars\n                                 }).remote(server_address, self.num_hosts, i,\n                                           self.mesh_id, move_worker,\n                                           global_config.runtime_random_seed,\n                                           global_config)\n            workers.append(worker)\n        return service_server, workers\n\n    @property\n    def host_ips(self):\n        ips = [\n            self.host_info[i][\"NodeManagerAddress\"]\n            for i, _ in enumerate(self.host_ids)\n        ]\n        return ips\n\n    def get_virtual_physical_mesh(self):\n        return VirtualPhysicalMesh(\n            host_ids=self.host_ids,\n            host_info=self.host_info,\n            num_devices_per_host=self.num_devices_per_host,\n            parent=self,\n            devices=self.devices)\n\n    def _split_ids_to_host(self, host_local_ids: Sequence[Tuple[int, int]]):\n        if host_local_ids is None:\n            full_local_id = [\n                range(self.num_devices_per_host) for _ in range(self.num_hosts)\n            ]\n            full_id_local_idx = [(i, j)\n                                 for i in range(self.num_hosts)\n                                 for j in range(self.num_devices_per_host)]\n            return tuple(full_local_id), full_id_local_idx\n        per_host_id = [[] for _ in range(self.num_hosts)]\n        host_id_local_idx = []\n        for id_pair in host_local_ids:\n            host_id, device_id = id_pair\n            host_id_local_idx.append((host_id, len(per_host_id[host_id])))\n            per_host_id[host_id].append(device_id)\n        return per_host_id, host_id_local_idx\n\n    ##### Buffer Related Functions #####\n    def get_remote_buffers(\n            self,\n            ary_refs: Union[List[\"RemoteArrayRef\"], \"RemoteArrayRef\"],\n            host_local_ids: Sequence[Sequence[Tuple[int, int]]] = None,\n            batching=False,\n            return_ray_ref=False):\n        \"\"\"\n        Get values of remote buffers.\n\n        Args:\n            host_local_ids: For each RemoteArrayRef, we can fetch a list of\n              buffers from multiple devices on multiple hosts. This variable\n              defines a list of (host_id, local_id) pair for each\n              RemoteArrayRef. If it is None, fetch all remote buffers.\n            batching: Whether batch remote calls by host ids. This can reduce\n              ray overhead.\n        \"\"\"\n        return_list = True\n        if not isinstance(ary_refs, Iterable):\n            return_list = False\n            ary_refs = [ary_refs]\n        if host_local_ids is None:\n            host_local_ids = [None] * len(ary_refs)\n        elif not isinstance(host_local_ids, Iterable):\n            assert not return_list\n            host_local_ids = [host_local_ids]\n\n        if batching:\n            # Batch the remote calls by host ids\n            ary_ids = np.array([ref.uuid for ref in ary_refs])\n            per_host_ids = np.empty((self.num_hosts, len(ary_ids)),\n                                    dtype=object)\n            host_id_local_indices = []\n            for arg_id, id_pairs in enumerate(host_local_ids):\n                tmp_ids, tmp_indices = self._split_ids_to_host(id_pairs)\n                host_id_local_indices.append(tmp_indices)\n                for host_id, tmp_per_host in enumerate(tmp_ids):\n                    per_host_ids[host_id][arg_id] = np.array(tmp_per_host)\n\n            # [host_id-> (buf_idx-> (local_device_id->device_buffer))]\n            obj_refs = []\n            for host_id in range(self.num_hosts):\n                obj_refs.append(self.workers[host_id].get_buffers.remote(\n                    ary_ids, per_host_ids[host_id]))\n            per_host_results = ray.get(obj_refs)\n            # [buf_id -> (flatten_id -> device_buffer)]\n            ret = []\n            for ref_idx, id_pairs in enumerate(host_id_local_indices):\n                buffers = []\n                for id_pair in id_pairs:\n                    host_id, local_idx = id_pair\n                    buffers.append(\n                        per_host_results[host_id][ref_idx][local_idx])\n                ret.append(buffers)\n        else:\n            obj_refs = []\n            for ary_ref, id_pairs in zip(ary_refs, host_local_ids):\n                ary_obj_refs = []\n                for id_pair in id_pairs:\n                    host_id, local_id = id_pair\n                    ary_obj_refs.append(\n                        self.workers[host_id].get_buffers.remote(\n                            ary_ref.uuid, local_id))\n                obj_refs.append(ary_obj_refs)\n            if return_ray_ref:\n                ret = obj_refs\n            else:\n                ret = [ray.get(refs) for refs in obj_refs]\n        return ret if return_list else ret[0]\n\n    def delete_remote_buffers(self, ary_refs: List[\"RemoteArrayRef\"]):\n        \"\"\"Delete remote buffers.\"\"\"\n        if not self.workers or not ray or not ray_worker or not np.array:\n            return\n\n        # Put delete requests into a buffer\n        for ary_ref in ary_refs:\n            self.to_delete_remote_refs.append(ary_ref.uuid)\n        self.to_delete_remote_ref_ct += len(ary_refs)\n\n        # Execute the delete requests if there are enough requests\n        if (self.to_delete_remote_ref_ct >\n                global_config.delete_remote_arrays_threshold):\n            to_delete_remote_refs = np.array(self.to_delete_remote_refs)\n            try:\n                for host_id in range(self.num_hosts):\n                    self.workers[host_id].delete_buffers.remote(\n                        to_delete_remote_refs)\n            except AttributeError:\n                pass\n            self.to_delete_remote_refs = []\n            self.to_delete_remote_ref_ct = 0\n\n    def block_until_ready_remote_buffers(self,\n                                         ary_refs: List[\"RemoteArrayRef\"]):\n        \"\"\"Block until the remote buffers are ready.\"\"\"\n        tasks = []\n        ary_uuids = np.array([ref.uuid for ref in ary_refs])\n        for worker in self.workers:\n            tasks.append(worker.block_until_ready_buffers.remote(ary_uuids))\n        ray.get(tasks)\n\n    ##### Executable Related Functions #####\n    def shard_args_to_bufs(self, shard_indices: Sequence[Sequence[Index]],\n                           donated_invars: Sequence[bool],\n                           batch_invars: Sequence[bool], num_micro_batches: int,\n                           args: Sequence[Any]):\n        ret_bufs = []\n        total_bytes = 0\n        time_start = time.time()\n\n        for arg, indices, donated, is_batch_var in zip(args, shard_indices,\n                                                       donated_invars,\n                                                       batch_invars):\n            tic = time.time()\n            slow_path = False\n\n            if is_batch_var:\n                if (isinstance(arg, DistributedArray) and\n                        arg.skip_shard_args_check is True):\n                    assert num_micro_batches == 1\n                    ret_bufs.append([arg.remote_ref])\n                else:\n                    slow_path = True\n                    if not isinstance(arg, ShapedArray):\n                        arg = np.asarray(arg)\n                    refs = _shard_array(arg, self, indices, num_micro_batches)\n                    ret_bufs.append(refs)\n            else:\n                if (isinstance(arg, DistributedArray) and\n                        arg.device_mesh == self and arg.indices == indices):\n                    # Fast path for DistributedArray\n                    ret_bufs.append(arg.remote_ref)\n                elif isinstance(arg, ReplicatedDistributedArray):\n                    replica = arg.get_replica_on_mesh(self)\n                    assert replica.indices == indices\n                    ret_bufs.append(replica.remote_ref)\n                else:  # Slow path\n                    slow_path = True\n                    if type(arg) not in [ShapedArray, ShapeDtypeStruct]:\n                        arg = xla.canonicalize_dtype(arg)\n                    ref = shard_arg_handlers[type(arg)](arg, self, indices)[0]\n                    ret_bufs.append(ref)\n                    if donated and hasattr(arg, \"delete\"):\n                        # shard_arg_handler always creates new buffers,\n                        # so we can delete the old buffers\n                        arg.delete()\n\n            if False and slow_path:  # pylint: disable=condition-evals-to-constant\n                # Print debug info\n                size = np.prod(arg.shape) * arg.dtype.itemsize\n                bandwidth = size / (time.time() - tic)\n                total_bytes += size\n                print(\"Slow path. \"\n                      f\"shape: {arg.shape}, \"\n                      f\"bandwidth: {bandwidth/1024**2:.2f} MB/s \"\n                      f\"total_bytes: {total_bytes/1024**2:.2f} MB \"\n                      f\"total_time: {time.time() - time_start:.2f}\")\n\n        return ret_bufs\n\n    def shard_args_to_arrays(self, avals: Sequence[ShapedArray],\n                             shard_indices: Sequence[Sequence[Index]],\n                             sharding_specs: Sequence[ShardingSpec],\n                             args: Sequence[np.array]):\n        arrays = []\n        for i in range(len(avals)):\n            remote_ref = _shard_array(args[i], self, shard_indices[i])[0]\n            arrays.append(\n                DistributedArray(self, avals[i], sharding_specs[i], remote_ref,\n                                 shard_indices[i]))\n        return arrays\n\n    def get_outputs_handler(self, avals: Sequence[ShapedArray],\n                            sharding_specs: Sequence[ShardingSpec]):\n        indices = [\n            pxla.spec_to_indices(aval.shape, spec)\n            for aval, spec in zip(avals, sharding_specs)\n        ]\n\n        def outs_handler(refs):\n            ret = []\n            for i, aval in enumerate(avals):\n                dis_array = DistributedArray(device_mesh=self,\n                                             aval=aval,\n                                             sharding_spec=sharding_specs[i],\n                                             remote_ref=refs[i],\n                                             indices=indices[i])\n                ret.append(dis_array)\n            return ret\n\n        return outs_handler\n\n    def delete_remote_executable(self, exec_uuid: int):\n        \"\"\"Delete remote worker executables of a driver executable.\"\"\"\n        if not self.workers or not ray or not ray_worker or not np.array:\n            return\n\n        try:\n            for w in self.workers:\n                w.delete_executable.remote(exec_uuid)\n        except AttributeError:\n            pass\n\n    def set_runtime_random_seed(self, seed: int):\n        for w in self.workers:\n            w.set_runtime_random_seed.remote(seed)\n\n    ##### Profiling and Debugging Related Functions #####\n    def profile_hlo_ops(self,\n                        op_infos: Sequence[Tuple],\n                        cache_filename: str,\n                        single_timeout: Optional[float] = None,\n                        batch_timeout: Optional[float] = None):\n        tasks = []\n        for w in self.workers:\n            tasks.append(\n                w.profile_hlo_ops.remote(op_infos, cache_filename,\n                                         single_timeout))\n        return ray.get(tasks, timeout=batch_timeout)[0]\n\n    def get_remote_timer(self, timer_name: str):\n        return ray.get(self.workers[0].get_timer.remote(timer_name))\n\n    def reset_remote_timer(self, timer_name: str):\n        for worker in self.workers:\n            ray.get(worker.reset_timer.remote(timer_name))\n\n    def get_remote_tracer(self):\n        return ray.get(self.workers[0].get_tracer.remote())\n\n    def get_memory_allocated(self):\n        return max(\n            ray.get([w.get_memory_allocated.remote() for w in self.workers]))\n\n    def get_max_memory_allocated(self):\n        return max(\n            ray.get([w.get_max_memory_allocated.remote() for w in self.workers\n                    ]))\n\n    def get_available_memory(self):\n        return min(\n            ray.get([w.get_available_memory.remote() for w in self.workers]))\n\n    def reset_memory_stats(self):\n        for worker in self.workers:\n            ray.get(worker.reset_memory_stats.remote())\n\n    ##### Other Functions #####\n    def sync_workers(self, sync_all_devices=False):\n        ray.get([w.sync.remote(sync_all_devices) for w in self.workers])\n\n    def sync_move_workers(self):\n        ray.get([w.sync_move_worker.remote() for w in self.workers])\n\n    def shutdown(self, forced=False):\n        self.operation_executables.clear()\n        if not self.launched:\n            return\n        if not forced:\n            ray.get([w.shutdown.remote() for w in self.workers])\n        for worker in self.workers:\n            ray.kill(worker)\n        self.workers = None\n        # shutdown grpc server\n        if self.service_server:\n            self.service_server.shutdown()\n            self.service_server = None\n        self.launched = False\n\n\n########################################\n# Distributed Array and Buffers\n########################################\nclass RemoteArrayRef:\n    \"\"\"\n    A reference to all device buffers of a logical array.\n\n    In Alpa, each pipeshard stage runs in SPMD(single program, multiple device).\n    Hence, buffers of the same logical array are allocated, used and freed\n    together, and thus we use one reference for all these buffers.\n    \"\"\"\n\n    def __init__(self, device_mesh: PhysicalDeviceMesh, uuid: int = None):\n        self.device_mesh = device_mesh\n        self.uuid = (uuid if uuid is not None else next_array_uuids()[0])\n        self.is_deleted_on_workers = False\n\n    def set_deleted_on_workers(self):\n        \"\"\"\n        Set the array as deleted on workers.\n        For some arrays (e.g., donated tensor), if we know the workers has\n        already deleted them, then we do not need to do the remote call\n        \"delete_remote_buffers\" again.\n        \"\"\"\n        self.is_deleted_on_workers = True\n\n    def __repr__(self):\n        return (f\"RemoteBufferRef(uuid = {self.uuid}, \"\n                f\"loc = Mesh ({self.device_mesh.mesh_id}))\")\n\n    def __del__(self):\n        if not self.is_deleted_on_workers:\n            self.device_mesh.delete_remote_buffers((self,))\n\n\n# The global buffer counter\nremote_buffer_counter = 0\n\n\ndef next_array_uuids(number=1):\n    \"\"\"Return the next uuid of a remote buffer.\"\"\"\n    global remote_buffer_counter\n    ret = np.arange(remote_buffer_counter, remote_buffer_counter + number)\n    remote_buffer_counter = (remote_buffer_counter + number) % (1 << 60)\n    return ret\n\n\ndef create_remote_array_refs(device_mesh, number=1):\n    \"\"\"Create a list of remote array refs.\"\"\"\n    ary_uuids = next_array_uuids(number)\n    ary_refs = [RemoteArrayRef(device_mesh, uuid) for uuid in ary_uuids]\n    return ary_refs, ary_uuids\n\n\nclass DistributedArray:\n    \"\"\"A distributed array on a PhysicalDeviceMesh.\n\n    End users can interact with this array as if they are working with\n    a normal numpy array.\n\n    Internally, it stores a pointer to all remote buffers.\n    The buffers are stored distributedly on remote workers' device memory.\n    When users require the value of the array. These buffers will be gathered\n    to the driver.\n    \"\"\"\n\n    def __init__(self,\n                 device_mesh: PhysicalDeviceMesh,\n                 aval: ShapedArray,\n                 sharding_spec: ShardingSpec,\n                 remote_ref: RemoteArrayRef,\n                 indices: Optional[Sequence[Index]] = None):\n        self.device_mesh = device_mesh\n        self.aval = aval\n        self.sharding_spec = sharding_spec\n        self.remote_ref = remote_ref\n\n        if indices is None:\n            indices = pxla.spec_to_indices(self.aval.shape, self.sharding_spec)\n        self.indices = indices\n\n        self.shape = self.aval.shape\n        self.dtype = self.aval.dtype\n        self._npy_value = None\n        self._fetched_np_buffers = None\n        self._fetched_np_buffers_ref = None\n        self.skip_shard_args_check = False\n\n    @property\n    def size(self):\n        return np.prod(self.shape)\n\n    def prefetch(self):\n        # TODO (yinmin): Move this function out of DistributedArray\n        #  and batch different requests. Also need to add another\n        #  function to `ray.wait` for the remote references.\n        self._fetched_np_buffers_ref = self.device_mesh.get_remote_buffers(\n            (self.remote_ref,), (self.one_replica_host_local_ids,), False,\n            True)[0]\n\n    def block_until_ready(self):\n        \"\"\"Block until all remote buffers of this array are ready.\"\"\"\n        self.device_mesh.block_until_ready_remote_buffers([self.remote_ref])\n\n    def delete(self):\n        self.remote_ref = None\n        self._npy_value = None\n\n    def flush(self):\n        self._npy_value = None\n\n    async def to_np_async(self):\n        if self._npy_value is None:\n            npy_value = np.empty(self.aval.shape, self.aval.dtype)\n            if not self._fetched_np_buffers:\n                if not self._fetched_np_buffers_ref:\n                    self.prefetch()\n                fetched_np_buffers = await asyncio.gather(\n                    *self._fetched_np_buffers_ref)\n            else:\n                fetched_np_buffers = self._fetched_np_buffers\n            for ct, i in enumerate(self.one_replica_buffer_ids):\n                npy_value[self.indices[i]] = fetched_np_buffers[ct]\n            self._npy_value = npy_value\n        return self._npy_value\n\n    ##### distributed save/load #####\n    def save(self, ckpt_dir: str, local_cache_dir: Union[str, None] = None):\n        \"\"\"\n            Save one replica of the array to `ckpt_dir` distributedly.\n\n            Args:\n                ckpt_dir: The directory where all the shards of\n                this array will be saved.\n                local_cache_dir: If not None, `ckpt_dir` should be a shared\n                filesystem path, and this function will return as soon as the\n                shards have been saved to this local directory.\n                DaemonMoveWorkers will move these shards into `ckpt_dir`\n                in the background.\n\n        \"\"\"\n        one_replica_indices = [\n            self.indices[i] for i in self.one_replica_buffer_ids\n        ]\n        device_ids_per_host = {}\n        indices_per_host = {}\n        for buf_id, indice in zip(self.one_replica_host_local_ids,\n                                  one_replica_indices):\n            host_id, device_id = buf_id\n            if indices_per_host.get(host_id) is None:\n                indices_per_host[host_id] = [indice]\n                device_ids_per_host[host_id] = [device_id]\n            else:\n                indices_per_host[host_id].append(indice)\n                device_ids_per_host[host_id].append(device_id)\n        for host_id, indices in indices_per_host.items():\n            if len(indices) > 0:\n                self.device_mesh.workers[host_id].save_array.remote(\n                    ckpt_dir, local_cache_dir, self.remote_ref.uuid,\n                    np.array(device_ids_per_host[host_id]), indices, self.shape)\n\n    @classmethod\n    def load(cls, path: str, aval: ShapedArray, device_mesh: PhysicalDeviceMesh,\n             sharding_spec: ShardingSpec):\n        \"\"\"\n            Load the data from `path` distributedly with `aval` and\n            return a new DistributedArray\n        \"\"\"\n        # pylint: disable=import-outside-toplevel\n        ary_ref = RemoteArrayRef(device_mesh)\n        indices = pxla.spec_to_indices(aval.shape, sharding_spec)\n\n        indices_per_host = {}\n        device_ids_per_host = {}\n        for buf_idx, indice in enumerate(indices):\n            host_id, device_id = divmod(buf_idx,\n                                        device_mesh.num_devices_per_host)\n            if indices_per_host.get(host_id) is None:\n                indices_per_host[host_id] = [indice]\n                device_ids_per_host[host_id] = [device_id]\n            else:\n                indices_per_host[host_id].append(indice)\n                device_ids_per_host[host_id].append(device_id)\n        for host_id, worker in enumerate(device_mesh.workers):\n            worker.load_array.remote(path, ary_ref.uuid,\n                                     device_ids_per_host[host_id],\n                                     indices_per_host[host_id])\n        return DistributedArray(device_mesh, aval, sharding_spec, ary_ref,\n                                indices)\n\n    @property\n    def one_replica_buffer_ids(self):\n        \"\"\"Indices of buffers containing one complete copy of the array data.\"\"\"\n        return self.device_mesh._compute_one_replica_ids(\n            self.indices, self.aval.shape, self.sharding_spec)[0]\n\n    @property\n    def one_replica_host_local_ids(self):\n        return self.device_mesh._compute_one_replica_ids(\n            self.indices, self.aval.shape, self.sharding_spec)[1]\n\n    @property\n    def _value(self):\n        if self._npy_value is None:\n            npy_value = np.empty(self.aval.shape, self.aval.dtype)\n            if not self._fetched_np_buffers:\n                if not self._fetched_np_buffers_ref:\n                    fetched_np_buffers = self.device_mesh.get_remote_buffers(\n                        (self.remote_ref,),\n                        (self.one_replica_host_local_ids,))[0]\n                else:\n                    fetched_np_buffers = ray.get(self._fetched_np_buffers_ref)\n            else:\n                fetched_np_buffers = self._fetched_np_buffers\n            for ct, i in enumerate(self.one_replica_buffer_ids):\n                npy_value[self.indices[i]] = fetched_np_buffers[ct]\n            self._npy_value = npy_value\n        return self._npy_value\n\n    def __array__(self, dtype=None, context=None):\n        # pylint: disable=unused-argument\n        return np.asarray(self._value, dtype=dtype)\n\n    def __float__(self):\n        return self._value.__float__()\n\n    # TODO(lmzheng): copy more functions from DeviceArray\n    #   (jax/_src/device_array.py)\n\n    def __str__(self):\n        return (f\"DistributedArray(sharding_spec={self.sharding_spec}, \"\n                f\"value={self._value})\")\n\n    def __del__(self):\n        self.delete()\n\n\ncore.pytype_aval_mappings[DistributedArray] = attrgetter(\"aval\")\nxla.pytype_aval_mappings[DistributedArray] = attrgetter(\"aval\")\nxla.canonicalize_dtype_handlers[DistributedArray] = lambda x: x\n\n\nclass ReplicatedDistributedArray:\n    \"\"\"A distributed array that is replicated on multiple meshes.\n\n    These class is used for arrays that need to be replicated on\n    multiple physical meshes (e.g., optimizer's step).\n    \"\"\"\n\n    def __init__(self, device_meshes: Sequence[PhysicalDeviceMesh],\n                 arrays: Sequence[DistributedArray]):\n        self._mesh_array_map = {}\n        self._array_mesh_map = {}\n        for mesh, array in zip(device_meshes, arrays):\n            self._mesh_array_map[mesh] = array\n            self._array_mesh_map[array] = mesh\n        self.aval = self.replica.aval\n\n    def is_replicated_on_mesh(self, mesh: PhysicalDeviceMesh):\n        \"\"\"Whether this distributed array is on a given mesh.\"\"\"\n        if mesh in self._mesh_array_map:\n            return True\n        return False\n\n    def get_replica_on_mesh(self, mesh: PhysicalDeviceMesh):\n        if not self.is_replicated_on_mesh(mesh):\n            raise RuntimeError(\"No replica found on this mesh.\")\n        return self._mesh_array_map[mesh]\n\n    def add_replica(self, mesh: PhysicalDeviceMesh, array: DistributedArray):\n        assert isinstance(array, DistributedArray)\n        assert isinstance(mesh, PhysicalDeviceMesh)\n        if array in self._array_mesh_map:\n            raise RuntimeError(\"Replica exists.\")\n        if mesh in self._mesh_array_map:\n            raise RuntimeError(\"Mesh exists.\")\n        self._mesh_array_map.update({mesh: array})\n        self._array_mesh_map.update({array: mesh})\n\n    @property\n    def replica(self):\n        return list(self._mesh_array_map.values())[0]\n\n    @property\n    def _value(self):\n        return self.replica._value\n\n    def __array__(self, dtype=None, context=None):\n        # pylint: disable=unused-argument\n        return np.asarray(self._value, dtype=dtype)\n\n    def __str__(self):\n        return str(self._value)\n\n\ncore.pytype_aval_mappings[ReplicatedDistributedArray] = attrgetter(\"aval\")\nxla.pytype_aval_mappings[ReplicatedDistributedArray] = attrgetter(\"aval\")\nxla.canonicalize_dtype_handlers[ReplicatedDistributedArray] = lambda x: x\n\n\ndef prefetch(dis_arrays: Sequence[Union[ShardedDeviceArray, DistributedArray,\n                                        ReplicatedDistributedArray]]):\n    \"\"\"Prefetch a pytree of DistributedArray in a batch.\n\n    If you want to get a lot of DistributedArrays from remote workers,\n    call this batched prefetch can make the later access faster.\n    \"\"\"\n    group_by_mesh = defaultdict(list)\n    for array in tree_leaves(dis_arrays):\n        if isinstance(array, ShardedDeviceArray):\n            array.copy_to_host_async()\n        elif isinstance(array, DistributedArray):\n            group_by_mesh[array.device_mesh].append(array)\n        elif isinstance(array, ReplicatedDistributedArray):\n            array = array.replica\n            group_by_mesh[array.device_mesh].append(array)\n        else:\n            raise ValueError(f\"Unhandled array type: {array}\")\n\n    for device_mesh, arrays in group_by_mesh.items():\n        buf_refs = []\n        host_local_ids = []\n        for array in arrays:\n            buf_refs.append(array.remote_ref)\n            host_local_ids.append(array.one_replica_host_local_ids)\n\n        np_arrays = device_mesh.get_remote_buffers(buf_refs,\n                                                   host_local_ids,\n                                                   batching=True)\n\n        for array, np_value in zip(arrays, np_arrays):\n            array._fetched_np_buffers = np_value  # pylint: disable=protected-access\n\n\n########################################\n##### Physical Mesh Group #####\n########################################\nclass VirtualPhysicalMesh:\n    \"\"\"\n    A virtual physical mesh used for pipeline parallel compilation.\n\n    VirtualPhysicalMesh is used during compile time. We don't allocate actual\n    workers for it. When compilation is finished, we instantiated it as a\n    PhysicalDeviceMesh and launch workers.\n\n    A VirtualPhysicalMesh can also be sliced into multiple VirtualPhysicalMesh.\n    After slicing, each sliced VirtualPhysicalMesh can be instantiated as a\n    PhysicalDeviceMesh. These sliced PhysicalDeviceMesh together can form a\n    PhysicalDeviceMeshGroup for pipeline parallelism.\n    \"\"\"\n\n    def __init__(self,\n                 host_ids: Sequence[int],\n                 host_info: Sequence[dict],\n                 num_devices_per_host,\n                 parent: \"VirtualPhysicalMesh\" = None,\n                 devices: Sequence[Sequence[int]] = None):\n        # host_ids are the indices of hosts in the global DeviceCluster\n        self.host_ids = host_ids\n        self.host_info = host_info\n        self.num_devices_per_host = num_devices_per_host\n        self.parent = parent\n\n        self.launched_physical_mesh = None\n        self.launched_physical_mesh_group = None\n\n        if devices is not None:\n            if len(devices) != len(host_ids):\n                raise RuntimeError(\n                    \"Please specify the gpu IDs used on each host.\")\n            if not all(len(ids) == num_devices_per_host for ids in devices):\n                raise RuntimeError(\n                    \"Device IDs specified for each host does not align \"\n                    \"with `num_devices_per_host`.\")\n        else:\n            devices = [list(range(num_devices_per_host)) for _ in host_ids]\n\n        self.devices = devices\n        # Depending on gpu_ids, generate device strs and ask Ray to allocate.\n        self.device_strs = []\n        for i in range(self.num_hosts):\n            ip = self.host_info[i][\"NodeManagerAddress\"]\n            self.device_strs.extend(\n                [device_id_to_str(ip, j) for j in devices[i]])\n\n    @property\n    def shape(self):\n        return (len(self.host_ids), self.num_devices_per_host)\n\n    @property\n    def num_devices(self):\n        \"\"\"Return the total number of GPUs on this mesh.\"\"\"\n        return len(self.host_ids) * self.num_devices_per_host\n\n    @property\n    def num_hosts(self):\n        \"\"\"Return the number of hosts in the mesh.\"\"\"\n        return len(self.host_ids)\n\n    def slice_1d(self, dim: int, indices: Sequence[int]):\n        \"\"\"\n        Slice a mesh given the slicing config.\n\n        Args:\n            dim: which dimension to slice from, 0 is host or 1 is the gpu\n            indices: indices to include along this dimension.\n\n        Returns:\n            mesh (PhysicalDeviceMesh)\n        \"\"\"\n        if dim == 0:\n            # slicing along the host dimension\n            host_ids = [self.host_ids[x] for x in indices]\n            host_info = [self.host_info[x] for x in host_ids]\n            return VirtualPhysicalMesh(\n                host_ids=host_ids,\n                host_info=host_info,\n                num_devices_per_host=self.num_devices_per_host,\n                parent=self)\n        else:\n            # slicing along the device dimension\n\n            # Check the validity of device_indices\n            for i in range(len(indices)):\n                for x in indices[i]:\n                    assert x in self.devices[i]\n\n            return VirtualPhysicalMesh(host_ids=self.host_ids,\n                                       host_info=self.host_info,\n                                       num_devices_per_host=len(indices[0]),\n                                       parent=self,\n                                       devices=indices)\n\n    def slice_2d(self, host_indices, device_indices):\n        host_ids = [self.host_ids[x] for x in host_indices]\n        host_info = [self.host_info[x] for x in host_indices]\n\n        # Check the validity of device_indices\n        for i in range(len(device_indices)):\n            for x in device_indices[i]:\n                assert x in self.devices[i]\n\n        return VirtualPhysicalMesh(host_ids=host_ids,\n                                   host_info=host_info,\n                                   num_devices_per_host=len(device_indices[0]),\n                                   parent=self,\n                                   devices=device_indices)\n\n    def slice_profiling_submeshes(self, submesh_num_hosts,\n                                  submesh_num_devices_per_host):\n        num_hosts = len(self.host_ids)\n        num_devices_per_host = self.num_devices_per_host\n        num_host_submeshes = num_hosts // submesh_num_hosts\n        num_device_submeshes = (num_devices_per_host //\n                                submesh_num_devices_per_host)\n        all_submeshes = []\n        for i in range(num_host_submeshes):\n            for j in range(num_device_submeshes):\n                host_indices = range(i * submesh_num_hosts,\n                                     (i + 1) * submesh_num_hosts)\n                device_indices = [\n                    range(j * submesh_num_devices_per_host,\n                          (j + 1) * submesh_num_devices_per_host)\n                    for _ in host_indices\n                ]\n                all_submeshes.append(self.slice_2d(host_indices,\n                                                   device_indices))\n        return all_submeshes\n\n    def get_logical_mesh(self,\n                         mesh_shape: Optional[Sequence[int]] = None,\n                         mesh_alpha: Optional[float] = None,\n                         mesh_beta: Optional[float] = None):\n        \"\"\"\n        Return a logical mesh and parameters of the alpha-beta communication\n        cost model. The logical view is used for auto-sharding.\n        \"\"\"\n        if mesh_shape is None:\n            mesh_shape = (self.num_hosts, self.num_devices_per_host)\n\n        id_mesh = np.arange(self.num_devices).reshape(mesh_shape)\n        mesh_alpha = mesh_alpha or (1, 1)\n        mesh_beta = mesh_beta or (1, 0.1)\n        return LogicalDeviceMesh(None, id_mesh, mesh_alpha, mesh_beta)\n\n    def get_physical_mesh(self, mesh_id: int = 0):\n        \"\"\"Launch a physical mesh (which will request resources from Ray).\"\"\"\n        assert self.launched_physical_mesh is None, \\\n            \"Physical mesh can only be launched once.\"\n\n        self.launched_physical_mesh = DistributedPhysicalDeviceMesh(\n            host_ids=self.host_ids,\n            host_info=self.host_info,\n            num_devices_per_host=self.num_devices_per_host,\n            parent=self,\n            devices=self.devices,\n            mesh_id=mesh_id)\n        return self.launched_physical_mesh\n\n    def get_physical_mesh_group(self, sliced_virtual_meshes):\n        \"\"\"Launch a physical mesh group (which will request resources from\n        Ray).\"\"\"\n        assert self.launched_physical_mesh_group is None, \\\n            \"Physical mesh group can only be launched once.\"\n\n        # Launch physical meshes in parallel\n        physical_meshes = [None] * len(sliced_virtual_meshes)\n\n        def launch_func(i):\n            physical_meshes[i] = sliced_virtual_meshes[i].get_physical_mesh(i)\n\n        threads = []\n        for i in range(len(sliced_virtual_meshes)):\n            t = threading.Thread(target=launch_func, args=(i,))\n            t.start()\n            threads.append(t)\n        for i in range(len(sliced_virtual_meshes)):\n            threads[i].join()\n\n        self.launched_physical_mesh_group = (PhysicalDeviceMeshGroup(\n            physical_meshes, self))\n        return self.launched_physical_mesh_group\n\n\nclass PhysicalDeviceMeshGroup:\n    \"\"\"A list of physical devices that forms a pipeline.\"\"\"\n\n    def __init__(self, meshes: Sequence[DistributedPhysicalDeviceMesh],\n                 parent: VirtualPhysicalMesh):\n        self.meshes = list(meshes)\n        self.parent = parent\n        self.collective_groups: List[List[Any]] = [\n            [None for _ in range(len(self))] for _ in range(len(self))\n        ]\n\n    def __getitem__(self, index):\n        return self.meshes[index]\n\n    def __len__(self):\n        return len(self.meshes)\n\n    def index(self, *args, **kwargs):\n        return self.meshes.index(*args, **kwargs)\n\n    def establish_nccl_group(self,\n                             src_mesh_id: int,\n                             dst_mesh_id: int,\n                             instantiate=True):\n        \"\"\"Establish NCCL group between two meshes.\"\"\"\n        # pylint: disable=import-outside-toplevel\n        from alpa.pipeline_parallel.cross_mesh_resharding import CollectiveGroup\n\n        assert src_mesh_id < dst_mesh_id\n        if self.collective_groups[src_mesh_id][dst_mesh_id] is not None:\n            # Already established\n            return\n        src_mesh = self.meshes[src_mesh_id]\n        dst_mesh = self.meshes[dst_mesh_id]\n        device_strs = OrderedSet(src_mesh.device_strs + dst_mesh.device_strs)\n        cg = CollectiveGroup(device_strs, src_mesh, dst_mesh)\n        self.collective_groups[src_mesh_id][dst_mesh_id] = cg\n        self.collective_groups[dst_mesh_id][src_mesh_id] = cg\n        if instantiate:\n            self._instantiate_nccl_group(cg)\n\n    def instantiate_nccl_group(self, src_mesh_id: int, dst_mesh_id: int):\n        cg = self.collective_groups[src_mesh_id][dst_mesh_id]\n        self._instantiate_nccl_group(cg)\n\n    def shard_args_to_arrays(self, placement_specs: PlacementSpec,\n                             args: Sequence[Any]):\n        rets = []\n\n        for info, arg in zip(placement_specs, args):\n            aval = info.aval\n            if len(info.mesh_ids) == 1:\n                mesh = self.meshes[info.mesh_ids[0]]\n                spec = info.sharding_specs[0]\n                indices = pxla.spec_to_indices(aval.shape, spec)\n                rets.append(\n                    mesh.shard_args_to_arrays((aval,), (indices,), (spec,),\n                                              (arg,))[0])\n            else:\n                meshes, arrays = [], []\n                for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs):\n                    mesh = self.meshes[mesh_id]\n                    meshes.append(mesh)\n                    indices = pxla.spec_to_indices(aval.shape, spec)\n                    arrays.append(\n                        mesh.shard_args_to_arrays((aval,), (indices,), (spec,),\n                                                  (arg,))[0])\n                rets.append(ReplicatedDistributedArray(meshes, arrays))\n\n        return rets\n\n    def set_runtime_random_seed(self, seed: int):\n        for m in self.meshes:\n            m.set_runtime_random_seed(seed)\n\n    def sync_workers(self):\n        \"\"\"Sync device activities on all workers.\"\"\"\n        all_workers = [w for mesh in self.meshes for w in mesh.workers]\n        ray.get([w.sync.remote() for w in all_workers])\n\n    def sync_move_workers(self):\n        \"\"\"Sync moveworkers on all meshes.\"\"\"\n        for mesh in self.meshes:\n            mesh.sync_move_workers()\n\n    def get_memory_allocated(self):\n        \"\"\"Get the current size of allocated memory.\"\"\"\n        calls = []\n        for mesh in self.meshes:\n            for worker in mesh.workers:\n                calls.append(worker.get_memory_allocated.remote())\n        return max(ray.get(calls))\n\n    def get_max_memory_allocated(self):\n        \"\"\"Get the maximal size of memory allocated so far.\"\"\"\n        calls = []\n        for mesh in self.meshes:\n            for worker in mesh.workers:\n                calls.append(worker.get_max_memory_allocated.remote())\n        return max(ray.get(calls))\n\n    def get_max_memory_allocated_per_mesh(self):\n        \"\"\"Get the maximal size of memory allocated for each mesh so far.\"\"\"\n        return [mesh.get_max_memory_allocated() for mesh in self.meshes]\n\n    def reset_memory_stats(self):\n        for mesh in self.meshes:\n            mesh.reset_memory_stats()\n\n    def destroy_collective_groups(self):\n        for i in range(len(self)):\n            for j in range(len(self)):\n                if i < j and self.collective_groups[i][j] is not None:\n                    self.collective_groups[i][j].destroy()\n\n    def shutdown(self):\n        self.destroy_collective_groups()\n        for mesh in self.meshes:\n            mesh.shutdown()\n\n    def exception_shutdown(self):\n        \"\"\"In this shutdown, some actors might have died.\"\"\"\n        # recycle collective group info\n        for i in range(len(self)):\n            for j in range(len(self)):\n                if i < j and self.collective_groups[i][j]:\n                    group_name = self.collective_groups[i][j].group_name\n                    # TODO(Hao): move this part of recycling to\n                    #   ray.util.collective instead of here.\n                    name = \"info_\" + group_name\n                    try:\n                        store = ray.get_actor(name)\n                        ray.kill(store)\n                    except ValueError:\n                        pass\n        # TODO(Hao): recycle the NCCLUniqueID named actor. Their name is MD5\n        #  hashed. each of them will take 1 CPU.\n        # recycle info actors\n        for mesh in self.meshes:\n            mesh.shutdown(forced=True)\n\n    @staticmethod\n    def _instantiate_nccl_group(cg):\n        if global_config.eagerly_create_communicators:\n            cg.instantiate_now()\n        else:\n            cg.instantiate()\n\n\n########################################\n# Device Cluster\n########################################\nclass DeviceCluster:\n    \"\"\"\n    A ray cluster with GPU devices.\n\n    This is the top interface for alpa to interact with ray cluster's resources.\n    \"\"\"\n\n    def __init__(self,\n                 num_nodes: int = None,\n                 num_devices_per_node: int = None,\n                 namespace: Optional[str] = None):\n        # pylint: disable=import-outside-toplevel\n        ray_global_node = ray_worker._global_node\n        try:\n            self.head_info = ray_global_node.address_info\n        except AttributeError as ae:\n            raise RuntimeError(\n                \"Cannot access ray global node. Did you call ray.init?\") \\\n                from ae\n\n        # Gather host ids\n        all_host_info = []\n        all_host_ips = []\n\n        for node in ray.nodes():\n            for key in node[\"Resources\"]:\n                if (is_ray_node_resource(key) and\n                        global_config.ray_accelerator_name\n                        in node[\"Resources\"]):\n                    all_host_info.append(node)\n                    all_host_ips.append(key.split(\"node:\")[-1])\n\n        # Gather device info\n        all_host_num_devices = []\n        for host_info in all_host_info:\n            number = host_info[\"Resources\"][global_config.ray_accelerator_name]\n            assert number.is_integer()\n            all_host_num_devices.append(int(number))\n\n        # adjust the resource allocations\n        # if `num_nodes` is set, use it.\n        # otherwise, use the number of nodes in cluster\n        if num_nodes:\n            num_hosts = min(num_nodes, len(all_host_info))\n        else:\n            num_hosts = len(all_host_info)\n\n        # if `devices_per_node` is set, use it.\n        if num_devices_per_node:\n            # verify that the number of devices per node is valid\n            num_valid = sum(num_device >= num_devices_per_node\n                            for num_device in all_host_num_devices)\n            if num_valid < num_nodes:\n                raise RuntimeError(\"The number of devices per node is invalid. \"\n                                   f\"There are only {num_valid} valid nodes.\")\n            # NOTE: for simplicity, we assume `num_devices_per_node` are equal.\n            self.host_num_devices = [num_devices_per_node] * num_hosts\n        else:\n            self.host_num_devices = all_host_num_devices\n\n        # Create placement group\n        self.namespace = namespace\n        if namespace:\n            pg_name = namespace + \"_pg\"\n            try:\n                pg = ray.util.get_placement_group(pg_name)\n            except ValueError:\n                pg = None\n        else:\n            pg_name = pg = None\n\n        if pg:\n            self.placement_group = pg\n        else:\n            self.placement_group = create_placement_group(\n                num_hosts, self.host_num_devices, pg_name)\n\n        # Update the Device Cluster info from placement group\n        if num_devices_per_node or num_nodes:\n            # map: host ip to host info\n            host_ip2info = dict(zip(all_host_ips, all_host_info))\n\n            # get bundle's ip address\n            ips = get_bundle2ip(self.placement_group)\n            bundle_specs = self.placement_group.bundle_specs\n\n            # filter out the bundle index with device (GPUs)\n            device_bundle_idx_list = [\n                i for i, bundle_spec in enumerate(bundle_specs)\n                if bundle_spec.get(\"GPU\", 0) > 0\n            ]\n\n            # filter nodes according to the placement group\n            self.host_info = [host_ip2info[ip] for ip in ips]\n            self.host_ips = [\n                ips[bundle_idx] for bundle_idx in device_bundle_idx_list\n            ]\n        else:\n            self.host_info = all_host_info\n            self.host_ips = all_host_ips\n\n    def delete_placement_group(self):\n        \"\"\"remove the placement group for the current device cluster.\"\"\"\n        remove_placement_group(self.placement_group)\n        self.placement_group = None\n\n    @property\n    def num_cpus(self):\n        return sum(\n            map(lambda info: int(info[\"Resources\"][\"CPU\"]), self.host_info))\n\n    @property\n    def num_devices(self):\n        return sum(self.host_num_devices)\n\n    @property\n    def num_hosts(self):\n        return len(self.host_info)\n\n    def get_physical_mesh(self,\n                          host_ids: Sequence[int] = None,\n                          num_devices_per_host: int = None):\n        \"\"\"\n        Slice a subset of hosts and devices to form a physical device mesh.\n\n        Args:\n            host_ids: The index of host nodes.\n                \"None\" means using all hosts\n            num_devices_per_host: The number of devices per host.\n                \"None\" means using all devices\n\n        Return:\n            A physical multi-host device mesh\n        \"\"\"\n        host_ids = host_ids or np.arange(len(self.host_info))\n        host_info = [self.host_info[x] for x in host_ids]\n\n        num_devices_per_host = num_devices_per_host or self.host_num_devices[\n            host_ids[0]]\n        for host_id in host_ids:\n            assert self.host_num_devices[host_id] >= num_devices_per_host\n\n        return DistributedPhysicalDeviceMesh(\n            host_ids=host_ids,\n            host_info=host_info,\n            num_devices_per_host=num_devices_per_host,\n            parent=self,\n            namespace=self.namespace)\n\n    def get_virtual_physical_mesh(self,\n                                  host_ids: Sequence[int] = None,\n                                  num_devices_per_host: int = None):\n        \"\"\"\n        Slice a subset of hosts and devices to form a virtual physical mesh.\n\n        The only difference between a virtual and a physical mesh is that a\n        virtual mesh does not request cluster resources.\n        \"\"\"\n        host_ids = host_ids or np.arange(len(self.host_info))\n        host_info = [self.host_info[x] for x in host_ids]\n\n        num_devices_per_host = num_devices_per_host or self.host_num_devices[\n            host_ids[0]]\n        for host_id in host_ids:\n            assert self.host_num_devices[host_id] >= num_devices_per_host\n\n        return VirtualPhysicalMesh(host_ids=host_ids,\n                                   host_info=host_info,\n                                   num_devices_per_host=num_devices_per_host,\n                                   parent=self)\n\n    def profile_all(self, *args, **kwargs):\n        \"\"\"Profile computation and communication cost for all submesh shapes of\n        this cluster.\"\"\"\n        return mesh_profiling.profile_all(self, *args, **kwargs)\n\n\n# Global runtime objects\nglobal_cluster: DeviceCluster = None\nglobal_physical_mesh: PhysicalDeviceMesh = None\nglobal_virtual_physical_mesh: VirtualPhysicalMesh = None\n\n\ndef init_global_cluster(cluster: str,\n                        cluster_address: Optional[str] = None,\n                        num_nodes: Optional[int] = None,\n                        num_devices_per_node: Optional[int] = None,\n                        namespace: Optional[str] = None):\n    global global_cluster, global_physical_mesh, global_virtual_physical_mesh\n\n    if cluster == \"local\":\n        global_physical_mesh = LocalPhysicalDeviceMesh()\n    elif cluster == \"ray\":\n        if not ray.is_initialized():\n            ray_addr = cluster_address if cluster_address else \"auto\"\n            ray.init(address=ray_addr,\n                     ignore_reinit_error=True,\n                     namespace=namespace)\n        update_jax_platform(\"cpu\")\n        global_cluster = DeviceCluster(num_nodes, num_devices_per_node)\n        global_virtual_physical_mesh = (\n            global_cluster.get_virtual_physical_mesh())\n\n\ndef shutdown_global_cluster():\n    global global_cluster, global_physical_mesh, global_virtual_physical_mesh\n\n    if global_physical_mesh:\n        global_physical_mesh.shutdown()\n        global_physical_mesh = None\n\n    if global_virtual_physical_mesh:\n        if global_virtual_physical_mesh.launched_physical_mesh_group:\n            global_virtual_physical_mesh.launched_physical_mesh_group.shutdown()\n        global_virtual_physical_mesh = None\n\n    global_cluster.delete_placement_group()\n    global_cluster = None\n    update_jax_platform(\"gpu\")\n\n\ndef set_global_cluster(cluster: DeviceCluster):\n    global global_cluster\n    global_cluster = cluster\n\n\ndef get_global_cluster():\n    return global_cluster\n\n\ndef set_global_physical_mesh(mesh: PhysicalDeviceMesh):\n    global global_physical_mesh\n    global_physical_mesh = mesh\n\n\ndef get_global_physical_mesh(create_if_not_exist=False):\n    global global_physical_mesh\n\n    if global_physical_mesh is None and create_if_not_exist:\n        if global_cluster is None:\n            # ray is not initialized, use local devices\n            mesh = LocalPhysicalDeviceMesh()\n        else:\n            mesh = global_cluster.get_physical_mesh()\n        global_physical_mesh = mesh\n\n    return global_physical_mesh\n\n\ndef set_global_virtual_physical_mesh(mesh: VirtualPhysicalMesh):\n    global global_virtual_physical_mesh\n    global_virtual_physical_mesh = mesh\n\n\ndef get_global_virtual_physical_mesh():\n    return global_virtual_physical_mesh\n\n\ndef set_seed(seed: int):\n    global_config.runtime_random_seed = seed\n\n    if global_physical_mesh:\n        global_physical_mesh.set_runtime_random_seed(seed)\n    if (global_virtual_physical_mesh and\n            global_virtual_physical_mesh.launched_physical_mesh_group):\n        global_virtual_physical_mesh.launched_physical_mesh_group.\\\n            set_runtime_random_seed(seed)\n\n\ndef get_global_num_devices():\n    if global_virtual_physical_mesh:\n        return global_virtual_physical_mesh.num_devices\n    if global_physical_mesh:\n        return global_physical_mesh.num_devices\n\n    raise RuntimeError(\"Please call alpa.init first\")\n\n\ndef create_and_record_cross_mesh_collective_communicators(\n        meshes: Sequence[DistributedPhysicalDeviceMesh], key):\n    workers = []\n    device_strs = []\n    for mesh in meshes:\n        workers.extend(mesh.workers)\n        device_strs.extend(mesh.device_strs)\n    world_size = len(workers)\n    backend = \"nccl\"\n    group_name = \",\".join(device_strs)\n    refs = []\n    for rank, worker in enumerate(workers):\n        ref = worker.create_and_set_cross_mesh_communicators.remote(\n            world_size, rank, backend, group_name, key)\n        refs.append(ref)\n    return refs\n\n\n########################################\n# Register ShardArg Handler\n########################################\ndef _device_mesh_put(device_mesh, shards, num_batch, batch_dim):\n    ary_refs, ary_uuids = create_remote_array_refs(device_mesh, num_batch)\n    shard_step = device_mesh.num_devices_per_host\n    for host_id in range(device_mesh.num_hosts):\n        device_mesh.workers[host_id].put_buffers.remote(\n            ary_uuids, shards[host_id * shard_step:(host_id + 1) * shard_step],\n            num_batch, batch_dim)\n    return ary_refs\n\n\ndef _device_mesh_put_dummy(array, device_mesh, indices, num_batch):\n    ary_refs, ary_uuids = create_remote_array_refs(device_mesh, num_batch)\n    step = device_mesh.num_devices_per_host * num_batch\n    for host_id in range(device_mesh.num_hosts):\n        device_mesh.workers[host_id].shard_and_put_non_zero_buffer.remote(\n            ary_uuids, array.shape, array.dtype,\n            indices[host_id * step:(host_id + 1) * step], num_batch)\n    return ary_refs\n\n\ndef _shard_abstract_array(array,\n                          device_mesh,\n                          indices,\n                          num_batch=1,\n                          batch_dim=0):\n    # pylint: disable=unused-argument\n    assert global_config.use_dummy_value_for_benchmarking is True\n    return _device_mesh_put_dummy(array, device_mesh, indices, num_batch)\n\n\ndef _shard_array(array, device_mesh, indices, num_batch=1, batch_dim=0):\n    if global_config.use_dummy_value_for_benchmarking:\n        return _device_mesh_put_dummy(array, device_mesh, indices, num_batch)\n    else:\n        # Create shards according to indices for a numpy array\n        if array.shape == ():\n            # need a special branch because np.ascontiguousarray does not\n            # correctly preserve the shapes of rank-0 arrays.\n            datas = [np.asarray(array)] * len(indices)\n        else:\n            datas = [np.ascontiguousarray(array[i]) for i in indices]\n        if num_batch > 1:\n            concate_datas = []\n            for device_id in range(device_mesh.num_devices):\n                mb = datas[device_id * num_batch:(device_id + 1) * num_batch]\n                concate_datas.append(np.concatenate(mb, axis=batch_dim))\n            datas = concate_datas\n        return _device_mesh_put(device_mesh, datas, num_batch, batch_dim)\n\n\ndef _shard_device_array(array, device_mesh, indices, num_batch=1, batch_dim=0):\n    if global_config.use_dummy_value_for_benchmarking:\n        return _device_mesh_put_dummy(array, device_mesh, indices, num_batch)\n    else:\n        return _shard_array(np.asarray(array), device_mesh, indices, num_batch,\n                            batch_dim)\n\n\ndef _shard_distributed_array(array,\n                             device_mesh,\n                             indices,\n                             num_batch=1,\n                             batch_dim=0):\n    # Slow path: gather values to host and reshard\n    return shard_arg_handlers[type(array._value)](array._value, device_mesh,\n                                                  indices, num_batch, batch_dim)\n\n\nshard_arg_handlers = {}  # Shard an argument to a distributed array\nfor a in array_types:\n    shard_arg_handlers[a] = _shard_array\nshard_arg_handlers[ShapedArray] = _shard_abstract_array\nshard_arg_handlers[ShapeDtypeStruct] = _shard_abstract_array\nshard_arg_handlers[xla._DeviceArray] = _shard_device_array\nshard_arg_handlers[xla._CppDeviceArray] = _shard_device_array\nshard_arg_handlers[DistributedArray] = _shard_distributed_array\nshard_arg_handlers[ShardedDeviceArray] = _shard_distributed_array\n"
  },
  {
    "path": "alpa/follow_parallel.py",
    "content": "\"\"\"Follow the parallelization strategy of another function.\"\"\"\nimport logging\n\nfrom jax.core import ClosedJaxpr\nfrom jax.interpreters import partial_eval as pe\nfrom jax.tree_util import tree_leaves\n\nfrom alpa.global_env import global_config\nfrom alpa.mesh_executable import (NormalMeshDriverExecutable,\n                                  GradAccMeshDriverExecutable)\nfrom alpa.parallel_plan import PlacementSpec\nfrom alpa.pipeline_parallel.compile_executable import (\n    compile_pipeshard_executable)\nfrom alpa.pipeline_parallel.layer_construction import (ManualLayerOption,\n                                                       FollowLayerOption)\nfrom alpa.pipeline_parallel.stage_construction import UniformStageOption\nfrom alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass,\n                                               AutoShardingOption)\nfrom alpa.util import (jaxpr_to_hlo, undefined_sharding_spec_proto)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\ndef compile_follow_parallel_executable(fun, in_tree, out_tree_thunk,\n                                       static_argnums, donated_invars,\n                                       batch_invars, src_func,\n                                       num_micro_batches, input_placement_specs,\n                                       pipeline_schedule, layer_option, *avals):\n\n    def is_leave(x):\n        return isinstance(x, PlacementSpec) or x is None\n\n    input_placement_specs = tree_leaves(input_placement_specs, is_leave)\n\n    executable = src_func.get_last_executable()\n    if (not isinstance(executable, NormalMeshDriverExecutable) and\n            global_config.backend == \"tpu\"):\n        raise NotImplementedError(f\"{type(executable)} is not supported in tpu\")\n    if isinstance(executable,\n                  (NormalMeshDriverExecutable, GradAccMeshDriverExecutable)):\n        if num_micro_batches != 1 and num_micro_batches is not None:\n            logger.warning(\"num_micro_batches is ignored in FollowParallel\")\n\n        # Trace to get jaxpr and HloModule\n        jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, avals)\n        closed_jaxpr = ClosedJaxpr(jaxpr, consts)\n        out_tree = out_tree_thunk()\n\n        name = f\"{fun.__name__}_follow_shard_parallel\"\n        hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars)\n\n        # Get input sharding specs\n        sharding_protos = []\n        for spec in input_placement_specs:\n            if spec is None:\n                sharding_protos.append(undefined_sharding_spec_proto())\n            else:\n                assert len(spec.mesh_ids) == 1\n                sharding_protos.append(spec.sharding_specs[0].sharding_proto())\n\n        # Run sharding propagation\n        physical_mesh = executable.physical_mesh\n        hlo.set_input_shardings(sharding_protos)\n        hlo, stage_plan = run_auto_sharding_pass(\n            hlo, physical_mesh.get_logical_mesh(), \"single\", 1,\n            AutoShardingOption(enable_auto_sharding=False))\n\n        return NormalMeshDriverExecutable(physical_mesh, hlo, stage_plan, avals,\n                                          out_avals, [False] * len(avals),\n                                          static_argnums, in_tree, out_tree)\n    else:\n        num_micro_batches = num_micro_batches or 1\n\n        if layer_option == \"manual\":\n            layer_option = ManualLayerOption()\n        elif layer_option == \"follow\":\n            layer_option = FollowLayerOption(input_placement_specs,\n                                             len(executable.mesh_group))\n        else:\n            raise ValueError(f\"Invalid layer option: {layer_option}\")\n\n        input_shardings = [x.sharding_specs[0] for x in input_placement_specs]\n        # TODO(lmzheng): handle ReplicatedDistributedArray, tied embedding\n        mesh = executable.mesh_group.parent\n\n        return compile_pipeshard_executable(\n            fun, in_tree, out_tree_thunk, static_argnums, donated_invars,\n            batch_invars, mesh, num_micro_batches, pipeline_schedule,\n            AutoShardingOption(enable_auto_sharding=False), layer_option,\n            UniformStageOption(), input_shardings, None, None, *avals)\n"
  },
  {
    "path": "alpa/global_env.py",
    "content": "\"\"\"All global configurations for this project.\"\"\"\nimport os\n\n\nclass GlobalConfig:\n    \"\"\"The global configuration of alpa.\"\"\"\n\n    def __init__(self):\n        ########## Options of device mesh ##########\n        self.backend = \"gpu\"\n        self.has_cuda = os.system(\"nvidia-smi > /dev/null 2>&1\") == 0\n\n        # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html\n        self.xla_client_mem_fraction = float(\n            os.environ.get(\"XLA_PYTHON_CLIENT_MEM_FRACTION\", 0.9))\n        self.xla_client_client_preallocate = os.environ.get(\n            \"XLA_PYTHON_CLIENT_PREALLOCATE\", \"true\")\n        # The threshold to tigger a batched deletion on workers.\n        self.delete_remote_arrays_threshold = 50\n\n        # Random seed used for compilation\n        self.compile_random_seed = 42\n        # Random seed used for runtime\n        self.runtime_random_seed = 42\n\n        # XLA server port range\n        self.xla_server_port_start = int(\n            os.environ.get(\"XLA_SERVER_PORT_START\", \"20000\").lower())\n        self.xla_server_port_end = int(\n            os.environ.get(\"XLA_SERVER_PORT_END\", \"25000\").lower())\n        # XLA gpu kernel auto-tuning level\n        self.xla_gpu_autotune_level = 4\n\n        # Whether to use AWS EFA network interface\n        self.use_aws_efa = os.environ.get(\"ALPA_USE_AWS_EFA\",\n                                          \"\").lower() in [\"true\", \"1\"]\n\n        ########## Options of shard_parallel ##########\n        # Whether to sync before and after the executable for accurate internal\n        # timer\n        self.shard_parallel_sync_for_timer = False\n\n        ########## Options of pipeline_parallel ##########\n        # Whether to debug with pipeshard runtime. If turned on, no physical\n        # resource is required until launching PipeshardExecutable.\n        self.debug_with_pipeshard_runtime = False\n        # Whether to use the whole cluster for stage profiling. If not, only\n        # use the given mesh.\n        self.profile_with_whole_ray_cluster = True\n        # Stage construction profiling time threshold.\n        self.profile_timeout = 500\n        # Stage construction profiling retry threshold.\n        # Some communication patterns may meet deadlock, so it needs retry.\n        self.profile_maximum_retry = 2\n        # Whether to forcely set stage construction's submesh choices\n        self.overwrite_submesh_choices = None\n        self.always_donate_micro_batch_vars = True\n\n        ########## Options of pipeline runtime ##########\n        # Whether to sync before and after the executable for accurate internal\n        # timer\n        self.pipeline_sync_for_timer = False\n        # Whether to use distributed compilation in pipeline parallel for\n        # each stage. Disabling it helps debug.\n        self.pipeline_distributed_compile = True\n        self.eagerly_create_communicators = True\n        self.pipeline_check_alive = False\n        # Whether to use single-byte signal tensor for send/recv.\n        # This is a debug option.\n        self.pipeline_use_signal_send_recv = False\n        # Whether to use the scatter-gater/local-all-gather optimization.\n        self.use_local_allgather = True\n        # Cross mesh resharding mode. Possible choices: {\"send_recv\",\n        # \"broadcast\"}\n        self.resharding_mode = \"send_recv\"\n        # Which nccl to use. Possible choices: {\"cupy\",\n        # \"xla_extension\"}\n        self.nccl_mode = \"cupy\"\n        self.enable_overlapping = False\n        # Cross mesh resharding load balancing mode.\n        # Possible choices: {\"normal\", \"no_loadbalance\",\n        # \"loadbalance_size\", \"loadbalance_order\"}\n        self.resharding_loadbalance_mode = \"normal\"\n        self.loadbalance_order_algo = \"greedy\"\n\n        ########## Options of benchmark ##########\n        # If true, the system is allowed to use dummy values during\n        # tensor creation and copy to reduce the initialization and copy time.\n        # This will produce wrong results but is acceptable for\n        # data-independent benchmarks.\n        self.use_dummy_value_for_benchmarking = False\n\n        ########## Options of monkey patch ##########\n        self.flax_always_use_fp16_embedding = False\n\n        ########## Options of logging ##########\n        self.print_compilation_time = False\n        self.print_auto_layer_stats = False\n\n        # Whether to collect activity trace\n        self.collect_trace = False\n\n    @property\n    def ray_accelerator_name(self):\n        backend_to_ray = {\"gpu\": \"GPU\"}\n        return backend_to_ray[self.backend]\n\n    def update_worker_config(self, cfg: \"GlobalConfig\"):\n        \"\"\"Update the worker config based on the host one\"\"\"\n        self.backend = cfg.backend\n        # Random seed used for compilation\n        self.compile_random_seed = cfg.compile_random_seed\n        # Random seed used for runtime\n        self.runtime_random_seed = cfg.runtime_random_seed\n        # XLA server port range\n        self.xla_server_port_start = cfg.xla_server_port_start\n        self.xla_server_port_end = cfg.xla_server_port_end\n        # XLA gpu kernel auto-tuning level\n        self.xla_gpu_autotune_level = cfg.xla_gpu_autotune_level\n        # Whether to use AWS EFA network interface\n        self.use_aws_efa = cfg.use_aws_efa\n        ########## Options of pipeline runtime ##########\n        # Whether to sync before and after the executable for accurate internal\n        # timer\n        self.pipeline_sync_for_timer = cfg.pipeline_sync_for_timer\n        # Whether to use single-byte signal tensor for send/recv.\n        # This is a debug option.\n        self.pipeline_use_signal_send_recv = cfg.pipeline_use_signal_send_recv\n        # Whether to use the scatter-gater/local-all-gather optimization.\n        self.use_local_allgather = cfg.use_local_allgather\n        # Cross mesh resharding mode. Possible choices: {\"send_recv\",\n        # \"broadcast\"}\n        self.resharding_mode = cfg.resharding_mode\n        self.nccl_mode = cfg.nccl_mode\n        self.enable_overlapping = cfg.enable_overlapping\n        self.collect_trace = cfg.collect_trace\n\n\nglobal_config = GlobalConfig()\n\n# Other environment setup\nis_worker = os.environ.get(\"ALPA_IS_WORKER\", \"False\") == \"True\"\n\nos.environ[\"XLA_FLAGS\"] = (os.environ.get(\"XLA_FLAGS\", \"\") +\n                           \" --xla_gpu_enable_async_all_reduce=false\" +\n                           \" --xla_gpu_force_compilation_parallelism=8\")\n"
  },
  {
    "path": "alpa/mesh_executable.py",
    "content": "# pylint: disable=arguments-differ\n\"\"\"A mesh executable encapsulates all compiled binary and meta information of\na distributed executable.\n\nA mesh executable contains one or several XLA executables.\nFor each type of mesh executable, there is a driver part and a worker part.\nThe driver part runs on the user script and the worker parts run on distributed\nworkers. The driver part sends control commands to launch the worker parts on\nworkers.\n\"\"\"\nfrom abc import ABC, abstractmethod\nfrom typing import Sequence, Optional\nimport os\n\nfrom jax import xla\nimport jax.numpy as jnp\nfrom jax._src.api import ShapeDtypeStruct\nfrom jax._src.lib import xla_client as xc, xla_extension as xe\nfrom jax.core import ShapedArray\nfrom jax.interpreters import pxla\nfrom jax.tree_util import tree_flatten, tree_unflatten, tree_leaves, PyTreeDef\nimport numpy as np\nimport ray\nfrom alpa.util import XlaPassContext\n\nfrom alpa.device_mesh import (LocalPhysicalDeviceMesh,\n                              DistributedPhysicalDeviceMesh, RemoteArrayRef,\n                              next_array_uuids)\nfrom alpa.global_env import global_config\nfrom alpa.parallel_plan import (PlacementSpec, StagePlan, ClusterInfo,\n                                ParallelPlan)\nfrom alpa.shard_parallel.auto_sharding import (AutoShardingOption,\n                                               get_input_output_sharding_specs,\n                                               make_replicated_spec,\n                                               run_backend_compilation,\n                                               run_spmd_partitioner_pass)\nfrom alpa.timer import timers\nfrom alpa.util import (compile_allocate_zero_buffers, get_compile_options,\n                       get_index_select_computation, get_shard_shape,\n                       get_microbatch_sharding_spec, profile_xla_executable)\nfrom alpa.wrapped_hlo import HloStatus, WrappedHlo\n\n\nclass MeshDriverExecutable(ABC):\n    \"\"\"The base class of the driver part of a mesh executable.\"\"\"\n\n    @abstractmethod\n    def launch_on_driver(self, *args, **kwargs):\n        \"\"\"Launch the executable on the driver.\n\n        Args:\n            args: The original arguments of the parallelized function.\n            kwargs: The additional arguments to control execution options.\n        \"\"\"\n        raise NotImplementedError()\n\n    def get_input_placement_specs(self):\n        \"\"\"\n        Return the preferred placement specs for input arguments.\n        The return value is a pytree of PlacementSpec\n        with the same structure as the input pytree.\n        \"\"\"\n        raise NotImplementedError()\n\n    def get_output_placement_specs(self):\n        \"\"\"\n        Return the preferred placement specs for outputs.\n        The return value is a pytree of PlacementSpec\n        with the same structure as the output pytree.\n        \"\"\"\n        raise NotImplementedError()\n\n    def get_parallel_plan(self):\n        \"\"\"Get the overall parallel plan.\"\"\"\n        raise NotImplementedError()\n\n    def preshard_dynamic_args(self, *args):\n        \"\"\"Pre-shard the input arguments.\"\"\"\n        raise NotImplementedError()\n\n    def profile_with_dummy_inputs(self, **kwargs):\n        \"\"\"Profile the execution time costs with dummy inputs.\n\n        Args:\n            kwargs: The additional arguments to control execution options.\n        \"\"\"\n        raise NotImplementedError()\n\n    def get_execution_time_costs(self):\n        \"\"\"Return the pure execution time costs recorded by an internal\n        timer.\"\"\"\n        return self.physical_mesh.get_remote_timer(self.exec_timer_name).costs\n\n    def get_shard_args_time_costs(self):\n        \"\"\"Return the time costs of sharding input arguments.\"\"\"\n        return timers(self.shard_args_timer_name).costs\n\n    def get_hlo_text(self, status: HloStatus):\n        \"\"\"Return the HLO IR in the text format.\"\"\"\n        raise NotImplementedError()\n\n    def get_total_allocation_size(self):\n        \"\"\"Get the total memory allocation size in bytes.\"\"\"\n        raise NotImplementedError()\n\n    def dump_debug_info(self, folder: str):\n        \"\"\"\n        Dump intermediate representations and other informations for debugging.\n        \"\"\"\n        raise NotImplementedError()\n\n    def sync(self):\n        \"\"\"Sync all workers\"\"\"\n        self.physical_mesh.sync_workers()\n\n    def __del__(self):\n        if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh):\n            self.physical_mesh.delete_remote_executable(self.exec_uuid)\n\n\nclass MeshWorkerExecutable(ABC):\n    \"\"\"The base class of the worker part of a mesh executable.\"\"\"\n\n    @abstractmethod\n    def execute_on_worker(self, *arg, **kwargs):\n        \"\"\"Run the executable on the worker.\"\"\"\n        raise NotImplementedError()\n\n    def profile_with_dummy_inputs(self, backend, local_devices):\n        \"\"\"Profile the execution time costs with dummy inputs.\"\"\"\n        raise NotImplementedError()\n\n    def get_hlo_text(self):\n        \"\"\"Return the HLO IR in the text format.\"\"\"\n        raise NotImplementedError()\n\n    def get_total_allocation_size(self):\n        \"\"\"Get the total memory allocation size in bytes.\"\"\"\n        raise NotImplementedError()\n\n\n# The global executable counter\nmesh_executable_counter = 0\n\n\ndef next_mesh_executable_uuid():\n    \"\"\"Return the next uuid of a mesh executable.\"\"\"\n    global mesh_executable_counter\n    mesh_executable_counter = (mesh_executable_counter + 1) % (1 << 60)\n    return mesh_executable_counter\n\n\ndef get_execution_timer_name(exec_uuid: int):\n    \"\"\"Return the name of the timer used for recording pure execution time.\"\"\"\n    return f\"exec-{exec_uuid}\"\n\n\ndef get_sync_func_driver(physical_mesh):\n    \"\"\"Get the sync function on the driver.\"\"\"\n\n    def sync_func_driver():\n        assert isinstance(physical_mesh, LocalPhysicalDeviceMesh)\n        physical_mesh.devices[0].synchronize_all_activity()\n\n    return sync_func_driver\n\n\ndef get_sync_func_worker(worker):\n    \"\"\"Get the sync function on the workers\"\"\"\n\n    def sync_func_worker():\n        worker.local_devices[0].synchronize_all_activity()\n\n    return sync_func_worker\n\n\ndef wrap_to_placement_spec_tree(physical_mesh, avals, sharding_specs, pytree):\n    \"\"\"Wrap avals and sharding specs to a pytree of placement specs.\"\"\"\n    placement_specs = [\n        PlacementSpec(aval, (physical_mesh.mesh_id,), (sharding_spec,))\n        for aval, sharding_spec in zip(avals, sharding_specs)\n    ]\n    return tree_unflatten(pytree, placement_specs)\n\n\nclass NormalMeshDriverExecutable(MeshDriverExecutable):\n    \"\"\"The driver part of a normal mesh executable.\"\"\"\n\n    def __init__(self,\n                 physical_mesh: \"PhysicalDeviceMesh\",\n                 hlo: WrappedHlo,\n                 stage_plan: StagePlan,\n                 avals: Sequence[ShapedArray],\n                 out_avals: Sequence[ShapedArray],\n                 donated_invars: Sequence[bool],\n                 static_argnums: Optional[Sequence[int]] = None,\n                 in_tree: Optional[PyTreeDef] = None,\n                 out_tree: Optional[PyTreeDef] = None,\n                 flop_count: Optional[int] = None):\n        self.physical_mesh = physical_mesh\n        self.hlo = hlo\n        self.avals = avals\n        self.out_avals = out_avals\n        self.donated_invars = donated_invars\n        self.static_argnums = static_argnums\n        self.in_tree = in_tree\n        self.out_tree = out_tree\n        self.flop_count = flop_count\n        self.stage_plan = stage_plan\n        self.auto_sharding_option = stage_plan.auto_sharding_option\n        self.auto_sharding_objective = stage_plan.auto_sharding_objective\n\n        # Send the executable to workers\n        self.fully_optimized_hlo_text = None\n        self.exec_uuid = next_mesh_executable_uuid()\n        self._set_executable(physical_mesh, hlo, stage_plan)\n\n        if hlo.is_sharding_annotated():\n            hlo = run_spmd_partitioner_pass(hlo, physical_mesh.num_devices)\n        # Read sharding specs\n        self.input_sharding_specs, self.output_sharding_specs = (\n            get_input_output_sharding_specs(hlo.get_module(), avals, out_avals,\n                                            physical_mesh.num_devices,\n                                            stage_plan.logical_mesh_shape))\n\n        # Cache results for input and output sharding\n        self.input_indices = [\n            pxla.spec_to_indices(aval.shape, spec)\n            for aval, spec in zip(avals, self.input_sharding_specs)\n        ]\n        self.outs_handler = physical_mesh.get_outputs_handler(\n            out_avals, self.output_sharding_specs)\n\n        # Set up timers\n        self.exec_timer_name = get_execution_timer_name(self.exec_uuid)\n        self.shard_args_timer_name = self.exec_timer_name + \"-shard-args\"\n        self.sync_func = get_sync_func_driver(physical_mesh)\n\n    def _set_executable(self, physical_mesh, hlo, stage_plan):\n        \"\"\"Put the executable on workers.\"\"\"\n        if isinstance(physical_mesh, DistributedPhysicalDeviceMesh):\n            for w in physical_mesh.workers:\n                w.put_executable.remote(self.exec_uuid,\n                                        NormalMeshWorkerExecutable, hlo,\n                                        stage_plan, self.donated_invars)\n        else:\n            assert isinstance(physical_mesh, LocalPhysicalDeviceMesh)\n\n            if physical_mesh.devices[0] is None:\n                # A fake physical mesh for generating HLO module only\n                self.compiled = run_backend_compilation(\n                    physical_mesh.backend,\n                    hlo,\n                    stage_plan,\n                    physical_mesh.num_devices,\n                    bypass_device_assignment_check=True)\n            else:\n                self.compiled = run_backend_compilation(\n                    physical_mesh.backend, hlo, stage_plan,\n                    physical_mesh.num_devices)\n            self.fully_optimized_hlo_text = self.compiled.hlo_modules(\n            )[0].to_string()\n\n    def launch_on_driver(self, *args, **kwargs):\n        \"\"\"Launch the executable on the driver.\"\"\"\n        physical_mesh = self.physical_mesh\n        num_hosts = physical_mesh.num_hosts\n        num_outs = len(self.out_avals)\n\n        timers(self.shard_args_timer_name).start()\n        input_bufs = physical_mesh.shard_args_to_bufs(self.input_indices,\n                                                      self.donated_invars,\n                                                      (False,) * len(args),\n                                                      None, args)\n        timers(self.shard_args_timer_name).stop()\n\n        if isinstance(physical_mesh, DistributedPhysicalDeviceMesh):\n            input_uuids = np.array([ref.uuid for ref in input_bufs])\n            output_uuids = next_array_uuids(num_outs)\n\n            if \"sync_before\" not in kwargs:\n                kwargs[\"sync_before\"] = kwargs[\"sync_after\"] = (\n                    global_config.shard_parallel_sync_for_timer)\n\n            # Execute the SPMD binary\n            for i in range(num_hosts):\n                physical_mesh.workers[i].run_executable.remote(\n                    self.exec_uuid, input_uuids, output_uuids, **kwargs)\n\n            # Gather output buffers\n            output_bufs = np.array(\n                [RemoteArrayRef(physical_mesh, uuid) for uuid in output_uuids])\n\n            # Mark donated input buffers as already deleted on workers.\n            for ary_ref, is_donated in zip(input_bufs, self.donated_invars):\n                if is_donated:\n                    ary_ref.set_deleted_on_workers()\n        else:\n            assert isinstance(physical_mesh, LocalPhysicalDeviceMesh)\n            sync_func = (self.sync_func if\n                         global_config.shard_parallel_sync_for_timer else None)\n\n            timers(self.exec_timer_name).start(sync_func)\n            output_bufs = self.compiled.execute_sharded_on_local_devices(\n                input_bufs)\n            timers(self.exec_timer_name).stop(sync_func)\n\n        return self.outs_handler(output_bufs)\n\n    def get_input_placement_specs(self):\n        \"\"\"\n        Return the preferred placement specs for input arguments.\n        The return value is a pytree of PlacementSpec\n        with the same structure as the input pytree.\n        \"\"\"\n        return wrap_to_placement_spec_tree(self.physical_mesh, self.avals,\n                                           self.input_sharding_specs,\n                                           self.in_tree)\n\n    def get_output_placement_specs(self):\n        \"\"\"\n        Return the preferred placement specs for outputs.\n        The return value is a pytree of PlacementSpec\n        with the same structure as the output pytree.\n        \"\"\"\n        return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals,\n                                           self.output_sharding_specs,\n                                           self.out_tree)\n\n    def get_parallel_plan(self):\n        \"\"\"Get the overall parallel plan.\"\"\"\n        cluster_info = ClusterInfo(self.physical_mesh.num_hosts,\n                                   self.physical_mesh.num_devices_per_host)\n        return ParallelPlan(cluster_info, None, self.auto_sharding_option, None,\n                            tree_leaves(self.get_input_placement_specs()))\n\n    def preshard_dynamic_args(self, *args):\n        \"\"\"Pre-shard the input arguments.\"\"\"\n        input_bufs = self.physical_mesh.shard_args_to_bufs(\n            self.input_indices, self.donated_invars, (False,) * len(args), None,\n            args)\n        outs_handler = self.physical_mesh.get_outputs_handler(\n            self.avals, self.input_sharding_specs)\n        return outs_handler(input_bufs)\n\n    def __call__(self, *args):\n        \"\"\"Fast call without signature matching.\"\"\"\n        if self.static_argnums:\n            dyn_args = [\n                args[i]\n                for i in range(len(args))\n                if i not in self.static_argnums\n            ]\n        else:\n            dyn_args = args\n        args_flat, _ = tree_flatten(dyn_args)\n        out = self.launch_on_driver(*args_flat)\n        return tree_unflatten(self.out_tree, out)\n\n    def profile_with_dummy_inputs(self, **kwargs):\n        \"\"\"Profile the execution time costs with dummy inputs.\"\"\"\n        if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh):\n            tasks = []\n            for worker in self.physical_mesh.workers:\n                tasks.append(\n                    worker.profile_executable_with_dummy_inputs.remote(\n                        self.exec_uuid, **kwargs))\n            costs = ray.get(tasks)\n            for cost_vec in costs:\n                if np.inf in cost_vec:\n                    return [np.inf] * len(cost_vec)\n            costs = np.mean(costs, axis=0)\n        else:\n            assert isinstance(self.physical_mesh, LocalPhysicalDeviceMesh)\n            costs = profile_xla_executable(self.compiled,\n                                           self.physical_mesh.backend,\n                                           self.physical_mesh.devices)\n        return costs\n\n    def get_total_allocation_size(self):\n        \"\"\"Get the total memory allocation size in bytes.\"\"\"\n        if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh):\n            return (ray.get(self.physical_mesh.workers[0].\n                            get_exec_total_allocation_size.remote(\n                                self.exec_uuid)))\n        else:\n            assert isinstance(self.physical_mesh, LocalPhysicalDeviceMesh)\n            return self.compiled.total_allocation_size()\n\n    def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED):\n        \"\"\"Return the HLO IR in the text format.\"\"\"\n        if status == HloStatus.FULLY_OPTIMIZED:\n            if self.fully_optimized_hlo_text is not None:\n                return self.fully_optimized_hlo_text\n            assert isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh)\n            self.fully_optimized_hlo_text = ray.get(\n                self.physical_mesh.workers[0].get_exec_hlo_text.remote(\n                    self.exec_uuid))\n            return self.fully_optimized_hlo_text\n        else:\n            raise ValueError(f\"Invalid status: {status}\")\n\n    def dump_debug_info(self, folder: str):\n        \"\"\"\n        Dump intermediate representations and other informations for debugging.\n        \"\"\"\n        os.makedirs(folder, exist_ok=True)\n        name = self.hlo.name\n        name = name[:name.index(\"shard_parallel\") - 1]\n        prefix = os.path.join(folder, name)\n        with open(f\"{prefix}.hlo\", \"w\") as f:\n            f.write(self.get_hlo_text())\n        with open(f\"{prefix}.mem_usage.txt\", \"w\") as f:\n            f.write(f\"total_allocation_size: \"\n                    f\"{self.get_total_allocation_size()/(1024**3):.3f} GB\\n\")\n        with open(f\"{prefix}_input_placement_specs.txt\", \"w\") as f:\n            f.write(str(self.get_input_placement_specs()))\n        with open(f\"{prefix}_output_placement_specs.txt\", \"w\") as f:\n            f.write(str(self.get_output_placement_specs()))\n\n\ndef delete_donated_buffers(buffer_dict, uuids, donated_invars):\n    \"\"\"Delete the donated buffers from the local buffer dictionary.\"\"\"\n    for uuid, is_donated in zip(uuids, donated_invars):\n        if is_donated:\n            del buffer_dict[uuid]\n\n\nclass NormalMeshWorkerExecutable(MeshWorkerExecutable):\n    \"\"\"The worker part of a normal mesh executable.\"\"\"\n\n    def __init__(self, worker: \"MeshHostWorker\", uuid: int, hlo: WrappedHlo,\n                 stage_plan: StagePlan, donated_invars: Sequence[bool]):\n        num_devices = np.prod(stage_plan.logical_mesh_shape)\n        assert num_devices == len(worker.backend.devices())\n\n        self.compiled = run_backend_compilation(worker.backend, hlo, stage_plan,\n                                                num_devices)\n        self.donated_invars = donated_invars\n        self.worker = worker\n\n        # Set up timers\n        self.timer_name = get_execution_timer_name(uuid)\n        self.sync_func = get_sync_func_worker(worker)\n\n    def execute_on_worker(self, input_uuids: Sequence[int],\n                          output_uuids: Sequence[int], sync_before: bool,\n                          sync_after: bool):\n        \"\"\"Run the executable on the worker.\"\"\"\n        buffer_dict = self.worker.buffers\n\n        # Get input buffers from uuids\n        # Sequence[Sequence[DeviceBuffer]], shape(num_args, num_devices)\n        input_bufs = [buffer_dict[x] for x in input_uuids]\n\n        if global_config.enable_overlapping:\n            xe.computation_wait_events(input_uuids, self.worker.backend)\n            xe.set_idx_to_uuid(output_uuids)\n        # Execute the executable\n        timers(self.timer_name).start(self.sync_func if sync_before else None)\n        try:\n            output_bufs = self.compiled.execute_sharded_on_local_devices(\n                input_bufs)\n        except RuntimeError:\n            ray.actor.exit_actor()\n        timers(self.timer_name).stop(self.sync_func if sync_after else None)\n\n        # Store output buffers\n        for i in range(len(output_uuids)):\n            buffer_dict[output_uuids[i]] = output_bufs[i]\n\n        # Delete donated input buffers\n        delete_donated_buffers(buffer_dict, input_uuids, self.donated_invars)\n\n    def profile_with_dummy_inputs(self, backend, local_devices):\n        \"\"\"Profile the time cost of this executable with dummy inputs.\"\"\"\n        return profile_xla_executable(self.compiled, backend, local_devices)\n\n    def get_hlo_text(self):\n        return self.compiled.hlo_modules()[0].to_string()\n\n    def get_total_allocation_size(self):\n        return self.compiled.total_allocation_size()\n\n    def __del__(self):\n        self.compiled.delete()\n\n\ndef get_grad_sync_channel_ids(hlo_module: xe.HloModule) -> str:\n    \"\"\"Return the channel ids of all-reduces that are used for gradient\n    synchronization.\n\n    The return value is a string containing all channel ids separated by\n    periods. (e.g., \".0.12.\" means channel id 0 and 12)\n    \"\"\"\n    return xe.get_grad_sync_channel_ids(hlo_module)\n\n\nclass GradAccMeshDriverExecutable(MeshDriverExecutable):\n    \"\"\"The driver part of a gradient accumulation mesh executable.\"\"\"\n\n    def __init__(self,\n                 physical_mesh: \"PhysicalDeviceMesh\",\n                 accumulate_grad: WrappedHlo,\n                 apply_grad: WrappedHlo,\n                 stage_plan: StagePlan,\n                 avals: Sequence[ShapedArray],\n                 out_avals: Sequence[ShapedArray],\n                 grad_avals: Sequence[ShapedArray],\n                 donated_invars: Sequence[bool],\n                 batch_invars: Sequence[bool],\n                 accumulate_grad_invar_indices: Sequence[int],\n                 apply_grad_invar_indices: Sequence[int],\n                 num_micro_batches: int,\n                 in_tree: Optional[PyTreeDef] = None,\n                 out_tree: Optional[PyTreeDef] = None,\n                 flop_count: Optional[int] = None):\n        self.physical_mesh = physical_mesh\n        self.accumulate_grad_hlo = accumulate_grad\n        self.apply_grad_hlo = apply_grad\n        self.avals = avals\n        self.out_avals = out_avals\n        self.grad_avals = grad_avals\n        self.donated_invars = donated_invars\n        self.batch_invars = batch_invars\n        self.accumulate_grad_invar_indices = accumulate_grad_invar_indices\n        self.apply_grad_invar_indices = apply_grad_invar_indices\n        self.num_micro_batches = num_micro_batches\n        self.in_tree = in_tree\n        self.out_tree = out_tree\n        self.flop_count = flop_count\n        self.stage_plan = stage_plan\n        self.auto_sharding_option = stage_plan.auto_sharding_option\n        self.auto_sharding_objective = stage_plan.auto_sharding_objective\n\n        # Read sharding specs\n        logical_mesh_shape = stage_plan.logical_mesh_shape\n        accumulate_grad_in_avals = [\n            avals[i] for i in accumulate_grad_invar_indices\n        ] + grad_avals\n        apply_grad_in_avals = \\\n            [avals[i] for i in apply_grad_invar_indices] + grad_avals\n        accumulate_grad_input_sharding_specs, grad_sharding_specs = (\n            get_input_output_sharding_specs(accumulate_grad.get_module(),\n                                            accumulate_grad_in_avals,\n                                            grad_avals,\n                                            physical_mesh.num_devices,\n                                            logical_mesh_shape))\n        apply_grad_input_sharding_specs, output_sharding_specs = (\n            get_input_output_sharding_specs(apply_grad.get_module(),\n                                            apply_grad_in_avals, out_avals,\n                                            physical_mesh.num_devices,\n                                            logical_mesh_shape))\n        self.output_sharding_specs = output_sharding_specs\n        num_grads = len(grad_avals)\n        assert accumulate_grad_input_sharding_specs[\n            -num_grads:] == grad_sharding_specs\n\n        global_arg_sharding_specs = [None] * len(avals)\n        for i, idx in enumerate(accumulate_grad_invar_indices):\n            global_arg_sharding_specs[\n                idx] = accumulate_grad_input_sharding_specs[i]\n        for i, idx in enumerate(apply_grad_invar_indices):\n            if global_arg_sharding_specs[idx] is None:\n                global_arg_sharding_specs[\n                    idx] = apply_grad_input_sharding_specs[i]\n            else:\n                assert global_arg_sharding_specs[\n                    idx] == apply_grad_input_sharding_specs[i]\n        ## Fill in \"Replicated\" for remaining undefined args\n        for i, spec in enumerate(global_arg_sharding_specs):\n            if spec is None:\n                global_arg_sharding_specs[i] = (make_replicated_spec(\n                    avals[i], logical_mesh_shape))\n\n        # Cache results for input and output sharding\n        global_batch_arg_indices = [\n            i for i in range(len(avals)) if batch_invars[i]\n        ]\n        global_arg_shard_indices = []\n        for i, aval in enumerate(avals):\n            if batch_invars[i] and isinstance(self.physical_mesh,\n                                              DistributedPhysicalDeviceMesh):\n                # The handling of micro batches is different for\n                # distributed device mesh.\n                batch_dim = 0\n                new_shape = (num_micro_batches *\n                             aval.shape[0],) + aval.shape[1:]\n                new_spec = get_microbatch_sharding_spec(\n                    global_arg_sharding_specs[i], batch_dim, num_micro_batches)\n                global_arg_shard_indices.append(\n                    pxla.spec_to_indices(new_shape, new_spec))\n            else:\n                global_arg_shard_indices.append(\n                    pxla.spec_to_indices(aval.shape,\n                                         global_arg_sharding_specs[i]))\n\n        accumulate_grad_batch_arg_indices = [\n            i for i, j in enumerate(accumulate_grad_invar_indices)\n            if batch_invars[j]\n        ]\n        grad_shard_shapes = [\n            get_shard_shape(aval, spec)\n            for aval, spec in zip(grad_avals, grad_sharding_specs)\n        ]\n        grad_shard_dtypes = [aval.dtype for aval in grad_avals]\n        self.global_arg_sharding_specs = global_arg_sharding_specs\n        self.global_batch_arg_indices = global_batch_arg_indices\n        self.global_arg_shard_indices = global_arg_shard_indices\n        self.outs_handler = physical_mesh.get_outputs_handler(\n            out_avals, output_sharding_specs)\n\n        # Send the executable to workers\n        self.exec_uuid = next_mesh_executable_uuid()\n        if isinstance(physical_mesh, DistributedPhysicalDeviceMesh):\n            for w in physical_mesh.workers:\n                w.put_executable.remote(\n                    self.exec_uuid, GradAccMeshWorkerExecutable,\n                    accumulate_grad, apply_grad, accumulate_grad_invar_indices,\n                    apply_grad_invar_indices, accumulate_grad_batch_arg_indices,\n                    grad_shard_shapes, grad_shard_dtypes, stage_plan,\n                    donated_invars, batch_invars, num_grads, num_micro_batches)\n            # The following members will be fetched from the workers later\n            self.fully_optimized_hlo_text = None\n            self.grad_sync_channel_ids = None\n        else:\n            assert isinstance(physical_mesh, LocalPhysicalDeviceMesh)\n            backend = physical_mesh.backend\n\n            self.accumulate_grad = run_backend_compilation(\n                backend, accumulate_grad, stage_plan, physical_mesh.num_devices)\n            self.apply_grad = run_backend_compilation(backend, apply_grad,\n                                                      stage_plan,\n                                                      physical_mesh.num_devices)\n            self.allocate_zero_buffers = compile_allocate_zero_buffers(\n                backend, physical_mesh.num_devices, grad_shard_shapes,\n                grad_shard_dtypes)\n            self.accumulate_grad_batch_arg_indices = (\n                accumulate_grad_batch_arg_indices)\n\n            self.fully_optimized_hlo_text = (\n                self.accumulate_grad.hlo_modules()[0].to_string() +\n                self.apply_grad.hlo_modules()[0].to_string())\n            self.grad_sync_channel_ids = get_grad_sync_channel_ids(\n                self.accumulate_grad.hlo_modules()[0])\n            self.skip_allreduce_env_name = (\n                self.accumulate_grad.hlo_modules()[0].name +\n                \"XLA_SKIP_NCCL_COLLECTIVE_IDS\")\n\n        # Set up timers\n        self.exec_timer_name = get_execution_timer_name(self.exec_uuid)\n        self.shard_args_timer_name = self.exec_timer_name + \"-shard-args\"\n        self.sync_func = get_sync_func_driver(physical_mesh)\n\n    def launch_on_driver(self, *args):\n        \"\"\"Launch the executable on the driver.\"\"\"\n        num_micro_batches = self.num_micro_batches\n        grad_avals = self.grad_avals\n        num_grads = len(grad_avals)\n        physical_mesh = self.physical_mesh\n        num_hosts = physical_mesh.num_hosts\n        num_outs = len(self.out_avals)\n\n        timers(self.shard_args_timer_name).start()\n        input_bufs = physical_mesh.shard_args_to_bufs(\n            self.global_arg_shard_indices, self.donated_invars,\n            self.batch_invars, num_micro_batches, args)\n\n        first_batch_bufs = input_bufs\n        next_batches_bufs = []\n        for i in self.global_batch_arg_indices:\n            micro_batches = input_bufs[i]\n            first_batch_bufs[i] = micro_batches[0]\n            next_batches_bufs.extend(micro_batches[1:])\n        timers(self.shard_args_timer_name).stop()\n\n        if isinstance(physical_mesh, DistributedPhysicalDeviceMesh):\n            first_batch_uuids = np.array([ref.uuid for ref in first_batch_bufs])\n\n            if next_batches_bufs:\n                next_batches_uuids = np.array(\n                    [ref.uuid for ref in next_batches_bufs])\n            else:\n                next_batches_uuids = (None,) * num_hosts\n\n            output_uuids = next_array_uuids(num_outs)\n\n            # Execute SPMD binary\n            for i in range(num_hosts):\n                physical_mesh.workers[i].run_executable.remote(\n                    self.exec_uuid, first_batch_uuids, next_batches_uuids,\n                    output_uuids, global_config.shard_parallel_sync_for_timer,\n                    global_config.shard_parallel_sync_for_timer)\n\n            # Gather output buffers\n            output_bufs = np.array(\n                [RemoteArrayRef(physical_mesh, uuid) for uuid in output_uuids])\n\n            # Mark donated input buffers as already deleted on workers.\n            for ary_ref, is_donated in zip(first_batch_bufs,\n                                           self.donated_invars):\n                if is_donated:\n                    ary_ref.set_deleted_on_workers()\n\n            # Mark micro batch buffers as already deleted on workers.\n            for ary_ref in next_batches_bufs:\n                ary_ref.set_deleted_on_workers()\n        else:\n            assert isinstance(physical_mesh, LocalPhysicalDeviceMesh)\n            sync_func = (self.sync_func if\n                         global_config.shard_parallel_sync_for_timer else None)\n\n            # Prepare gradient buffers\n            timers(self.exec_timer_name).start(sync_func)\n            grad_bufs = (\n                self.allocate_zero_buffers.execute_sharded_on_local_devices([]))\n\n            # Call accumulate_grad multiple times\n            tmp_input_bufs = ([\n                first_batch_bufs[i] for i in self.accumulate_grad_invar_indices\n            ] + grad_bufs)\n            os.environ[\n                self.skip_allreduce_env_name] = self.grad_sync_channel_ids\n            for i in range(num_micro_batches):\n                if i != 0:\n                    # Feed in the data of the next batch\n                    tmp_input_bufs[-num_grads:] = grad_bufs\n                    for j, idx in enumerate(\n                            self.accumulate_grad_batch_arg_indices):\n                        tmp_input_bufs[idx] = next_batches_bufs[\n                            j * (num_micro_batches - 1) + (i - 1)]\n                if i == num_micro_batches - 1:\n                    os.environ[self.skip_allreduce_env_name] = \"\"\n                grad_bufs = (self.accumulate_grad.\n                             execute_sharded_on_local_devices(tmp_input_bufs))\n\n            # Call apply_grad\n            tmp_input_bufs = (\n                [first_batch_bufs[i] for i in self.apply_grad_invar_indices] +\n                grad_bufs)\n            output_bufs = self.apply_grad.execute_sharded_on_local_devices(\n                tmp_input_bufs)\n            timers(self.exec_timer_name).stop(sync_func)\n\n        # Wrap output buffers as ShardedArray\n        return self.outs_handler(output_bufs)\n\n    def get_input_placement_specs(self):\n        \"\"\"\n        Return the preferred placement specs for input arguments.\n        The return value is a pytree of PlacementSpec\n        with the same structure as the input pytree.\n        \"\"\"\n        return wrap_to_placement_spec_tree(self.physical_mesh, self.avals,\n                                           self.global_arg_sharding_specs,\n                                           self.in_tree)\n\n    def get_output_placement_specs(self):\n        \"\"\"\n        Return the preferred placement specs for outputs.\n        The return value is a pytree of PlacementSpec\n        with the same structure as the output pytree.\n        \"\"\"\n        return wrap_to_placement_spec_tree(self.physical_mesh, self.out_avals,\n                                           self.output_sharding_specs,\n                                           self.out_tree)\n\n    def get_parallel_plan(self):\n        \"\"\"Get the overall parallel plan.\"\"\"\n        cluster_info = ClusterInfo(self.physical_mesh.num_hosts,\n                                   self.physical_mesh.num_devices_per_host)\n        return ParallelPlan(cluster_info, self.num_micro_batches,\n                            self.auto_sharding_option, None,\n                            tree_leaves(self.get_input_placement_specs()))\n\n    def get_total_allocation_size(self):\n        \"\"\"Get the total memory allocation size in bytes.\"\"\"\n        if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh):\n            return ray.get(self.physical_mesh.workers[0].\n                           get_exec_total_allocation_size.remote(\n                               self.exec_uuid))\n        else:\n            assert isinstance(self.physical_mesh, LocalPhysicalDeviceMesh)\n            return max(self.accumulate_grad.total_allocation_size(),\n                       self.apply_grad.total_allocation_size())\n\n    def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED):\n        \"\"\"Return the HLO IR in the text format.\"\"\"\n        if status == HloStatus.FULLY_OPTIMIZED:\n            if self.fully_optimized_hlo_text is not None:\n                return self.fully_optimized_hlo_text\n            assert isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh)\n            self.fully_optimized_hlo_text = ray.get(\n                self.physical_mesh.workers[0].get_exec_hlo_text.remote(\n                    self.exec_uuid))\n            self.grad_sync_channel_ids = ray.get(\n                self.physical_mesh.workers[0].get_exec_grad_sync_channel_ids.\n                remote(self.exec_uuid))\n            return self.fully_optimized_hlo_text\n        else:\n            raise ValueError(f\"Invalid status: {status}\")\n\n    def dump_debug_info(self, folder: str):\n        \"\"\"\n        Dump intermediate representations and other informations for debugging.\n        \"\"\"\n        os.makedirs(folder, exist_ok=True)\n        name = self.accumulate_grad_hlo.name\n        name = name[:name.index(\"shard_parallel\") - 1]\n        prefix = os.path.join(folder, name)\n        with open(f\"{prefix}.hlo\", \"w\") as f:\n            f.write(self.get_hlo_text())\n        with open(f\"{prefix}.grad_sync_channel_ids.txt\", \"w\") as f:\n            f.write(str(self.grad_sync_channel_ids) + \"\\n\")\n        with open(f\"{prefix}.mem_usage.txt\", \"w\") as f:\n            f.write(f\"total_allocation_size: \"\n                    f\"{self.get_total_allocation_size()/(1024**3):.3f} GB\\n\")\n        with open(f\"{prefix}_input_placement_specs.txt\", \"w\") as f:\n            f.write(str(self.get_input_placement_specs()))\n        with open(f\"{prefix}_output_placement_specs.txt\", \"w\") as f:\n            f.write(str(self.get_output_placement_specs()))\n\n\nclass GradAccMeshWorkerExecutable(MeshWorkerExecutable):\n    \"\"\"The worker part of a gradient accumulation mesh executable.\"\"\"\n\n    def __init__(self, worker: \"MeshHostWorker\", uuid: int,\n                 accumulate_grad: WrappedHlo, apply_grad: WrappedHlo,\n                 accumulate_grad_invar_indices: Sequence[int],\n                 apply_grad_invar_indices: Sequence[int],\n                 accumulate_grad_batch_arg_indices: Sequence[int],\n                 grad_shard_shapes: Sequence[Sequence[int]],\n                 grad_shard_dtypes: Sequence[jnp.dtype], stage_plan: StagePlan,\n                 donated_invars: Sequence[bool], batch_invars: Sequence[bool],\n                 num_grads: int, num_micro_batches: int):\n        num_devices = np.prod(stage_plan.logical_mesh_shape)\n        assert num_devices == len(worker.backend.devices())\n\n        self.accumulate_grad = run_backend_compilation(worker.backend,\n                                                       accumulate_grad,\n                                                       stage_plan, num_devices)\n        self.apply_grad = run_backend_compilation(worker.backend, apply_grad,\n                                                  stage_plan, num_devices)\n        self.allocate_zero_buffers = compile_allocate_zero_buffers(\n            worker.backend, num_devices, grad_shard_shapes, grad_shard_dtypes)\n        self.accumulate_grad_invar_indices = accumulate_grad_invar_indices\n        self.apply_grad_invar_indices = apply_grad_invar_indices\n        self.accumulate_grad_batch_arg_indices = (\n            accumulate_grad_batch_arg_indices)\n        self.donated_invars = donated_invars\n        self.batch_invars = batch_invars\n        self.num_grads = num_grads\n        self.num_micro_batches = num_micro_batches\n        self.buffer_dict = worker.buffers\n        self.grad_sync_channel_ids = get_grad_sync_channel_ids(\n            self.accumulate_grad.hlo_modules()[0])\n        self.skip_allreduce_env_name = (\n            self.accumulate_grad.hlo_modules()[0].name +\n            \"XLA_SKIP_NCCL_COLLECTIVE_IDS\")\n\n        # Set up timers\n        self.timer_name = get_execution_timer_name(uuid)\n        self.sync_func = get_sync_func_worker(worker)\n\n    def execute_on_worker(self, first_batch_uuids: Sequence[int],\n                          next_batches_uuids: Sequence[int],\n                          output_uuids: Sequence[int], sync_before: bool,\n                          sync_after: bool):\n        \"\"\"Run the executable on the worker.\"\"\"\n        buffer_dict = self.buffer_dict\n        num_micro_batches = self.num_micro_batches\n\n        tmp_input_bufs = [\n            buffer_dict[first_batch_uuids[i]]\n            for i in self.accumulate_grad_invar_indices\n        ]\n\n        # Prepare gradient buffers\n        timers(self.timer_name).start(self.sync_func if sync_before else None)\n        grad_bufs = self.allocate_zero_buffers.execute_sharded_on_local_devices(\n            [])\n\n        # Call accumulate_grad multiple times\n        tmp_input_bufs = tmp_input_bufs + grad_bufs\n        os.environ[self.skip_allreduce_env_name] = self.grad_sync_channel_ids\n        for i in range(num_micro_batches):\n            if i != 0:\n                # Feed in the data of the next batch\n                tmp_input_bufs[-self.num_grads:] = grad_bufs\n                for j, idx in enumerate(self.accumulate_grad_batch_arg_indices):\n                    tmp_input_bufs[idx] = buffer_dict[next_batches_uuids[\n                        j * (num_micro_batches - 1) + (i - 1)]]\n            if i == num_micro_batches - 1:\n                os.environ[self.skip_allreduce_env_name] = \"\"\n            grad_bufs = self.accumulate_grad.execute_sharded_on_local_devices(\n                tmp_input_bufs)\n\n        # Call apply_grad\n        tmp_input_bufs = [\n            buffer_dict[first_batch_uuids[i]]\n            for i in self.apply_grad_invar_indices\n        ] + grad_bufs\n        output_bufs = self.apply_grad.execute_sharded_on_local_devices(\n            tmp_input_bufs)\n        timers(self.timer_name).stop(self.sync_func if sync_after else None)\n\n        # Store output buffers\n        for i in range(len(output_uuids)):\n            buffer_dict[output_uuids[i]] = output_bufs[i]\n\n        # Delete donated input buffers\n        delete_donated_buffers(buffer_dict, first_batch_uuids,\n                               self.donated_invars)\n\n        # Delete micro batch buffers\n        if next_batches_uuids is not None and \\\n                next_batches_uuids[0] is not None:\n            for i in range(len(next_batches_uuids)):\n                del buffer_dict[next_batches_uuids[i]]\n\n    def get_hlo_text(self):\n        return (self.accumulate_grad.hlo_modules()[0].to_string() +\n                self.apply_grad.hlo_modules()[0].to_string())\n\n    def get_total_allocation_size(self):\n        \"\"\"Get the total memory allocation size in bytes.\"\"\"\n        return max(self.accumulate_grad.total_allocation_size(),\n                   self.apply_grad.total_allocation_size())\n\n    def __del__(self):\n        self.accumulate_grad.delete()\n        self.apply_grad.delete()\n        self.allocate_zero_buffers.delete()\n\n\nclass PartialGradAccMeshDriverExecutable(NormalMeshDriverExecutable):\n    \"\"\"\n    The driver part of a mesh executable that can optionally skip\n    the gradient synchronization step.\n\n    This executable is used for computation stages in pipeline,\n    such as forward, backward and apply_grad\n    \"\"\"\n\n    def __init__(self, physical_mesh: \"PhysicalDeviceMesh\", hlo: WrappedHlo,\n                 stage_plan: StagePlan, avals: Sequence[ShapedArray],\n                 out_avals: Sequence[ShapedArray],\n                 donated_invars: Sequence[bool]):\n        super().__init__(physical_mesh, hlo, stage_plan, avals, out_avals,\n                         donated_invars)\n\n    def _set_executable(self, physical_mesh, hlo, stage_plan):\n        \"\"\"Put the executable on workers.\"\"\"\n        if isinstance(physical_mesh, DistributedPhysicalDeviceMesh):\n            for w in physical_mesh.workers:\n                w.put_executable.remote(self.exec_uuid,\n                                        PartialGradAccMeshWorkerExecutable, hlo,\n                                        stage_plan, self.donated_invars)\n            self.hlo_text = None  # will be fetched from the workers later\n            self.grad_sync_channel_ids = None\n            self.skip_allreduce_env_name = None\n        else:\n            assert isinstance(physical_mesh, LocalPhysicalDeviceMesh)\n            self.compiled = run_backend_compilation(physical_mesh.backend, hlo,\n                                                    stage_plan,\n                                                    physical_mesh.num_devices)\n            self.hlo_text = self.compiled.hlo_modules()[0].to_string()\n            self.grad_sync_channel_ids = get_grad_sync_channel_ids(\n                self.compiled.hlo_modules()[0])\n            self.skip_allreduce_env_name = (\n                self.compiled.hlo_modules()[0].name +\n                \"XLA_SKIP_NCCL_COLLECTIVE_IDS\")\n\n    def launch_on_driver(self, *args, **kwargs):\n        \"\"\"Launch the executable on the driver.\"\"\"\n        assert \"skip_grad_sync\" in kwargs, (\n            'Partial grad acc mesh executable missing kwargs \"skip_grad_sync\"')\n        skip_grad_sync = kwargs[\"skip_grad_sync\"]\n        os.environ[self.skip_allreduce_env_name] = (self.grad_sync_channel_ids\n                                                    if skip_grad_sync else \"\")\n        return super().launch_on_driver(*args, **kwargs)\n\n\nclass PartialGradAccMeshWorkerExecutable(NormalMeshWorkerExecutable):\n    \"\"\"\n    The worker part of a mesh executable that can optionally skip\n    the gradient synchronization step.\n\n    This executable is used for computation stages in pipeline,\n    such as forward, backward and apply_grad\n    \"\"\"\n\n    def __init__(self, worker: \"MeshHostWorker\", uuid: int, hlo: WrappedHlo,\n                 stage_plan: StagePlan, donated_invars: Sequence[bool]):\n        super().__init__(worker, uuid, hlo, stage_plan, donated_invars)\n        self.grad_sync_channel_ids = get_grad_sync_channel_ids(\n            self.compiled.hlo_modules()[0])\n        self.skip_allreduce_env_name = (self.compiled.hlo_modules()[0].name +\n                                        \"XLA_SKIP_NCCL_COLLECTIVE_IDS\")\n\n    # pylint: disable=arguments-differ\n    def execute_on_worker(self, input_uuids: Sequence[int],\n                          output_uuids: Sequence[int], sync_before: bool,\n                          sync_after: bool, skip_grad_sync: bool):\n        \"\"\"Run the executable on the worker.\"\"\"\n        os.environ[self.skip_allreduce_env_name] = (self.grad_sync_channel_ids\n                                                    if skip_grad_sync else \"\")\n        return super().execute_on_worker(input_uuids, output_uuids, sync_before,\n                                         sync_after)\n\n    def profile_with_dummy_inputs(self, backend, local_devices, skip_grad_sync):\n        \"\"\"Profile the time cost of this executable with dummy inputs.\"\"\"\n        os.environ[self.skip_allreduce_env_name] = (self.grad_sync_channel_ids\n                                                    if skip_grad_sync else \"\")\n        return profile_xla_executable(self.compiled, backend, local_devices)\n\n\nclass AllocZeroBufferDriverExecutable(MeshDriverExecutable):\n    \"\"\"The driver part of a buffer-allocation executable.\"\"\"\n\n    def __init__(self, physical_mesh: \"PhysicalDeviceMesh\",\n                 grad_vars: Sequence[ShapedArray],\n                 grad_sharding_specs: Sequence[pxla.ShardingSpec]):\n        self.physical_mesh = physical_mesh\n        grad_avals = [var.aval for var in grad_vars]\n        grad_shard_shapes = [\n            get_shard_shape(aval, spec)\n            for aval, spec in zip(grad_avals, grad_sharding_specs)\n        ]\n        grad_shard_dtypes = [aval.dtype for aval in grad_avals]\n        self.out_avals = grad_avals\n        self.outs_handler = physical_mesh.get_outputs_handler(\n            grad_avals, grad_sharding_specs)\n\n        self.exec_uuid = next_mesh_executable_uuid()\n        if isinstance(physical_mesh, DistributedPhysicalDeviceMesh):\n            for w in physical_mesh.workers:\n                w.put_executable.remote(self.exec_uuid,\n                                        AllocZeroBufferWorkerExecutable,\n                                        grad_shard_shapes, grad_shard_dtypes)\n        else:\n            assert isinstance(physical_mesh, LocalPhysicalDeviceMesh)\n            self.allocate_zero_buffers = compile_allocate_zero_buffers(\n                physical_mesh.backend, physical_mesh.devices, grad_shard_shapes,\n                grad_shard_dtypes)\n\n        self.exec_timer_name = get_execution_timer_name(self.exec_uuid)\n        self.sync_func = get_sync_func_driver(physical_mesh)\n\n    def launch_on_driver(self, *args):\n        \"\"\"Launch the executable on the driver.\"\"\"\n        assert len(args) == 0, (\n            f\"allocate zero buffers does not need args, got {len(args)}\")\n        physical_mesh = self.physical_mesh\n        num_hosts = physical_mesh.num_hosts\n        num_outs = len(self.out_avals)\n\n        if isinstance(physical_mesh, DistributedPhysicalDeviceMesh):\n            # Get output uuids\n            output_uuids = next_array_uuids(num_outs)\n\n            # Execute SPMD binary\n            for i in range(num_hosts):\n                physical_mesh.workers[i].run_executable.remote(\n                    self.exec_uuid, [], output_uuids)\n\n            # Gather outputs\n            output_bufs = np.array(\n                [RemoteArrayRef(physical_mesh, uuid) for uuid in output_uuids])\n        else:\n            assert isinstance(physical_mesh, LocalPhysicalDeviceMesh)\n            timers(self.exec_timer_name).start(self.sync_func)\n            output_bufs = (\n                self.allocate_zero_buffers.execute_sharded_on_local_devices([]))\n            timers(self.exec_timer_name).stop(self.sync_func)\n\n        return self.outs_handler(output_bufs)\n\n\nclass AllocZeroBufferWorkerExecutable(MeshWorkerExecutable):\n    \"\"\"The worker part of a buffer-allocation executable.\"\"\"\n\n    def __init__(self, worker: \"MeshHostWorker\", uuid: int,\n                 grad_shard_shapes: Sequence[Sequence[int]],\n                 grad_shard_dtypes: Sequence[jnp.dtype]):\n        num_devices = len(worker.backend.devices())\n        self.allocate_zero_buffers = compile_allocate_zero_buffers(\n            worker.backend, num_devices, grad_shard_shapes, grad_shard_dtypes)\n        self.worker = worker\n\n        self.timer_name = get_execution_timer_name(uuid)\n        self.sync_func = get_sync_func_worker(worker)\n\n    def execute_on_worker(self, input_uuids: Sequence[int],\n                          output_uuids: Sequence[int], sync_before: bool,\n                          sync_after: bool):\n        \"\"\"Run the executable on the worker.\"\"\"\n        # pylint: disable=unused-argument\n        buffer_dict = self.worker.buffers\n\n        # Execute\n        if global_config.enable_overlapping:\n            xe.set_idx_to_uuid(output_uuids)\n        timers(self.timer_name).start(self.sync_func if sync_before else None)\n        output_bufs = (\n            self.allocate_zero_buffers.execute_sharded_on_local_devices([]))\n        timers(self.timer_name).stop(self.sync_func if sync_after else None)\n        for i in range(len(output_uuids)):\n            buffer_dict[output_uuids[i]] = output_bufs[i]\n\n    def __del__(self):\n        self.allocate_zero_buffers.delete()\n\n\nclass UtilMeshWorkerExecutable(MeshWorkerExecutable):\n    \"\"\"Worker executable that runs a manually generated function. It is lighter\n    than NormalMeshWorkerExecutable as it does not have a StagePlan.\n\n    Currently, it is used for concatenate(will be deprecated after we move it\n    to apply_grad) and allgather.\n    \"\"\"\n\n    def __init__(self, worker, uuid, hlo: WrappedHlo):\n        num_devices = len(worker.backend.devices())\n        compile_options = get_compile_options(\n            num_replicas=1,\n            num_partitions=num_devices,\n            device_assignment=np.arange(num_devices).reshape((1, -1)),\n            use_spmd_partitioning=False,\n            parameter_is_tupled_arguments=False,\n            build_random_seed=global_config.compile_random_seed)\n        xla_computation = hlo.get_computation()\n\n        with XlaPassContext({\n                \"done-event::enable\": global_config.enable_overlapping,\n        }):\n            self.exec = worker.backend.compile(xla_computation, compile_options)\n\n        self.worker = worker\n        self.timer_name = get_execution_timer_name(uuid)\n        self.sync_func = get_sync_func_worker(worker)\n\n    def execute_on_worker(self, input_uuids: Sequence[int],\n                          output_uuids: Sequence[int], sync_before: bool,\n                          sync_after: bool):\n        \"\"\"Run the executable on the worker.\"\"\"\n        buffer_dict = self.worker.buffers\n\n        # Get input\n        input_bufs = [buffer_dict[x] for x in input_uuids]\n\n        if global_config.enable_overlapping:\n            xe.computation_wait_events(input_uuids, self.worker.backend)\n            xe.set_idx_to_uuid(output_uuids)\n\n        # Execute\n        timers(self.timer_name).start(self.sync_func if sync_before else None)\n        output_bufs = self.exec.execute_sharded_on_local_devices(input_bufs)\n        timers(self.timer_name).stop(self.sync_func if sync_after else None)\n\n        for i in range(len(output_uuids)):\n            buffer_dict[output_uuids[i]] = output_bufs[i]\n\n    def __del__(self):\n        self.exec.delete()\n\n\ndef get_index_select_mesh_executable(avals, sharding_specs, index, dim,\n                                     device_mesh, donate_avals):\n    if type(index) not in [ShapedArray, ShapeDtypeStruct]:\n        index = xla.canonicalize_dtype(index)\n    index_shape = xc.shape_from_pyval(index)\n    key = hash((\"index_select\", tuple(avals), tuple(sharding_specs),\n                tuple(donate_avals), dim, index_shape))\n    if key in device_mesh.operation_executables:\n        return device_mesh.operation_executables[key]\n    index_aval = ShapedArray(index.shape, index.dtype)\n    assert len(avals) == len(sharding_specs) == len(donate_avals)\n    hlo = get_index_select_computation(sharding_specs, dim, avals, index_shape)\n    hlo = run_spmd_partitioner_pass(hlo, device_mesh.num_devices)\n\n    as_option = AutoShardingOption()\n    strategy_config = StagePlan(global_config.compile_random_seed,\n                                device_mesh.shape, 1 << 60,\n                                as_option.all_reduce_threshold,\n                                AutoShardingOption(), None, -1)\n    out_tree = tree_flatten(avals)[1]\n    executable = NormalMeshDriverExecutable(device_mesh,\n                                            hlo,\n                                            strategy_config,\n                                            [*avals, index_aval],\n                                            avals, [*donate_avals, False],\n                                            out_tree=out_tree)\n    device_mesh.operation_executables[key] = executable\n    return executable\n"
  },
  {
    "path": "alpa/mesh_profiling.py",
    "content": "\"\"\"Profiling communication cost for device meshes.\"\"\"\nfrom collections import defaultdict\nimport math\nimport os\nimport pickle\nimport time\n\nimport numpy as np\nfrom jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe\nimport ray\n\nfrom alpa.util import (GB, print_used_time, XlaPassContext, to_str_round,\n                       run_with_timeout)\n\nops = xc.ops\n\n\nclass MeshProfilingResult:\n    \"\"\"Store the profiling result for a physical mesh.\"\"\"\n\n    def __init__(self):\n        # Cost dictionary for communication primitives.\n        # Dict[Tuple(group, dtype) -> List[Tuple(size, time)]]\n        # The elements in the list is sorted according to the size (ascending).\n        self.all_gather_cost_dict = defaultdict(list)\n        self.all_reduce_cost_dict = defaultdict(list)\n        self.all_to_all_cost_dict = defaultdict(list)\n        self.reduce_scatter_cost_dict = defaultdict(list)\n        self.available_memory_per_device = None\n\n        # Cost dictionary for computation primitives.\n        # Reuse the same data structure.\n        # Dict[Tuple(None, dtype)] -> List[Tuple(flop_count, time)]\n        self.dot_cost_dict = defaultdict(list)\n        self.conv_cost_dict = []\n\n        # Cost dictionary for specific operators\n        # Dict[op_info] -> double\n        self.op_cost_dict = []\n\n    def update(self, new_mesh_result):\n        raise NotImplementedError\n\n    def make_monotonic(self):\n        \"\"\"Make the bandwidth monotonically increase along with the\n        communication size.\"\"\"\n        for cost_dict in [\n                self.all_gather_cost_dict, self.all_reduce_cost_dict,\n                self.all_to_all_cost_dict, self.reduce_scatter_cost_dict,\n                self.dot_cost_dict\n        ]:\n            new_cost_dict = {}\n\n            for key, value in cost_dict.items():\n                sizes = np.array([x[0] for x in value])\n                times = np.array([x[1] for x in value])\n\n                # make bandwidth monotonically increasing\n                bandwidth = sizes / times\n                for i in range(1, len(bandwidth)):\n                    bandwidth[i] = max(bandwidth[i], bandwidth[i - 1])\n\n                new_times = np.empty_like(times)\n                for i in range(len(times)):\n                    if sizes[i] == 0 or bandwidth[i] == 0:\n                        new_times[i] = value[i][1]\n                    else:\n                        new_times[i] = sizes[i] / bandwidth[i]\n\n                new_value = [\n                    (value[i][0], new_times[i]) for i in range(len(value))\n                ]\n                new_cost_dict[key] = new_value\n\n            cost_dict.update(new_cost_dict)\n\n    def sort_cost_lists(self):\n        \"\"\"Sort the items in the list from smallest to largest. This is the\n        format required by the HLO cost model in c++.\"\"\"\n        for cost_dict in [\n                self.all_gather_cost_dict, self.all_reduce_cost_dict,\n                self.all_to_all_cost_dict, self.reduce_scatter_cost_dict,\n                self.dot_cost_dict\n        ]:\n            new_cost_dict = {}\n\n            for key, value in cost_dict.items():\n                sizes = [x[0] for x in value]\n                indices = np.argsort(sizes, kind=\"stable\")\n                new_cost_dict[key] = [value[i] for i in indices]\n\n            cost_dict.update(new_cost_dict)\n\n    def estimate_all_gather(self, group, size, dtype):\n        ret = (\n            self._estimate_internal(group, size, dtype,\n                                    self.all_gather_cost_dict) -\n            self._estimate_internal(group, 0, dtype, self.all_gather_cost_dict))\n        return ret\n\n    def estimate_all_reduce(self, group, size, dtype):\n        ret = (\n            self._estimate_internal(group, size, dtype,\n                                    self.all_reduce_cost_dict) -\n            self._estimate_internal(group, 0, dtype, self.all_reduce_cost_dict))\n        return ret\n\n    @staticmethod\n    def _estimate_internal(group, size, dtype, cost_dict):\n        key = (group, dtype)\n        cost_list = cost_dict[key]\n        assert cost_list, f\"Cannot find records for {(group, dtype)}\"\n\n        if size > cost_list[-1][0]:\n            i = len(cost_list) - 2\n        elif size < cost_list[0][0]:\n            i = 0\n        else:\n            for i in range(len(cost_list) - 1):\n                if cost_list[i][0] <= size <= cost_list[i + 1][0]:\n                    break\n\n        left_size = cost_list[i][0]\n        left_cost = cost_list[i][1]\n        right_size = cost_list[i + 1][0]\n        right_cost = cost_list[i + 1][1]\n\n        return (size - left_size) / (right_size - left_size) * (\n            right_cost - left_cost) + left_cost\n\n    def __str__(self):\n        ret = \"=== dot_cost_dict ===\\n\"\n        for key, value in self.dot_cost_dict.items():\n            sizes = np.array([x[0] for x in value])\n            times = np.array([x[1] for x in value])\n            tflops = sizes / times / 1e12\n            ret += f\"Key: {key}\\nTFLOPS: {to_str_round(tflops, 2)}\\n\\n\"\n\n        ret += \"=== all_reduce_cost_dict ===\\n\"\n        for key, value in self.all_reduce_cost_dict.items():\n            num_devices = len(key[0][0])\n            sizes = np.array([x[0] for x in value])\n            times = np.array([x[1] for x in value])\n            comm_bytes = 2 * (num_devices -\n                              1) / num_devices * sizes * to_np_dtype(\n                                  key[1]).itemsize\n            bandwidth = comm_bytes / times / GB\n            ret += f\"Key: {key}\\nBandwidth: {to_str_round(bandwidth, 2)}\\n\\n\"\n\n        ret += \"=== all_to_all_cost_dict ===\\n\"\n        for key, value in self.all_to_all_cost_dict.items():\n            num_devices = len(key[0][0])\n            sizes = np.array([x[0] for x in value])\n            times = np.array([x[1] for x in value])\n            comm_bytes = ((num_devices - 1) / (num_devices**2) * sizes *\n                          to_np_dtype(key[1]).itemsize)\n            bandwidth = comm_bytes / times / GB\n            ret += f\"Key: {key}\\nBandwidth: {to_str_round(bandwidth, 2)}\\n\\n\"\n        return ret\n\n\nclass ProfilingResultDatabase:\n    \"\"\"A database that stores profiling results for multiple device mesh\n    shapes.\"\"\"\n\n    def __init__(self, data=None):\n        self.data = data or {}\n\n    def query(self, cluster_key, mesh_shape):\n        key = (cluster_key, mesh_shape)\n        return self.data[key]\n\n    def update_one_mesh(self, cluster_key, mesh_shape, mesh_result):\n        key = (cluster_key, mesh_shape)\n        if key not in self.data:\n            self.data[key] = mesh_result\n        else:\n            self.data[key].update(mesh_result)\n\n    def update(self, new_database):\n        for ((cluster_key, mesh_shape),\n             mesh_result) in new_database.data.items():\n            self.update_one_mesh(cluster_key, mesh_shape, mesh_result)\n\n    def insert_dummy_mesh_result(self, cluster_key, mesh_shape):\n        \"\"\"Insert dummy results for a mesh.\"\"\"\n        key = (cluster_key, mesh_shape)\n        assert key not in self.data\n\n        # Copy data from mesh shape (1, 1)\n        src_key = (cluster_key, (1, 1))\n        assert src_key in self.data\n        self.data[key] = self.data[src_key]\n\n    def save(self, filename):\n        with open(filename, \"wb\") as f:\n            pickle.dump(self.data, f)\n\n    def load(self, filename):\n        with open(filename, \"rb\") as f:\n            new_data = pickle.load(f)\n        self.update(ProfilingResultDatabase(new_data))\n\n    def __str__(self):\n        ret = \"\"\n        for (cluster_key, mesh_shape), value in self.data.items():\n            ret += f\"cluster_key: {cluster_key}, mesh_shape: {mesh_shape}\\n\"\n            ret += str(value)\n        return ret\n\n\ndef _op_parameter(builder, num, shape, dtype):\n    shape = xc.Shape.array_shape(dtype, shape)\n    name = \"\"\n    replicated = []\n    return ops.Parameter(builder, num,\n                         shape.with_major_to_minor_layout_if_absent(), name,\n                         replicated)\n\n\ndef _create_channel_id(backend):\n    channel_id = backend.create_channel_handle()\n    channel_id.type = xe.ChannelHandle_ChannelType.DEVICE_TO_DEVICE\n    channel_id.handle = 1\n    return channel_id\n\n\ndef _op_all_gather(operand, replica_groups, channel_id):\n    replica_groups_protos = xc.make_replica_groups(replica_groups)\n    ret = ops.AllGather(operand, 0, len(replica_groups[0]),\n                        replica_groups_protos, channel_id, None, True)\n    return ret\n\n\ndef _op_all_reduce(operand, dtype, reduce_op, replica_groups, channel_id):\n    replica_groups_protos = xc.make_replica_groups(replica_groups)\n    if reduce_op == \"add\":\n        rc = xc.XlaBuilder(\"reduce_\" + reduce_op)\n        x = _op_parameter(rc, 0, (), dtype)\n        y = _op_parameter(rc, 1, (), dtype)\n        z = ops.Add(x, y)\n        rc = rc.build(z)\n    else:\n        raise NotImplementedError\n\n    ret = ops.AllReduce(operand, rc, replica_groups_protos, channel_id, None,\n                        True)\n    return ret\n\n\ndef _op_all_to_all(operand, replica_groups, channel_id):\n    replica_groups_protos = xc.make_replica_groups(replica_groups)\n    ret = ops.AllToAll(operand, 0, 0, len(replica_groups[0]),\n                       replica_groups_protos, channel_id, None, True)\n    return ret\n\n\ndef _op_reduce_scatter(operand, dtype, reduce_op, replica_groups, channel_id):\n    replica_groups_protos = xc.make_replica_groups(replica_groups)\n    if reduce_op == \"add\":\n        rc = xc.XlaBuilder(\"reduce_\" + reduce_op)\n        x = _op_parameter(rc, 0, (), dtype)\n        y = _op_parameter(rc, 1, (), dtype)\n        z = ops.Add(x, y)\n        rc = rc.build(z)\n    else:\n        raise NotImplementedError\n\n    ret = ops.ReduceScatter(operand, rc, 0, len(replica_groups[0]),\n                            replica_groups_protos, channel_id, None, True)\n    return ret\n\n\ndef _compile_profiling_executable_while_loop(backend, shapes, op_func,\n                                             num_devices):\n    \"\"\"\n    Compile an xla executable for benchmarking operators.\n    It is a while loop that calls the operator for multiple times.\n    \"\"\"\n\n    in_tuple_shape = xc.Shape.tuple_shape(\n        [xc.Shape.array_shape(np.dtype(np.int32), ())] +\n        [xc.Shape.array_shape(dtype, shape) for shape, dtype in shapes])\n\n    sharding = xc.OpSharding()\n    sharding.type = sharding.type.REPLICATED\n    sharding.tile_assignment_dimensions.extend([1])\n    sharding.tile_assignment_devices.extend([0])\n\n    # body\n    body = xc.XlaBuilder(\"body\")\n    in_tuple = ops.Parameter(body, 0, in_tuple_shape)\n    counter = ops.GetTupleElement(in_tuple, 0)\n    counter = ops.Sub(counter, ops.Constant(body, np.int32(1)))\n\n    operands = [\n        ops.GetTupleElement(in_tuple, i + 1) for i in range(len(shapes))\n    ]\n    body.set_sharding(sharding)\n    op_func(operands)\n    body.clear_sharding()\n    ops.Tuple(body, [counter] + operands)\n    body_computation = body.build()\n\n    # condition\n    cond = xc.XlaBuilder(\"condition\")\n    in_tuple = ops.Parameter(cond, 0, in_tuple_shape)\n    counter = ops.GetTupleElement(in_tuple, 0)\n    ops.Gt(counter, ops.Constant(cond, np.int32(0)))\n    cond_computation = cond.Build()\n\n    # while loop\n    loop = xc.XlaBuilder(\"loop\")\n    counter = _op_parameter(loop, 0, (), np.dtype(np.int32))\n    operands = [\n        _op_parameter(loop, i + 1, shape, dtype)\n        for i, (shape, dtype) in enumerate(shapes)\n    ]\n    while_init = ops.Tuple(loop, [counter] + operands)\n    ops.While(cond_computation, body_computation, while_init)\n    for i in range(len(shapes) + 1):\n        loop.setup_alias((i,), i, ())\n    loop_computation = loop.Build()\n\n    compile_options = xb.get_compile_options(\n        num_replicas=1,\n        num_partitions=num_devices,\n        device_assignment=np.arange(num_devices).reshape((1, -1)),\n        use_spmd_partitioning=True,\n    )\n    shapes = [(1, np.int32)] + shapes\n    return shapes, backend.compile(loop_computation, compile_options)\n\n\ndef _compile_profiling_executable_once(backend, shapes, op_func, num_devices):\n    \"\"\"\n    Compile an xla executable for benchmarking operators.\n    It runs the op only once.\n    \"\"\"\n\n    sharding = xc.OpSharding()\n    sharding.type = sharding.type.REPLICATED\n    sharding.tile_assignment_dimensions.extend([1])\n    sharding.tile_assignment_devices.extend([0])\n\n    body = xc.XlaBuilder(\"body\")\n    operands = [\n        _op_parameter(body, i, shape, dtype)\n        for i, (shape, dtype) in enumerate(shapes)\n    ]\n    body.set_sharding(sharding)\n    op_func(operands)\n    body.clear_sharding()\n    ops.Tuple(body, operands)\n    for i in range(len(shapes)):\n        body.setup_alias((i,), i, ())\n    body_computation = body.Build()\n\n    compile_options = xb.get_compile_options(\n        num_replicas=1,\n        num_partitions=num_devices,\n        device_assignment=np.arange(num_devices).reshape((1, -1)),\n        use_spmd_partitioning=True,\n    )\n    return shapes, backend.compile(body_computation, compile_options)\n\n\ndef bound(value, minimum, maximum):\n    return max(min(value, maximum), minimum)\n\n\ndef to_np_dtype(dtype_str: str):\n    \"\"\"Convert a string type to np dtype\"\"\"\n    if dtype_str == \"f32\":\n        return np.dtype(\"float32\")\n    elif dtype_str == \"f16\":\n        return np.dtype(\"float16\")\n    else:\n        return np.dtype(dtype_str)\n\n\ndef rank_0_print(host_id, msg):\n    \"\"\"Print message on rank 0.\"\"\"\n    if host_id == 0:\n        print(msg, flush=True)\n\n\n# A set containing all replica group patterns with nccl communicator created.\ncommunicator_set = set()\n\n\ndef profile_one_hlo_op(backend, local_devices, host_id, num_devices, op_info):\n    \"\"\"Profile one HLO operator.\"\"\"\n    dot_fp16_work = 100e12\n    dot_fp32_work = 50e12\n    comm_work = 1 << 32\n    replica_groups = None\n\n    if op_info[0] == \"dot\":\n        n, m, k, dtype_str = op_info[1]\n        dtype = to_np_dtype(dtype_str)\n        shapes = [((n, k), dtype), ((k, m), dtype), ((n, m), dtype)]\n\n        def op_func(operands):\n            lhs, rhs, _ = operands\n            dim_numbers = (((1,), (0,)), ((), ()))\n            dim_numbers = xc.make_dot_dimension_numbers(dim_numbers)\n            out = ops.DotGeneral(lhs, rhs, dim_numbers)\n            operands[-1] = out\n\n        flop_ct = max(2 * n * m * k, 1)\n        if dtype_str == \"f16\":\n            work = dot_fp16_work\n        elif dtype_str == \"f32\":\n            work = dot_fp32_work\n        else:\n            raise ValueError(f\"Invalid type: {dtype_str}\")\n        number = bound(int(work / flop_ct), 10, 1 << 12)\n    elif op_info[0] == \"all-gather\":\n        replica_groups, dtype, size = op_info[1]\n        dtype = to_np_dtype(dtype)\n        size = size // len(replica_groups[0]) * len(replica_groups[0])\n        shapes = [((size // len(replica_groups[0]),), dtype), ((size,), dtype)]\n\n        def op_func(operands):\n            if shapes[0][0][0] == 0:\n                return\n            channel_id = _create_channel_id(backend)\n            out = _op_all_gather(operands[0], replica_groups, channel_id)\n            operands[-1] = out\n\n        number = bound(int(comm_work / max(size * dtype.itemsize, 1)), 10,\n                       1 << 13)\n    elif op_info[0] == \"all-reduce\":\n        replica_groups, dtype, size = op_info[1]\n        dtype = to_np_dtype(dtype)\n        shapes = [((size,), dtype), ((size,), dtype)]\n\n        def op_func(operands):\n            channel_id = _create_channel_id(backend)\n            out = _op_all_reduce(operands[0], dtype, \"add\", replica_groups,\n                                 channel_id)\n            operands[-1] = out\n\n        number = bound(int(comm_work / max(size * dtype.itemsize, 1)), 10,\n                       1 << 13)\n    elif op_info[0] == \"all-to-all\":\n        replica_groups, dtype, size = op_info[1]\n        dtype = to_np_dtype(dtype)\n        size = size // (len(replica_groups[0])**2) * (len(replica_groups[0])**2)\n        shapes = [((size // len(replica_groups[0]),), dtype),\n                  ((size // len(replica_groups[0]),), dtype)]\n\n        def op_func(operands):\n            if shapes[0][0][0] // len(replica_groups[0]) == 0:\n                return\n            channel_id = _create_channel_id(backend)\n            out = _op_all_to_all(operands[0], replica_groups, channel_id)\n            operands[-1] = out\n\n        number = bound(int(comm_work / max(size * dtype.itemsize, 1)), 10,\n                       1 << 13)\n    elif op_info[0] == \"reduce-scatter\":\n        replica_groups, dtype, size = op_info[1]\n        dtype = to_np_dtype(dtype)\n        size = size // len(replica_groups[0]) * len(replica_groups[0])\n        shapes = [((size,), dtype), ((size // len(replica_groups[0]),), dtype)]\n\n        def op_func(operands):\n            if shapes[1][0][0] == 0:\n                return\n            channel_id = _create_channel_id(backend)\n            out = _op_reduce_scatter(operands[0], dtype, \"add\", replica_groups,\n                                     channel_id)\n            operands[-1] = out\n\n        number = bound(int(comm_work / max(size * dtype.itemsize, 1)), 10,\n                       1 << 13)\n    elif op_info[0] == \"create-communicator\":\n        replica_groups, = op_info[1]\n        dtype = to_np_dtype(\"f32\")\n        shapes = [((1024,), dtype), ((1024,), dtype)]\n\n        def op_func(operands):\n            channel_id = _create_channel_id(backend)\n            out = _op_all_reduce(operands[0], dtype, \"add\", replica_groups,\n                                 channel_id)\n            operands[-1] = out\n    elif op_info[0] == \"barrier\":\n        replica_groups = (tuple(i for i in range(num_devices)),)\n        dtype = to_np_dtype(\"f32\")\n        shapes = [((1,), dtype), ((1,), dtype)]\n\n        def op_func(operands):\n            channel_id = _create_channel_id(backend)\n            out = _op_all_reduce(operands[0], dtype, \"add\", replica_groups,\n                                 channel_id)\n            operands[-1] = out\n    else:\n        raise NotImplementedError(f\"Invalid op: {op_info[0]}\")\n\n    if op_info[0] in [\"create-communicator\", \"barrier\"]:\n        rank_0_print(host_id, f\"{op_info[0]}\")\n\n        # Compile\n        all_shapes, compiled = _compile_profiling_executable_once(\n            backend, shapes, op_func, num_devices)\n\n        # Run\n        device_inputs = []\n        for shape, dtype in all_shapes:\n            device_inputs.append([\n                backend.buffer_from_pyval(np.ones(shape, dtype),\n                                          local_devices[k])\n                for k in range(len(local_devices))\n            ])\n\n        for d in local_devices:\n            d.synchronize_all_activity()\n        device_inputs = compiled.execute_sharded_on_local_devices(device_inputs)\n        for d in local_devices:\n            d.synchronize_all_activity()\n        return 0\n    else:\n        # Create the nccl communicator\n        # This step is a workaround for some nccl/xla deadlock\n        if replica_groups and replica_groups not in communicator_set:\n            tmp_op_info = (\"create-communicator\", (op_info[1][0],))\n            profile_one_hlo_op(backend, local_devices, host_id, num_devices,\n                               tmp_op_info)\n            communicator_set.add(replica_groups)\n\n        warmup = max(number // 10, 2)\n\n        rank_0_print(\n            host_id, f\"Profiling {op_info}, number: {number}, \"\n            f\"timestamp: {time.time():.0f}.\")\n\n        # Compile\n        all_shapes, compiled = _compile_profiling_executable_while_loop(\n            backend, shapes, op_func, num_devices)\n\n        # Warm up\n        device_inputs = []\n        for j, (shape, dtype) in enumerate(all_shapes):\n            if j == 0:\n                device_inputs.append([\n                    backend.buffer_from_pyval(np.int32(warmup),\n                                              local_devices[k])\n                    for k in range(len(local_devices))\n                ])\n            else:\n                np_array = np.ones(shape, dtype)\n                device_inputs.append([\n                    backend.buffer_from_pyval(np_array, local_devices[k])\n                    for k in range(len(local_devices))\n                ])\n\n        for d in local_devices:\n            d.synchronize_all_activity()\n        device_inputs = compiled.execute_sharded_on_local_devices(device_inputs)\n        for d in local_devices:\n            d.synchronize_all_activity()\n\n        # Run profiling\n        device_inputs[0] = [\n            backend.buffer_from_pyval(np.int32(number), local_devices[k])\n            for k in range(len(local_devices))\n        ]\n\n        for d in local_devices:\n            d.synchronize_all_activity()\n        tic = time.time()\n        compiled.execute_sharded_on_local_devices(device_inputs)\n        for d in local_devices:\n            d.synchronize_all_activity()\n        toc = time.time()\n\n        # Return\n        mean_time = (toc - tic) / number\n        return mean_time\n\n\ndef profile_hlo_ops(op_infos, backend, local_devices, host_id, num_devices,\n                    cache_filename, single_timeout):\n    \"\"\"Profile a list of HLO operators on a worker.\"\"\"\n    results = []\n    save_every = 15\n    barrier_every = 5\n\n    if os.path.exists(cache_filename):\n        rank_0_print(host_id,\n                     f\"Load cached hlo op cost dict from {cache_filename}...\")\n        with open(cache_filename, \"rb\") as cf:\n            cache_dict = pickle.load(cf)\n    else:\n        cache_dict = {}\n\n    old_cache_len = len(cache_dict)\n\n    try:\n        for i, op_info in enumerate(op_infos):\n            if op_info in cache_dict:\n                rank_0_print(host_id, f\"Hit cache {op_info} ...\")\n                results.append(cache_dict[op_info])\n                continue\n\n            if i % barrier_every == 0:\n                # Run barrier to reduce hanging/deadlock issues\n                run_with_timeout(profile_one_hlo_op,\n                                 (backend, local_devices, host_id, num_devices,\n                                  (\"barrier\",)),\n                                 timeout=single_timeout)\n\n            # Profile one op\n            mean_time = run_with_timeout(\n                profile_one_hlo_op,\n                (backend, local_devices, host_id, num_devices, op_info),\n                timeout=single_timeout)\n            cache_dict[op_info] = mean_time\n            results.append(mean_time)\n\n            if host_id == 0 and (i + 1) % save_every == 0:\n                old_cache_len = len(cache_dict)\n                rank_0_print(host_id, \"Save cache...\")\n                with open(cache_filename, \"wb\") as cf:\n                    pickle.dump(cache_dict, cf)\n    except TimeoutError:\n        print(f\"Worker {host_id} timeout error\", flush=True)\n        return None\n    except RuntimeError:\n        print(f\"Worker {host_id} runtime error\", flush=True)\n        return None\n\n    if host_id == 0 and len(cache_dict) > old_cache_len:\n        rank_0_print(host_id, \"Save cache...\")\n        with open(cache_filename, \"wb\") as cf:\n            pickle.dump(cache_dict, cf)\n\n    return np.array(results)\n\n\ndef profile_dot(dot_range, device_cluster, cache_filename):\n    \"\"\"Profile the compute cost of dot.\"\"\"\n    physical_mesh = device_cluster.get_physical_mesh(host_ids=[0],\n                                                     num_devices_per_host=1)\n\n    # Profile dot\n    op_infos = []\n    for dtype in [\"f16\", \"f32\"]:\n        for n in dot_range:\n            op_infos.append((\"dot\", (n, n, n, dtype)))\n    results = physical_mesh.profile_hlo_ops(op_infos, cache_filename)\n\n    dot_cost_dict = defaultdict(list)\n    for i in range(len(op_infos)):\n        n, m, k, dtype = op_infos[i][1]\n        flop_count = 2 * n * m * k\n        dot_cost_dict[((), dtype)].append((flop_count, results[i]))\n        print(f\"Matmul: {(n, m, k, dtype)}, \"\n              f\"TFLOPS: {flop_count / results[i]/ 1e12:.2f}\")\n\n    physical_mesh.shutdown()\n    time.sleep(2)\n    return dot_cost_dict\n\n\ndef enumerate_all_collective_spec(num_hosts, num_devices_per_host,\n                                  max_comm_size_intra_node,\n                                  max_comm_size_inter_node):\n    \"\"\"Enumerate all possible collective groups.\"\"\"\n    # Enumerate all possible logical meshes\n    logical_mesh_shapes = []\n    num_devices = num_hosts * num_devices_per_host\n    for i in range(1, num_devices + 1):\n        if num_devices % i == 0:\n            logical_mesh_shapes.append((num_devices // i, i))\n\n    # Enumerate all replica groups\n    all_specs = set()\n    for logical_mesh_shape in logical_mesh_shapes:\n        # dim 0\n        replica_groups = []\n        tmp_group = []\n        for i in range(logical_mesh_shape[0]):\n            tmp_group.append(\n                tuple(i * logical_mesh_shape[1] + j\n                      for j in range(logical_mesh_shape[1])))\n        replica_groups.append(tuple(tmp_group))\n\n        # dim 1\n        tmp_group = []\n        for j in range(logical_mesh_shape[1]):\n            tmp_group.append(\n                tuple(i * logical_mesh_shape[1] + j\n                      for i in range(logical_mesh_shape[0])))\n        replica_groups.append(tuple(tmp_group))\n\n        for replica_group in replica_groups:\n            for dtype in [\"f32\", \"f16\"]:\n                # Debug filter\n                #if replica_group != (tuple(range(32)),) or dtype != \"f32\":\n                #    continue\n\n                if (max(replica_group[0]) - min(replica_group[0]) <\n                        num_devices_per_host):\n                    max_comm_size = max_comm_size_intra_node\n                else:\n                    max_comm_size = max_comm_size_inter_node\n\n                max_num_elem_log_2 = math.ceil(\n                    math.log2(\n                        (1 << max_comm_size) / to_np_dtype(dtype).itemsize))\n\n                all_specs.add((tuple(replica_group), dtype, 0))\n                for i in range(0, max_num_elem_log_2 + 1):\n                    all_specs.add((tuple(replica_group), dtype, 1 << i))\n\n    all_specs = list(all_specs)\n    all_specs.sort(key=lambda k:\n                   (k[0][0][0] - k[0][0][-1], to_np_dtype(k[1]).itemsize, k[2]))\n    return list(all_specs)\n\n\ndef profile_all(device_cluster,\n                cluster_key,\n                max_comm_size_intra_node,\n                max_comm_size_inter_node,\n                max_fail_retry,\n                cache_filename,\n                dot_range=(0, 1024),\n                mesh_size_choices=None):\n    \"\"\"Profile costs for all dot and communication primitives.\"\"\"\n    #  pylint: disable=import-outside-toplevel\n    from alpa.pipeline_parallel.stage_construction import get_submesh_choices\n    print_used_time(None)\n\n    ##### Profile compute cost\n    dot_cost_dict = profile_dot(dot_range, device_cluster, cache_filename)\n    print_used_time(\"Profile dot\")\n\n    ##### Profile communication cost\n    virtual_mesh = device_cluster.get_virtual_physical_mesh()\n    if mesh_size_choices is None:\n        submesh_choices = list(\n            reversed(\n                get_submesh_choices(virtual_mesh.num_hosts,\n                                    virtual_mesh.num_devices_per_host, \"all\")))\n    else:\n        submesh_choices = list(\n            reversed(\n                get_submesh_choices(virtual_mesh.num_hosts,\n                                    virtual_mesh.num_devices_per_host, \"manual\",\n                                    mesh_size_choices)))\n    # Load failed batch keys\n    failed_batch_keys_filename = \"tmp/failed_batch_keys.pkl\"\n    if os.path.exists(failed_batch_keys_filename):\n        with open(failed_batch_keys_filename, \"rb\") as fbkf:\n            failed_batch_keys = pickle.load(fbkf)\n    else:\n        failed_batch_keys = set()\n\n    prof_database = ProfilingResultDatabase()\n    for _, (num_hosts, num_devices_per_host) in enumerate(submesh_choices):\n        print(f\"Mesh shape: {(num_hosts, num_devices_per_host)}\")\n\n        # Slice a mesh\n        tmp_mesh = virtual_mesh.slice_2d(tuple(range(num_hosts)),\n                                         (tuple(range(num_devices_per_host)),) *\n                                         num_hosts)\n        all_specs = enumerate_all_collective_spec(num_hosts,\n                                                  num_devices_per_host,\n                                                  max_comm_size_intra_node,\n                                                  max_comm_size_inter_node)\n\n        op_infos = []\n        for op_type in [\n                \"all-reduce\", \"all-gather\", \"all-to-all\", \"reduce-scatter\"\n        ]:\n            for spec in all_specs:\n                op_infos.append((op_type, spec))\n\n        physical_mesh = tmp_mesh.get_physical_mesh()\n        available_memory_per_device = physical_mesh.get_available_memory()\n\n        def get_op_info_key(op_info):  # return (op_type, replica_group)\n            return (op_info[0], op_info[1][0])\n\n        # Profile operators in batch to resolve some deadlock issues\n        results = []\n        s = 0\n        fail_ct = 0\n        while s < len(op_infos):\n            # Decide batch size\n            batch_key = get_op_info_key(op_infos[s])\n            batch_size = 1\n            while (s + batch_size < len(op_infos) and\n                   get_op_info_key(op_infos[s + batch_size]) == batch_key):\n                batch_size += 1\n\n            print(f\"Batch size: {batch_size}, key: {batch_key}\")\n\n            # Profile a batch\n            if batch_key in failed_batch_keys:\n                # This batch is skipped due to too many errors\n                batch_result = [np.inf] * batch_size\n            else:\n                try:\n                    batch_result = physical_mesh.profile_hlo_ops(\n                        op_infos[s:s + batch_size],\n                        cache_filename,\n                        single_timeout=bound(fail_ct * 100, 100, 400),\n                        batch_timeout=batch_size * 100)\n                except ray.exceptions.RayError:\n                    batch_result = None\n\n            if batch_result is not None:\n                results.extend(batch_result)\n                s += batch_size\n                fail_ct = 0\n            else:\n                op_infos[s:s + batch_size] = reversed(op_infos[s:s +\n                                                               batch_size])\n                fail_ct += 1\n\n                if fail_ct > max_fail_retry:\n                    # Skip this batch if there are too many errors\n                    print(f\"Failed key: {batch_key}\")\n                    failed_batch_keys.add(batch_key)\n                    with open(failed_batch_keys_filename, \"wb\") as fbkf:\n                        pickle.dump(failed_batch_keys, fbkf)\n\n                print(f\"Reboot physical mesh. fail_ct: {fail_ct}\")\n                physical_mesh.shutdown(forced=True)\n                physical_mesh = None\n                while physical_mesh is None:\n                    try:\n                        time.sleep(10)\n                        tmp_mesh.launched_physical_mesh = None\n                        physical_mesh = tmp_mesh.get_physical_mesh()\n                    except ray.exceptions.RayError:\n                        ray.shutdown()\n                        ray.init(address=\"auto\")\n                        physical_mesh = None\n\n        # Parse results\n        all_gather_cost_dict = defaultdict(list)\n        all_reduce_cost_dict = defaultdict(list)\n        all_to_all_cost_dict = defaultdict(list)\n        reduce_scatter_cost_dict = defaultdict(list)\n        for i in range(len(op_infos)):\n            op_type, (replica_groups, dtype, size) = op_infos[i]\n            array_size = size * to_np_dtype(dtype).itemsize\n            num_devices = len(replica_groups[0])\n\n            if op_type == \"all-gather\":\n                communication_size = array_size * (num_devices -\n                                                   1) / num_devices\n                all_gather_cost_dict[(replica_groups, dtype)].append(\n                    (size, results[i]))\n            elif op_type == \"all-reduce\":\n                communication_size = 2 * array_size * (num_devices -\n                                                       1) / num_devices\n                all_reduce_cost_dict[(replica_groups, dtype)].append(\n                    (size, results[i]))\n            elif op_type == \"all-to-all\":\n                communication_size = array_size * (\n                    num_devices - 1) / num_devices / num_devices\n                all_to_all_cost_dict[(replica_groups, dtype)].append(\n                    (size, results[i]))\n            elif op_type == \"reduce-scatter\":\n                communication_size = array_size * (num_devices -\n                                                   1) / num_devices\n                reduce_scatter_cost_dict[(replica_groups, dtype)].append(\n                    (size, results[i]))\n            else:\n                raise ValueError(f\"Invalid op: {op_type}\")\n\n            bandwidth = communication_size / results[i]\n            print(f\"Op: {op_infos[i]}, Bandwidth: {bandwidth / GB:.2f} GB/s\")\n\n        physical_mesh.shutdown()\n\n        mesh_result = MeshProfilingResult()\n        mesh_result.dot_cost_dict = dot_cost_dict\n        mesh_result.all_gather_cost_dict = all_gather_cost_dict\n        mesh_result.all_reduce_cost_dict = all_reduce_cost_dict\n        mesh_result.all_to_all_cost_dict = all_to_all_cost_dict\n        mesh_result.reduce_scatter_cost_dict = reduce_scatter_cost_dict\n        mesh_result.available_memory_per_device = available_memory_per_device\n        mesh_result.sort_cost_lists()\n        mesh_result.make_monotonic()\n        prof_database.update_one_mesh(cluster_key,\n                                      (num_hosts, num_devices_per_host),\n                                      mesh_result)\n\n    print_used_time(\"Profile communication\")\n    return prof_database\n\n\ndef estimate_hlo_module_cost(hlo_module,\n                             profiling_results,\n                             num_micro_batches=1,\n                             grad_sync_channel_ids=\"\"):\n    \"\"\"Estimate the cost of an HLO module with the HLO instruction level cost\n    model.\"\"\"\n    with XlaPassContext({\n            \"gpu_cost_model::profiling_results\": profiling_results,\n            \"gpu_cost_model::num_micro_batches\": num_micro_batches,\n            \"gpu_cost_model::grad_sync_channel_ids\": grad_sync_channel_ids,\n            \"gpu_cost_model::verbose\": 0,\n    }):\n        return xe.estimate_hlo_module_cost(hlo_module)\n"
  },
  {
    "path": "alpa/model/__init__.py",
    "content": ""
  },
  {
    "path": "alpa/model/bert_model.py",
    "content": "# flake8: noqa\n\"\"\"Model definition of BERT.\nCopied from https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_flax_bert.py\"\"\"\nfrom functools import partial\nfrom typing import Callable\n\nimport numpy as np\n\nfrom flax import linen as nn\nfrom flax.linen.partitioning import remat\nimport jax\nfrom jax import lax\nimport jax.numpy as jnp\n\nfrom alpa.model.model_util import (FlaxBaseModelOutput,\n                                   FlaxBaseModelOutputWithPooling,\n                                   FlaxBertForPreTrainingOutput,\n                                   FlaxMaskedLMOutput,\n                                   FlaxSequenceClassifierOutput, TrainState)\nfrom alpa.model.model_util import TrainState\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\n\n\nclass BertConfig:\n\n    def __init__(self,\n                 vocab_size=30522,\n                 hidden_size=768,\n                 num_hidden_layers=12,\n                 num_attention_heads=12,\n                 intermediate_size=3072,\n                 hidden_act=\"gelu\",\n                 hidden_dropout_prob=0.1,\n                 attention_probs_dropout_prob=0.1,\n                 max_position_embeddings=512,\n                 type_vocab_size=2,\n                 initializer_range=0.02,\n                 layer_norm_eps=1e-12,\n                 gradient_checkpointing=False,\n                 position_embedding_type=\"absolute\",\n                 use_cache=True,\n                 classifier_dropout=None,\n                 num_labels=None,\n                 tie_word_embeddings=True,\n                 add_manual_pipeline_markers=False,\n                 pipeline_mp_size=0,\n                 **kwargs):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.hidden_dropout_prob = hidden_dropout_prob\n        self.attention_probs_dropout_prob = attention_probs_dropout_prob\n        self.max_position_embeddings = max_position_embeddings\n        self.type_vocab_size = type_vocab_size\n        self.initializer_range = initializer_range\n        self.layer_norm_eps = layer_norm_eps\n        self.gradient_checkpointing = gradient_checkpointing\n        self.position_embedding_type = position_embedding_type\n        self.use_cache = use_cache\n        self.classifier_dropout = classifier_dropout\n        self.num_labels = num_labels\n        self.tie_word_embeddings = tie_word_embeddings\n        self.add_manual_pipeline_markers = add_manual_pipeline_markers\n        self.pipeline_mp_size = pipeline_mp_size\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.swish,\n    \"swish\": nn.swish,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\nclass FlaxBertEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word, position and token_type embeddings.\"\"\"\n\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n\n        if self.config.gradient_checkpointing:\n            trans_func = remat\n        else:\n            trans_func = lambda x: x\n\n        self.word_embeddings = trans_func(nn.Embed)(\n            self.config.vocab_size,\n            self.config.hidden_size,\n            embedding_init=jax.nn.initializers.normal(\n                stddev=self.config.initializer_range),\n            dtype=self.dtype,\n        )\n        self.position_embeddings = trans_func(nn.Embed)(\n            self.config.max_position_embeddings,\n            self.config.hidden_size,\n            embedding_init=jax.nn.initializers.normal(\n                stddev=self.config.initializer_range),\n            dtype=self.dtype,\n        )\n\n        if self.config.type_vocab_size > 0:\n            self.token_type_embeddings = trans_func(nn.Embed)(\n                self.config.type_vocab_size,\n                self.config.hidden_size,\n                embedding_init=jax.nn.initializers.normal(\n                    stddev=self.config.initializer_range),\n                dtype=self.dtype,\n            )\n        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                      dtype=self.dtype)\n        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)\n\n    def __call__(self,\n                 input_ids,\n                 token_type_ids,\n                 position_ids,\n                 attention_mask,\n                 deterministic: bool = True):\n        # Embed\n        inputs_embeds = self.word_embeddings(input_ids.astype(\"i4\"))\n        position_embeds = self.position_embeddings(position_ids.astype(\"i4\"))\n\n        if self.config.type_vocab_size > 0:\n            token_type_embeddings = self.token_type_embeddings(\n                token_type_ids.astype(\"i4\"))\n        else:\n            token_type_embeddings = 0.0\n\n        # Sum all embeddings\n        hidden_states = inputs_embeds + position_embeds + token_type_embeddings\n        hidden_states = self.LayerNorm(hidden_states)\n        hidden_states = self.dropout(hidden_states, deterministic=deterministic)\n        return hidden_states\n\n\nclass FlaxBertSelfAttention(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        if self.config.hidden_size % self.config.num_attention_heads != 0:\n            raise ValueError(\n                f\"`hidden_size`: {self.config.hidden_size} has to be a multiple of `num_attention_heads`: {self.config.num_attention_heads}\"\n            )\n\n        self.qvk_combined = nn.Dense(\n            self.config.hidden_size * 3,\n            dtype=self.dtype,\n            kernel_init=jax.nn.initializers.normal(\n                self.config.initializer_range),\n        )\n\n    def __call__(self,\n                 hidden_states,\n                 attention_mask,\n                 deterministic: bool = True,\n                 output_attentions: bool = False):\n        head_dim = self.config.hidden_size // self.config.num_attention_heads\n\n        qvk_combined_states = self.qvk_combined(hidden_states)\n        qvk_combined_states = qvk_combined_states.reshape(\n            qvk_combined_states.shape[:2] + (-1, 3))\n        query_states, value_states, key_states = jnp.split(qvk_combined_states,\n                                                           3,\n                                                           axis=3)\n\n        query_states = query_states.reshape(hidden_states.shape[:2] +\n                                            (self.config.num_attention_heads,\n                                             head_dim))\n        value_states = value_states.reshape(hidden_states.shape[:2] +\n                                            (self.config.num_attention_heads,\n                                             head_dim))\n        key_states = key_states.reshape(hidden_states.shape[:2] +\n                                        (self.config.num_attention_heads,\n                                         head_dim))\n\n        # Convert the boolean attention mask to an attention bias.\n        if attention_mask is not None:\n            # attention mask in the form of attention bias\n            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))\n            attention_bias = lax.select(\n                attention_mask > 0,\n                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),\n                jnp.full(attention_mask.shape, -1e10).astype(self.dtype),\n            )\n        else:\n            attention_bias = None\n\n        dropout_rng = None\n\n        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:\n            dropout_rng = self.make_rng(\"dropout\")\n\n        attn_weights = nn.attention.dot_product_attention_weights(\n            query_states,\n            key_states,\n            bias=attention_bias,\n            dropout_rng=dropout_rng,\n            dropout_rate=self.config.attention_probs_dropout_prob,\n            broadcast_dropout=False,\n            deterministic=deterministic,\n            dtype=self.dtype,\n            precision=None,\n        )\n\n        attn_output = jnp.einsum(\"...hqk,...khd->...qhd\", attn_weights,\n                                 value_states)\n        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))\n\n        outputs = (attn_output,\n                   attn_weights) if output_attentions else (attn_output,)\n        return outputs\n\n\nclass FlaxBertSelfOutput(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        self.dense = nn.Dense(\n            self.config.hidden_size,\n            kernel_init=jax.nn.initializers.normal(\n                self.config.initializer_range),\n            dtype=self.dtype,\n        )\n        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                      dtype=self.dtype)\n        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)\n\n    def __call__(self, hidden_states, input_tensor, deterministic: bool = True):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states, deterministic=deterministic)\n        hidden_states = self.LayerNorm(hidden_states + input_tensor)\n        return hidden_states\n\n\nclass FlaxBertAttention(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.self = FlaxBertSelfAttention(self.config, dtype=self.dtype)\n        self.output = FlaxBertSelfOutput(self.config, dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 attention_mask,\n                 deterministic: bool = True,\n                 output_attentions: bool = False):\n        # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)\n        # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable\n        # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)\n        attn_outputs = self.self(hidden_states,\n                                 attention_mask,\n                                 deterministic=deterministic,\n                                 output_attentions=output_attentions)\n        attn_output = attn_outputs[0]\n        hidden_states = self.output(attn_output,\n                                    hidden_states,\n                                    deterministic=deterministic)\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (attn_outputs[1],)\n\n        return outputs\n\n\nclass FlaxBertIntermediate(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        self.dense = nn.Dense(\n            self.config.intermediate_size,\n            kernel_init=jax.nn.initializers.normal(\n                self.config.initializer_range),\n            dtype=self.dtype,\n        )\n        self.activation = ACT2FN[self.config.hidden_act]\n\n    def __call__(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.activation(hidden_states)\n        return hidden_states\n\n\nclass FlaxBertOutput(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        self.dense = nn.Dense(\n            self.config.hidden_size,\n            kernel_init=jax.nn.initializers.normal(\n                self.config.initializer_range),\n            dtype=self.dtype,\n        )\n        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)\n        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                      dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 attention_output,\n                 deterministic: bool = True):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.dropout(hidden_states, deterministic=deterministic)\n        hidden_states = self.LayerNorm(hidden_states + attention_output)\n        return hidden_states\n\n\nclass FlaxBertLayer(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        self.attention = FlaxBertAttention(self.config, dtype=self.dtype)\n        self.intermediate = FlaxBertIntermediate(self.config, dtype=self.dtype)\n        self.output = FlaxBertOutput(self.config, dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 attention_mask,\n                 deterministic: bool = True,\n                 output_attentions: bool = False):\n        attention_outputs = self.attention(hidden_states,\n                                           attention_mask,\n                                           deterministic=deterministic,\n                                           output_attentions=output_attentions)\n        attention_output = attention_outputs[0]\n\n        hidden_states = self.intermediate(attention_output)\n        hidden_states = self.output(hidden_states,\n                                    attention_output,\n                                    deterministic=deterministic)\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (attention_outputs[1],)\n        return outputs\n\n\nclass FlaxBertLayerCollection(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        if self.config.gradient_checkpointing:\n            trans_func = partial(remat, static_argnums=(2, 3))\n        else:\n            trans_func = lambda x: x\n\n        # Mixed rematerialization\n        #layers = []\n        #for i in range(self.config.num_hidden_layers):\n        #    if i % 2 == 0:\n        #        layer = trans_func(FlaxBertLayer)(self.config,\n        #                                  name=str(i),\n        #                                  dtype=self.dtype)\n        #    else:\n        #        layer = FlaxBertLayer(self.config,\n        #                              name=str(i),\n        #                              dtype=self.dtype)\n        #    layers.append(layer)\n        #self.layers = layers\n\n        self.layers = [\n            trans_func(FlaxBertLayer)(self.config,\n                                      name=str(i),\n                                      dtype=self.dtype)\n            for i in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(\n        self,\n        hidden_states,\n        attention_mask,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        all_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n\n        for i, layer in enumerate(self.layers):\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            layer_outputs = layer(hidden_states, attention_mask, deterministic,\n                                  output_attentions)\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_attentions += (layer_outputs[1],)\n\n            if self.config.add_manual_pipeline_markers:\n                layers_per_stage = self.config.num_hidden_layers // self.config.pipeline_mp_size\n                assert self.config.num_hidden_layers % self.config.pipeline_mp_size == 0\n                if i % layers_per_stage == layers_per_stage - 1 and i != len(\n                        self.layers) - 1:\n                    mark_pipeline_boundary()\n\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        outputs = (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in outputs if v is not None)\n\n        return FlaxBaseModelOutput(last_hidden_state=hidden_states,\n                                   hidden_states=all_hidden_states,\n                                   attentions=all_attentions)\n\n\nclass FlaxBertEncoder(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype)\n\n    def __call__(\n        self,\n        hidden_states,\n        attention_mask,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        return self.layer(\n            hidden_states,\n            attention_mask,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n\nclass FlaxBertPooler(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        self.dense = nn.Dense(\n            self.config.hidden_size,\n            kernel_init=jax.nn.initializers.normal(\n                self.config.initializer_range),\n            dtype=self.dtype,\n        )\n\n    def __call__(self, hidden_states):\n        cls_hidden_state = hidden_states[:, 0]\n        cls_hidden_state = self.dense(cls_hidden_state)\n        return nn.tanh(cls_hidden_state)\n\n\nclass FlaxBertPredictionHeadTransform(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)\n        self.activation = ACT2FN[self.config.hidden_act]\n        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                      dtype=self.dtype)\n\n    def __call__(self, hidden_states):\n        hidden_states = self.dense(hidden_states)\n        hidden_states = self.activation(hidden_states)\n        return self.LayerNorm(hidden_states)\n\n\nclass FlaxBertLMPredictionHead(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n    bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros\n\n    def setup(self):\n        self.transform = FlaxBertPredictionHeadTransform(self.config,\n                                                         dtype=self.dtype)\n        if self.config.tie_word_embeddings:\n            self.decoder = None\n        else:\n            self.decoder = nn.Dense(self.config.vocab_size,\n                                    dtype=self.dtype,\n                                    use_bias=False)\n        self.bias = self.param(\"bias\", self.bias_init,\n                               (self.config.vocab_size,))\n\n    def __call__(self, hidden_states, shared_embedding=None):\n        hidden_states = self.transform(hidden_states)\n\n        if shared_embedding is not None:\n            assert self.decoder is None\n            hidden_states = hidden_states @ shared_embedding.T\n        else:\n            assert self.decoder is not None\n            hidden_states = self.decoder(hidden_states)\n\n        hidden_states += self.bias\n        return hidden_states\n\n\nclass FlaxBertOnlyMLMHead(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.predictions = FlaxBertLMPredictionHead(self.config,\n                                                    dtype=self.dtype)\n\n    def __call__(self, hidden_states, shared_embedding=None):\n        hidden_states = self.predictions(hidden_states,\n                                         shared_embedding=shared_embedding)\n        return hidden_states\n\n\nclass FlaxBertOnlyNSPHead(nn.Module):\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.seq_relationship = nn.Dense(2, dtype=self.dtype)\n\n    def __call__(self, pooled_output):\n        return self.seq_relationship(pooled_output)\n\n\nclass FlaxBertPreTrainingHeads(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.predictions = FlaxBertLMPredictionHead(self.config,\n                                                    dtype=self.dtype)\n        self.seq_relationship = nn.Dense(2, dtype=self.dtype)\n\n    def __call__(self, hidden_states, pooled_output, shared_embedding=None):\n        prediction_scores = self.predictions(hidden_states,\n                                             shared_embedding=shared_embedding)\n        seq_relationship_score = self.seq_relationship(pooled_output)\n        return prediction_scores, seq_relationship_score\n\n\nclass FlaxBertModule(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n    add_pooling_layer: bool = True\n\n    def setup(self):\n        self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)\n        self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype)\n        if self.add_pooling_layer:\n            self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask,\n        token_type_ids,\n        position_ids,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        hidden_states = self.embeddings(input_ids,\n                                        token_type_ids,\n                                        position_ids,\n                                        attention_mask,\n                                        deterministic=deterministic)\n        outputs = self.encoder(\n            hidden_states,\n            attention_mask,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = outputs[0]\n        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None\n\n        if not return_dict:\n            # if pooled is None, don't return it\n            if pooled is None:\n                return (hidden_states,) + outputs[1:]\n            return (hidden_states, pooled) + outputs[1:]\n\n        return FlaxBaseModelOutputWithPooling(\n            last_hidden_state=hidden_states,\n            pooler_output=pooled,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\nclass FlaxBertForPreTrainingModule(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)\n        self.cls = FlaxBertPreTrainingHeads(config=self.config,\n                                            dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask,\n        token_type_ids,\n        position_ids,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n\n        # Model\n        outputs = self.bert(\n            input_ids,\n            attention_mask,\n            token_type_ids,\n            position_ids,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        if self.config.tie_word_embeddings:\n            shared_embedding = self.bert.variables[\"params\"][\"embeddings\"][\n                \"word_embeddings\"][\"embedding\"]\n        else:\n            shared_embedding = None\n\n        hidden_states = outputs[0]\n        pooled_output = outputs[1]\n\n        prediction_scores, seq_relationship_score = self.cls(\n            hidden_states, pooled_output, shared_embedding=shared_embedding)\n\n        if not return_dict:\n            return (prediction_scores, seq_relationship_score) + outputs[2:]\n\n        return FlaxBertForPreTrainingOutput(\n            prediction_logits=prediction_scores,\n            seq_relationship_logits=seq_relationship_score,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\nclass FlaxBertForMaskedLMModule(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.bert = FlaxBertModule(config=self.config,\n                                   add_pooling_layer=False,\n                                   dtype=self.dtype)\n        self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask,\n        token_type_ids,\n        position_ids,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        # Model\n        outputs = self.bert(\n            input_ids,\n            attention_mask,\n            token_type_ids,\n            position_ids,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        if self.config.tie_word_embeddings:\n            shared_embedding = self.bert.variables[\"params\"][\"embeddings\"][\n                \"word_embeddings\"][\"embedding\"]\n        else:\n            shared_embedding = None\n\n        # Compute the prediction scores\n        logits = self.cls(hidden_states, shared_embedding=shared_embedding)\n\n        if not return_dict:\n            return (logits,) + outputs[1:]\n\n        return FlaxMaskedLMOutput(\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\nclass FlaxBertForSequenceClassificationModule(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.bert = FlaxBertModule(\n            config=self.config,\n            dtype=self.dtype,\n        )\n        classifier_dropout = (self.config.classifier_dropout\n                              if self.config.classifier_dropout is not None else\n                              self.config.hidden_dropout_prob)\n        self.dropout = nn.Dropout(rate=classifier_dropout)\n        self.classifier = nn.Dense(\n            self.config.num_labels,\n            dtype=self.dtype,\n        )\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask,\n        token_type_ids,\n        position_ids,\n        head_mask=None,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        # Model\n        outputs = self.bert(\n            input_ids,\n            attention_mask,\n            token_type_ids,\n            position_ids,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        pooled_output = outputs[1]\n        pooled_output = self.dropout(pooled_output, deterministic=deterministic)\n        logits = self.classifier(pooled_output)\n\n        if not return_dict:\n            return (logits,) + outputs[2:]\n\n        return FlaxSequenceClassifierOutput(\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\ndef test_bert_layer():\n    batch_size = 64\n    seq_len = 64\n    hidden_size = 768\n\n    hidden_states = jnp.ones((batch_size, seq_len, hidden_size),\n                             dtype=jnp.float32)\n    attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    label = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32)\n\n    # Init model and optimizer\n    model = FlaxBertLayer(BertConfig(hidden_size=hidden_size))\n    rngkey = jax.random.PRNGKey(0)\n    params = model.init(rngkey, hidden_states, attention_mask)\n    optimizer = optim.GradientDescent(1e-2).create(params)\n\n    def train_step(optimizer, batch):\n\n        def loss_func(params):\n            rngs = {\"dropout\": batch[\"rng\"]}\n            out = model.apply(params,\n                              batch[\"hidden_states\"],\n                              batch[\"attention_mask\"],\n                              rngs=rngs)[0]\n            return jnp.mean((out - batch[\"label\"])**2)\n\n        grad = jax.grad(loss_func)(optimizer.target)\n        new_optimizer = optimizer.apply_gradient(grad)\n        return new_optimizer\n\n    # JIT compile\n    #optimizer = train_step(optimizer,\n    #                       {\"hidden_states\": hidden_states,\n    #                        \"attention_mask\": attention_mask,\n    #                        \"label\": label,\n    #                        \"rng\": rngkey})\n\n    jaxpr = jax.make_jaxpr(train_step)(optimizer, {\n        \"hidden_states\": hidden_states,\n        \"attention_mask\": attention_mask,\n        \"label\": label,\n        \"rng\": rngkey\n    })\n    print(jaxpr)\n\n\ndef test_bert_mlm():\n    batch_size = 64\n    seq_len = 64\n    hidden_size = 128\n    num_attention_heads = 4\n    num_hidden_layers = 2\n    vocab_size = 1024\n\n    @partial(jax.jit, static_argnums=(2,))\n    def train_step(optimizer, batch, apply_func):\n\n        def loss_func(params):\n            rngs = {\"dropout\": batch[\"rng\"]}\n            logits = apply_func(params,\n                                batch[\"input_ids\"],\n                                batch[\"attention_mask\"],\n                                batch[\"token_type_ids\"],\n                                batch[\"position_ids\"],\n                                rngs=rngs)[0]\n            label_mask = jnp.where(batch[\"labels\"] > 0, 1.0, 0.0)\n            labels = jax.nn.one_hot(batch[\"labels\"], logits.shape[-1])\n            loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1),\n                            axis=-1)\n            loss = (label_mask * loss).sum() / label_mask.sum()\n            return loss\n\n        grad = jax.grad(loss_func)(optimizer.target)\n        new_optimizer = optimizer.apply_gradient(grad)\n        return new_optimizer\n\n    # Init model and optimizer\n    input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    token_type_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n\n    model = FlaxBertForMaskedLMModule(\n        BertConfig(\n            vocab_size=vocab_size,\n            hidden_size=hidden_size,\n            num_attention_heads=num_attention_heads,\n            intermediate_size=hidden_size * 4,\n            num_hidden_layers=num_hidden_layers,\n        ))\n    rngkey = jax.random.PRNGKey(0)\n    params = model.init(rngkey, input_ids, attention_mask, token_type_ids,\n                        position_ids)\n    optimizer = optim.GradientDescent(1e-2).create(params)\n\n    # JIT compile\n    train_step(\n        optimizer, {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"token_type_ids\": token_type_ids,\n            \"position_ids\": position_ids,\n            \"labels\": labels,\n            \"rng\": rngkey\n        }, model.apply)\n\n\nif __name__ == \"__main__\":\n    #test_bert_layer()\n    test_bert_mlm()\n"
  },
  {
    "path": "alpa/model/conformer.py",
    "content": "\"\"\"Conformer.\n\nReference:\nhttps://arxiv.org/pdf/2005.08100.pdf\nhttps://github.com/TensorSpeech/TensorFlowASR/blob/main/tensorflow_asr/models/encoders/conformer.py\n\"\"\"\n\nfrom functools import partial\nfrom typing import Any, Callable\n\nimport numpy as np\n\nimport flax\nfrom flax import linen as nn, optim\nfrom flax.training import train_state\nimport jax\nfrom jax import lax\nimport jax.numpy as jnp\n\nfrom alpa.model.model_util import (FlaxBaseModelOutput,\n                                   FlaxBaseModelOutputWithPooling,\n                                   FlaxBertForPreTrainingOutput,\n                                   FlaxMaskedLMOutput)\nfrom alpa import mark_pipeline\n\n\nclass TrainState(train_state.TrainState):\n    batch_stats: Any\n    dynamic_scale: optim.DynamicScale\n\n\nclass ConformerConfig:\n\n    def __init__(self,\n                 vocab_size=30522,\n                 hidden_size=768,\n                 num_hidden_layers=12,\n                 num_attention_heads=12,\n                 intermediate_size=3072,\n                 hidden_act=\"gelu\",\n                 hidden_dropout_prob=0.1,\n                 attention_probs_dropout_prob=0.1,\n                 max_position_embeddings=512,\n                 type_vocab_size=2,\n                 initializer_range=0.02,\n                 layer_norm_eps=1e-12,\n                 gradient_checkpointing=False,\n                 position_embedding_type=\"absolute\",\n                 use_cache=True,\n                 conv_subsample_channel=256,\n                 conv_kernel_size=32,\n                 **kwargs):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.hidden_dropout_prob = hidden_dropout_prob\n        self.attention_probs_dropout_prob = attention_probs_dropout_prob\n        self.max_position_embeddings = max_position_embeddings\n        self.type_vocab_size = type_vocab_size\n        self.initializer_range = initializer_range\n        self.layer_norm_eps = layer_norm_eps\n        self.gradient_checkpointing = gradient_checkpointing\n        self.position_embedding_type = position_embedding_type\n        self.use_cache = use_cache\n        self.conv_subsample_channel = conv_subsample_channel\n        self.conv_kernel_size = conv_kernel_size\n\n\nclass ConvSubSample(nn.Module):\n    config: ConformerConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.conv1 = nn.Conv(features=self.config.conv_subsample_channel,\n                             kernel_size=(3, 3),\n                             strides=(2, 2),\n                             dtype=self.dtype)\n        self.conv2 = nn.Conv(features=self.config.conv_subsample_channel,\n                             kernel_size=(3, 3),\n                             strides=(2, 2),\n                             dtype=self.dtype)\n        self.dense = nn.Dense(features=self.config.hidden_size,\n                              dtype=self.dtype)\n        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)\n\n    def __call__(self, x, deterministic: bool = True):\n        x = self.conv1(x)\n        x = nn.relu(x)\n        x = self.conv2(x)\n        x = nn.relu(x)\n        x = x.reshape((x.shape[0], x.shape[1], -1))\n        x = self.dense(x)\n        x = self.dropout(x, deterministic=deterministic)\n        return x\n\n\nclass FFNModule(nn.Module):\n    config: ConformerConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                       dtype=self.dtype)\n        self.dense_1 = nn.Dense(self.config.intermediate_size, dtype=self.dtype)\n        self.act = nn.swish\n        self.dropout_1 = nn.Dropout(rate=self.config.hidden_dropout_prob)\n        self.dense_2 = nn.Dense(self.config.hidden_size, dtype=self.dtype)\n        self.dropout_2 = nn.Dropout(rate=self.config.hidden_dropout_prob)\n\n    def __call__(self, inputs, deterministic: bool = True):\n        outputs = self.layer_norm(inputs)\n        outputs = self.dense_1(outputs)\n        outputs = self.act(outputs)\n        outputs = self.dropout_1(outputs, deterministic=deterministic)\n        outputs = self.dense_2(outputs)\n        outputs = self.dropout_2(outputs, deterministic=deterministic)\n        return 0.5 * outputs + inputs\n\n\nclass ConvModule(nn.Module):\n    config: ConformerConfig\n    dtype: jnp.dtype = jnp.float32\n\n    @nn.compact\n    def __call__(self, inputs, deterministic: bool = True, train: bool = True):\n        outputs = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                               dtype=self.dtype)(inputs)\n        B, T, E = outputs.shape\n        outputs = outputs.reshape((B, T, 1, E))\n        outputs = nn.Conv(features=self.config.hidden_size * 2,\n                          kernel_size=(1, 1),\n                          strides=(1, 1),\n                          dtype=self.dtype)(outputs)\n        outputs = nn.glu(outputs)\n        outputs = nn.Conv(features=self.config.hidden_size,\n                          kernel_size=(self.config.conv_kernel_size, 1),\n                          strides=(1, 1),\n                          feature_group_count=self.config.hidden_size,\n                          dtype=self.dtype)(outputs)\n        outputs = nn.BatchNorm(use_running_average=not train,\n                               momentum=0.9,\n                               epsilon=1e-5,\n                               dtype=self.dtype)(outputs)\n        outputs = nn.swish(outputs)\n        outputs = nn.Conv(features=self.config.hidden_size,\n                          kernel_size=(1, 1),\n                          strides=(1, 1),\n                          dtype=self.dtype)(outputs)\n        outputs = outputs.reshape((B, T, E))\n        outputs = nn.Dropout(rate=self.config.hidden_dropout_prob)(\n            outputs, deterministic=deterministic)\n        return outputs + inputs\n\n\nclass MultiHeadSelfAttentionModule(nn.Module):\n    config: ConformerConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                       dtype=self.dtype)\n        self.qvk_combined = nn.Dense(\n            self.config.hidden_size * 3,\n            dtype=self.dtype,\n            kernel_init=jax.nn.initializers.normal(\n                self.config.initializer_range),\n        )\n        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)\n        self.out_dense = nn.Dense(self.config.hidden_size,\n                                  dtype=self.dtype,\n                                  kernel_init=jax.nn.initializers.normal(\n                                      self.config.initializer_range))\n\n        if self.config.hidden_size % self.config.num_attention_heads != 0:\n            raise ValueError(\n                f\"`hidden_size`: {self.config.hidden_size} has to be a multiple of `num_attention_heads`: {self.config.num_attention_heads}\"\n            )\n\n    def __call__(self,\n                 inputs,\n                 pos_encoding,\n                 attention_mask,\n                 deterministic=True):\n        outputs = self.layer_norm(inputs)\n        outputs = outputs + pos_encoding\n\n        head_dim = self.config.hidden_size // self.config.num_attention_heads\n\n        qvk_combined_states = self.qvk_combined(outputs)\n        qvk_combined_states = qvk_combined_states.reshape(\n            qvk_combined_states.shape[:2] + (-1, 3))\n        query_states, value_states, key_states = jnp.split(qvk_combined_states,\n                                                           3,\n                                                           axis=3)\n        query_states = query_states.reshape(outputs.shape[:2] +\n                                            (self.config.num_attention_heads,\n                                             head_dim))\n        value_states = value_states.reshape(outputs.shape[:2] +\n                                            (self.config.num_attention_heads,\n                                             head_dim))\n        key_states = key_states.reshape(outputs.shape[:2] +\n                                        (self.config.num_attention_heads,\n                                         head_dim))\n\n        # Convert the boolean attention mask to an attention bias.\n        if attention_mask is not None:\n            # attention mask in the form of attention bias\n            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))\n            attention_bias = lax.select(\n                attention_mask > 0,\n                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),\n                jnp.full(attention_mask.shape, -1e10).astype(self.dtype),\n            )\n        else:\n            attention_bias = None\n\n        dropout_rng = None\n        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:\n            dropout_rng = self.make_rng(\"dropout\")\n\n        attn_weights = nn.attention.dot_product_attention_weights(\n            query_states,\n            key_states,\n            bias=attention_bias,\n            dropout_rng=dropout_rng,\n            dropout_rate=self.config.attention_probs_dropout_prob,\n            broadcast_dropout=True,\n            deterministic=deterministic,\n            dtype=self.dtype,\n            precision=None,\n        )\n\n        attn_output = jnp.einsum(\"...hqk,...khd->...qhd\", attn_weights,\n                                 value_states)\n        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))\n\n        outputs = self.out_dense(attn_output)\n        outputs = self.dropout(outputs, deterministic=deterministic)\n        return outputs + inputs\n\n\nclass ConformerLayer(nn.Module):\n    config: ConformerConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.ffn_1 = FFNModule(config=self.config, dtype=self.dtype)\n        self.mhsa = MultiHeadSelfAttentionModule(config=self.config,\n                                                 dtype=self.dtype)\n        self.conv = ConvModule(config=self.config, dtype=self.dtype)\n        self.ffn_2 = FFNModule(config=self.config, dtype=self.dtype)\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                       dtype=self.dtype)\n\n    def __call__(\n        self,\n        inputs,\n        pos_encoding,\n        attention_mask,\n        deterministic: bool = True,\n        train: bool = True,\n    ):\n        outputs = self.ffn_1(inputs, deterministic=deterministic)\n        outputs = self.mhsa(outputs,\n                            pos_encoding,\n                            attention_mask,\n                            deterministic=deterministic)\n        outputs = self.conv(outputs, deterministic=deterministic, train=train)\n        outputs = self.ffn_2(outputs, deterministic=deterministic)\n        outputs = self.layer_norm(outputs)\n        return outputs\n\n\nclass ConformerForASRModule(nn.Module):\n    \"\"\"\n    Conformer for automatic speech recognition.\n    \"\"\"\n    config: ConformerConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.conv_subsample = ConvSubSample(config=self.config,\n                                            dtype=self.dtype)\n        self.layers = [\n            ConformerLayer(config=self.config, name=str(i), dtype=self.dtype)\n            for i in range(self.config.num_hidden_layers)\n        ]\n        self.decoder = nn.Dense(self.config.vocab_size, dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_frames,\n        attention_mask,\n        deterministic: bool = True,\n        train: bool = True,\n    ):\n        # Model\n        hidden_states = self.conv_subsample(input_frames)\n        pos_encoding = jnp.ones(\n            (1, hidden_states.shape[1], hidden_states.shape[2]))\n\n        for layer in self.layers:\n            hidden_states = layer(hidden_states,\n                                  pos_encoding,\n                                  attention_mask,\n                                  deterministic=deterministic,\n                                  train=train)\n\n        logits = self.decoder(hidden_states)\n\n        return logits\n"
  },
  {
    "path": "alpa/model/gpt_model.py",
    "content": "# flake8: noqa\n\"\"\"Model definition of GPT. Modified from bert_model.py. \"\"\"\n# TODO(lmzheng): Test this GPT implementation:\n# https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt2/modeling_flax_gpt2.py\n\nfrom functools import partial\nfrom typing import Callable, Optional, Tuple\n\nimport numpy as np\n\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\n\nfrom alpa.model.bert_model import BertConfig, FlaxBertModule, FlaxMaskedLMOutput\nfrom alpa.model.model_util import TrainState\n\n\nclass FlaxGPTForLMModule(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n    bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros\n\n    def setup(self):\n        self.transformers = FlaxBertModule(config=self.config,\n                                           add_pooling_layer=False,\n                                           dtype=self.dtype)\n\n        if self.config.tie_word_embeddings:\n            self.decoder = None\n        else:\n            self.decoder = nn.Dense(self.config.vocab_size,\n                                    dtype=self.dtype,\n                                    use_bias=False)\n        self.decoder_bias = self.param(\"bias\", self.bias_init,\n                                       (self.config.vocab_size,))\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask,\n        token_type_ids,\n        position_ids,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        # Model\n        outputs = self.transformers(\n            input_ids,\n            attention_mask,\n            token_type_ids,\n            position_ids,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        if self.config.tie_word_embeddings:\n            if self.dtype == jnp.float16:\n                shared_embedding = self.transformers.embeddings.word_embeddings.embedding_fp16\n            else:\n                shared_embedding = self.transformers.variables[\"params\"][\n                    \"embeddings\"][\"word_embeddings\"][\"embedding\"]\n            assert self.decoder is None\n            logits = hidden_states @ shared_embedding.T\n        else:\n            assert self.decoder is not None\n            logits = self.decoder(hidden_states)\n\n        logits += jnp.asarray(self.decoder_bias, self.dtype)\n\n        # Compute the prediction scores\n        if not return_dict:\n            return (logits,) + outputs[1:]\n\n        return FlaxMaskedLMOutput(\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\ndef test_gpt_lm():\n    batch_size = 64\n    seq_len = 64\n    hidden_size = 128\n    num_attention_heads = 4\n    num_hidden_layers = 2\n    vocab_size = 1024\n\n    @partial(jax.jit, static_argnums=(2,))\n    def train_step(optimizer, batch, apply_func):\n\n        def loss_func(params):\n            rngs = {\"dropout\": batch[\"rng\"]}\n            logits = apply_func(params,\n                                batch[\"input_ids\"],\n                                batch[\"attention_mask\"],\n                                batch[\"token_type_ids\"],\n                                batch[\"position_ids\"],\n                                rngs=rngs)[0]\n            label_mask = jnp.where(batch[\"labels\"] > 0, 1.0, 0.0)\n            labels = jax.nn.one_hot(batch[\"labels\"], logits.shape[-1])\n            loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1),\n                            axis=-1)\n            loss = (label_mask * loss).sum() / label_mask.sum()\n            return loss\n\n        grad = jax.grad(loss_func)(optimizer.target)\n        new_optimizer = optimizer.apply_gradient(grad)\n        return new_optimizer\n\n    # Init model and optimizer\n    input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    token_type_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n\n    model = FlaxGPTForLMModule(\n        BertConfig(\n            vocab_size=vocab_size,\n            hidden_size=hidden_size,\n            num_attention_heads=num_attention_heads,\n            intermediate_size=hidden_size * 4,\n            num_hidden_layers=num_hidden_layers,\n            type_vocab_size=0,\n        ))\n    rngkey = jax.random.PRNGKey(0)\n    params = model.init(rngkey, input_ids, attention_mask, token_type_ids,\n                        position_ids)\n    optimizer = optim.GradientDescent(1e-2).create(params)\n\n    # JIT compile\n    train_step(\n        optimizer, {\n            \"input_ids\": input_ids,\n            \"attention_mask\": attention_mask,\n            \"token_type_ids\": token_type_ids,\n            \"position_ids\": position_ids,\n            \"labels\": labels,\n            \"rng\": rngkey\n        }, model.apply)\n\n\nif __name__ == \"__main__\":\n    test_gpt_lm()\n"
  },
  {
    "path": "alpa/model/model_util.py",
    "content": "# flake8: noqa\nfrom collections import OrderedDict\nfrom dataclasses import fields\nimport functools\nfrom typing import Any, Callable, Optional, Tuple, Optional, Union, Sequence\n\nfrom alpa.api import value_and_grad\nimport flax\nfrom flax.training import train_state, dynamic_scale as dynamic_scale_lib\nfrom flax.training.dynamic_scale import DynamicScaleResult\nfrom flax import struct\nimport numpy as np\nimport jax\nfrom jax import lax\nimport jax.numpy as jnp\nimport jaxlib.xla_extension as jax_xla\nimport optax\n\nArray = Any\n\n\ndef is_tensor(x):\n    \"\"\"\n    Tests if ``x`` is a :obj:`torch.Tensor`, :obj:`tf.Tensor`, obj:`jaxlib.xla_extension.DeviceArray` or\n    :obj:`np.ndarray`.\n    \"\"\"\n    #if is_torch_fx_proxy(x):\n    #    return True\n    #if is_torch_available():\n    #    import torch\n\n    #    if isinstance(x, torch.Tensor):\n    #        return True\n    #if is_tf_available():\n    #    import tensorflow as tf\n\n    #    if isinstance(x, tf.Tensor):\n    #        return True\n\n    #if is_flax_available():\n    if True:\n        import jaxlib.xla_extension as jax_xla\n        from jax.core import Tracer\n\n        if isinstance(x, (jax_xla.DeviceArray, Tracer)):\n            return True\n\n    return isinstance(x, np.ndarray)\n\n\nclass ModelOutput(OrderedDict):\n    \"\"\"\n    Base class for all model outputs as dataclass. Has a ``__getitem__`` that allows indexing by integer or slice (like\n    a tuple) or strings (like a dictionary) that will ignore the ``None`` attributes. Otherwise behaves like a regular\n    python dictionary.\n    .. warning::\n        You can't unpack a :obj:`ModelOutput` directly. Use the :meth:`~transformers.file_utils.ModelOutput.to_tuple`\n        method to convert it to a tuple before.\n    \"\"\"\n\n    def __post_init__(self):\n        class_fields = fields(self)\n\n        # Safety and consistency checks\n        assert len(class_fields), f\"{self.__class__.__name__} has no fields.\"\n        assert all(\n            field.default is None for field in class_fields[1:]\n        ), f\"{self.__class__.__name__} should not have more than one required field.\"\n\n        first_field = getattr(self, class_fields[0].name)\n        other_fields_are_none = all(\n            getattr(self, field.name) is None for field in class_fields[1:])\n\n        if other_fields_are_none and not is_tensor(first_field):\n            try:\n                iterator = iter(first_field)\n                first_field_iterator = True\n            except TypeError:\n                first_field_iterator = False\n\n            # if we provided an iterator as first field and the iterator is a (key, value) iterator\n            # set the associated fields\n            if first_field_iterator:\n                for element in iterator:\n                    if (not isinstance(element, (list, tuple)) or\n                            not len(element) == 2 or\n                            not isinstance(element[0], str)):\n                        break\n                    setattr(self, element[0], element[1])\n                    if element[1] is not None:\n                        self[element[0]] = element[1]\n            elif first_field is not None:\n                self[class_fields[0].name] = first_field\n        else:\n            for field in class_fields:\n                v = getattr(self, field.name)\n                if v is not None:\n                    self[field.name] = v\n\n    def __delitem__(self, *args, **kwargs):\n        raise Exception(\n            f\"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.\"\n        )\n\n    def setdefault(self, *args, **kwargs):\n        raise Exception(\n            f\"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.\"\n        )\n\n    def pop(self, *args, **kwargs):\n        raise Exception(\n            f\"You cannot use ``pop`` on a {self.__class__.__name__} instance.\")\n\n    def update(self, *args, **kwargs):\n        raise Exception(\n            f\"You cannot use ``update`` on a {self.__class__.__name__} instance.\"\n        )\n\n    def __getitem__(self, k):\n        if isinstance(k, str):\n            inner_dict = {k: v for (k, v) in self.items()}\n            return inner_dict[k]\n        else:\n            return self.to_tuple()[k]\n\n    def __setattr__(self, name, value):\n        if name in self.keys() and value is not None:\n            # Don't call self.__setitem__ to avoid recursion errors\n            super().__setitem__(name, value)\n        super().__setattr__(name, value)\n\n    def __setitem__(self, key, value):\n        # Will raise a KeyException if needed\n        super().__setitem__(key, value)\n        # Don't call self.__setattr__ to avoid recursion errors\n        super().__setattr__(key, value)\n\n    def to_tuple(self) -> Tuple[Any]:\n        \"\"\"\n        Convert self to a tuple containing all the attributes/keys that are not ``None``.\n        \"\"\"\n        return tuple(self[k] for k in self.keys())\n\n\n@flax.struct.dataclass\nclass FlaxBaseModelOutput(ModelOutput):\n    \"\"\"\n    Base class for model's outputs, with potential hidden states and attentions.\n    Args:\n        last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):\n            Sequence of hidden-states at the output of the last layer of the model.\n        hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):\n            Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each\n            layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.\n            Hidden-states of the model at the output of each layer plus the initial embedding outputs.\n        attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):\n            Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,\n            sequence_length, sequence_length)`.\n            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n            heads.\n    \"\"\"\n\n    last_hidden_state: jax_xla.DeviceArray = None\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n\n\n@flax.struct.dataclass\nclass FlaxBaseModelOutputWithPooling(ModelOutput):\n    \"\"\"\n    Base class for model's outputs that also contains a pooling of the last hidden states.\n    Args:\n        last_hidden_state (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, hidden_size)`):\n            Sequence of hidden-states at the output of the last layer of the model.\n        pooler_output (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, hidden_size)`):\n            Last layer hidden-state of the first token of the sequence (classification token) further processed by a\n            Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence\n            prediction (classification) objective during pretraining.\n        hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):\n            Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each\n            layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.\n            Hidden-states of the model at the output of each layer plus the initial embedding outputs.\n        attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):\n            Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,\n            sequence_length, sequence_length)`.\n            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n            heads.\n    \"\"\"\n\n    last_hidden_state: jax_xla.DeviceArray = None\n    pooler_output: jax_xla.DeviceArray = None\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n\n\n@flax.struct.dataclass\nclass FlaxBertForPreTrainingOutput(ModelOutput):\n    \"\"\"\n    Output type of :class:`~transformers.BertForPreTraining`.\n    Args:\n        prediction_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):\n            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).\n        seq_relationship_logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, 2)`):\n            Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation\n            before SoftMax).\n        hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):\n            Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each\n            layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.\n            Hidden-states of the model at the output of each layer plus the initial embedding outputs.\n        attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):\n            Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,\n            sequence_length, sequence_length)`.\n            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n            heads.\n    \"\"\"\n\n    prediction_logits: jax_xla.DeviceArray = None\n    seq_relationship_logits: jax_xla.DeviceArray = None\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n\n\n@flax.struct.dataclass\nclass FlaxMaskedLMOutput(ModelOutput):\n    \"\"\"\n    Base class for masked language models outputs.\n    Args:\n        logits (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):\n            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).\n        hidden_states (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):\n            Tuple of :obj:`jax_xla.DeviceArray` (one for the output of the embeddings + one for the output of each\n            layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`.\n            Hidden-states of the model at the output of each layer plus the initial embedding outputs.\n        attentions (:obj:`tuple(jax_xla.DeviceArray)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):\n            Tuple of :obj:`jax_xla.DeviceArray` (one for each layer) of shape :obj:`(batch_size, num_heads,\n            sequence_length, sequence_length)`.\n            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n            heads.\n    \"\"\"\n\n    logits: jax_xla.DeviceArray = None\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n\n\n@flax.struct.dataclass\nclass FlaxSequenceClassifierOutput(ModelOutput):\n    \"\"\"\n    Base class for outputs of sentence classification models.\n    Args:\n        logits (`jnp.ndarray` of shape `(batch_size, config.num_labels)`):\n            Classification (or regression if config.num_labels==1) scores (before SoftMax).\n        hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):\n            Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape\n            `(batch_size, sequence_length, hidden_size)`.\n            Hidden-states of the model at the output of each layer plus the initial embedding outputs.\n        attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):\n            Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length,\n            sequence_length)`.\n            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n            heads.\n    \"\"\"\n\n    logits: jnp.ndarray = None\n    hidden_states: Optional[Tuple[jnp.ndarray]] = None\n    attentions: Optional[Tuple[jnp.ndarray]] = None\n\n\ndef softmax_cross_entropy(logits, labels):\n    return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)\n\n\nclass TrainState(train_state.TrainState):\n    \"\"\"This is an extended version of flax.training.train_state.TrainState.\n\n    This class wraps the logic for creating the master weight copy in\n    mixed precision training.\n    \"\"\"\n    master_copy: flax.core.FrozenDict[str, Any]\n    dynamic_scale: Optional[dynamic_scale_lib.DynamicScale]\n\n    def apply_gradients(self, *, grads, **kwargs):\n        \"\"\"Updates `step`, `params`, `opt_state` and `**kwargs` in return value.\n        Note that internally this function calls `.tx.update()` followed by a call\n        to `optax.apply_updates()` to update `params` and `opt_state`.\n        Args:\n          grads: Gradients that have the same pytree structure as `.params`.\n          **kwargs: Additional dataclass attributes that should be `.replace()`-ed.\n        Returns:\n          An updated instance of `self` with `step` incremented by one, `params`\n          and `opt_state` updated by applying `grads`, and additional attributes\n          replaced as specified by `kwargs`.\n        \"\"\"\n        if self.master_copy is None:\n            master_params = self.params\n        else:\n            master_params = self.master_copy\n\n        updates, new_opt_state = self.tx.update(grads, self.opt_state,\n                                                master_params)\n        new_master_params = optax.apply_updates(master_params, updates)\n\n        if self.master_copy is None:\n            new_master_copy = None\n            new_params = new_master_params\n        else:\n            new_master_copy = new_master_params\n            new_params = jax.tree_util.tree_map(\n                lambda x: jnp.asarray(x, dtype=jnp.float16), new_master_params)\n\n            # A hack to make the donation works perfectly in gradient accumulation:\n            # We need the accumulate_grad to take the old params as input.\n            new_params_flat, tree = jax.tree_util.tree_flatten(new_params)\n            old_params_flat, _ = jax.tree_util.tree_flatten(self.params)\n            new_params_flat = [\n                x + 0.0 * y for x, y in zip(new_params_flat, old_params_flat)\n            ]\n            new_params = jax.tree_util.tree_unflatten(tree, new_params_flat)\n\n        return self.replace(\n            step=self.step + 1,\n            params=new_params,\n            master_copy=new_master_copy,\n            opt_state=new_opt_state,\n            **kwargs,\n        )\n\n    @classmethod\n    def create(cls, *, apply_fn, params, tx, use_master_copy=False, **kwargs):\n        \"\"\"Creates a new instance with `step=0` and initialized `opt_state`.\"\"\"\n        if use_master_copy:\n            master_copy = jax.tree_util.tree_map(\n                lambda x: jnp.asarray(x, dtype=jnp.float32), params)\n            params = jax.tree_util.tree_map(\n                lambda x: jnp.asarray(x, dtype=jnp.float16), params)\n            opt_state = tx.init(master_copy)\n        else:\n            master_copy = None\n            opt_state = tx.init(params)\n\n        return cls(\n            step=np.array(0, dtype=np.int32),\n            apply_fn=apply_fn,\n            params=params,\n            master_copy=master_copy,\n            tx=tx,\n            opt_state=opt_state,\n            **kwargs,\n        )\n\n    @classmethod\n    def create_aval(cls,\n                    *,\n                    apply_fn,\n                    params,\n                    tx,\n                    use_master_copy=False,\n                    **kwargs):\n        \"\"\"Creates a new instance with `step=0` and initialized `opt_state`.\"\"\"\n        opt_state = jax.eval_shape(tx.init, params)\n\n        if use_master_copy:\n            master_copy = params\n            params = jax.eval_shape(\n                lambda p: jax.tree_util.tree_map(\n                    lambda x: jnp.asarray(x, dtype=jnp.float16), p), params)\n        else:\n            master_copy = None\n\n        return cls(\n            step=np.array(0, dtype=np.int32),\n            apply_fn=apply_fn,\n            params=params,\n            master_copy=master_copy,\n            tx=tx,\n            opt_state=opt_state,\n            **kwargs,\n        )\n\n\nclass DynamicScale(struct.PyTreeNode):\n    \"\"\"This is the same as flax.optim.DynamicScale, except that\n  jax.value_and_grad is replaced by alpa.value_and_grad.\n\n  Dynamic loss scaling for mixed precision gradients.\n\n  For many models gradient computations in float16 will result in numerical\n  issues because small/large gradients being flushed to zero/infinity.\n  Dynamic loss scaling is an algorithm that aims to find the largest scalar\n  multiple for which the gradient does not overflow. This way the risk of\n  underflow is minimized.\n\n  the `value_and_grad` method mimicks `jax.value_and_grad`. Beside the loss\n  and gradients it also ouputs and updated `DynamicScale` instance with the\n  current loss scale factor. This method also returns a boolean value indicating\n  whether the gradients are finite.\n\n  Example::\n\n    def loss_fn(p):\n      return jnp.asarray(p, jnp.float16) ** 2\n    p = jnp.array(1., jnp.float32)\n\n    dyn_scale = optim.DynamicScale(growth_interval=10)\n    compute_grad = jax.jit(lambda ds, p: ds.value_and_grad(loss_fn)(p))\n    for _ in range(100):\n      dyn_scale, is_fin, loss, grad = compute_grad(dyn_scale, p)\n      p += jnp.where(is_fin, 0.01 * grad, 0.)\n      print(loss)\n\n  Jax currently cannot execute conditionals efficiently on GPUs therefore we\n  selectifly ignore the gradient update using `jax.numpy.where` in case of\n  non-finite gradients.\n\n  Attributes:\n    growth_factor: how much to grow the scalar after a period of finite\n      gradients (default: 2.).\n    backoff_factor: how much to shrink the scalar after a non-finite gradient\n      (default: 0.5).\n    growth_interval: after how many steps of finite gradients the scale should\n      be increased (default: 2000).\n    fin_steps: indicates how many gradient steps in a row have been finite.\n    scale: the current scale by which the loss is multiplied.\n  \"\"\"\n    growth_factor: float = struct.field(pytree_node=False, default=2.0)\n    backoff_factor: float = struct.field(pytree_node=False, default=0.5)\n    growth_interval: int = struct.field(pytree_node=False, default=2000)\n    fin_steps: Array = 0\n    scale: Array = 65536.0\n\n    def value_and_grad(\n        self,\n        fun: Callable[..., Any],\n        argnums: Union[int, Sequence[int]] = 0,\n        has_aux: bool = False,\n        axis_name: Optional[str] = None,\n    ) -> Callable[..., DynamicScaleResult]:\n        \"\"\"Wrapper around `jax.value_and_grad`.\n\n    Args:\n      fun: Function to be differentiated. Its arguments at positions specified\n        by ``argnums`` should be arrays, scalars, or standard Python containers.\n        It should return a scalar (which includes arrays with shape ``()``\n        but not arrays with shape ``(1,)`` etc.)\n      argnums: Optional, integer or sequence of integers. Specifies which\n        positional argument(s) to differentiate with respect to (default 0).\n      has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where\n        the first element is considered the output of the mathematical function\n        to be differentiated and the second element is auxiliary data.\n        Default False.\n      axis_name: If an axis is given the gradients will be averaged across\n        replicas (default: None).\n    Returns:\n      A function that takes the same arguments as `fun` and\n      returns a DynamicScaleResult\n    \"\"\"\n\n        @functools.wraps(fun)\n        def loss_wrapper(*args):\n            aux = fun(*args)\n            if has_aux:\n                return (self.scale * aux[0], aux[1])\n            else:\n                return self.scale * aux\n\n        grad_fn = value_and_grad(loss_wrapper, argnums, has_aux)\n\n        def grad_fn_wrapper(*args):\n            aux, grad = grad_fn(*args)\n            aux = (aux[0] / self.scale, aux[1]) if has_aux else aux / self.scale\n\n            grad = jax.tree_util.tree_map(\n                lambda g: jnp.asarray(g, jnp.float32) / self.scale, grad)\n            if axis_name is not None:\n                grad = lax.pmean(grad, axis_name)\n\n            finite = jnp.array(True)\n            for g in jax.tree_util.tree_leaves(grad):\n                finite &= jnp.all(lax.is_finite(g))\n\n            grow = self.fin_steps == self.growth_interval\n            fin_scale = jnp.where(grow & finite,\n                                  self.scale * self.growth_factor, self.scale)\n            inf_scale = self.scale * self.backoff_factor\n            new_scale = jnp.where(finite, fin_scale, inf_scale)\n            new_fin_steps = jnp.where(grow | (~finite), 0, self.fin_steps + 1)\n\n            new_self = self.replace(fin_steps=new_fin_steps, scale=new_scale)\n            return DynamicScaleResult(new_self, finite, aux, grad)\n\n        return grad_fn_wrapper\n"
  },
  {
    "path": "alpa/model/moe.py",
    "content": "# flake8: noqa\n\"\"\"Model definition of Mixture of Expert model.\"\"\"\nfrom dataclasses import dataclass\nfrom functools import partial\nfrom typing import Callable, Optional, Tuple\n\nimport numpy as np\n\nimport flax\nfrom flax import linen as nn\nfrom flax.training import train_state\nfrom flax.linen.attention import dot_product_attention_weights\nfrom flax.linen.initializers import lecun_normal\nimport jax\nfrom jax import lax\nimport jax.numpy as jnp\nfrom jax.nn import one_hot\n\nfrom alpa.model.bert_model import (FlaxBaseModelOutput,\n                                   FlaxBaseModelOutputWithPooling,\n                                   FlaxBertAttention, FlaxBertEmbeddings,\n                                   FlaxBertIntermediate, FlaxBertLayer,\n                                   FlaxBertOutput, FlaxMaskedLMOutput)\nfrom alpa.model.model_util import TrainState\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\n\n\nclass MoEConfig:\n\n    def __init__(\n            self,\n            vocab_size=30522,\n            hidden_size=768,\n            num_hidden_layers=12,\n            num_attention_heads=12,\n            intermediate_size=3072,\n            hidden_act=\"gelu\",\n            hidden_dropout_prob=0.1,\n            attention_probs_dropout_prob=0.1,\n            max_position_embeddings=512,\n            type_vocab_size=0,\n            initializer_range=0.02,\n            layer_norm_eps=1e-12,\n            gradient_checkpointing=False,\n            position_embedding_type=\"absolute\",\n            use_cache=True,\n            tie_word_embeddings=True,\n            expert_group_size=8192,  # S in the paper\n            expert_number=128,  # E in the paper\n            add_manual_pipeline_markers=False,\n            pipeline_mp_size=0,\n            **kwargs):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.num_hidden_layers = num_hidden_layers\n        self.num_attention_heads = num_attention_heads\n        self.hidden_act = hidden_act\n        self.intermediate_size = intermediate_size\n        self.hidden_dropout_prob = hidden_dropout_prob\n        self.attention_probs_dropout_prob = attention_probs_dropout_prob\n        self.max_position_embeddings = max_position_embeddings\n        self.type_vocab_size = type_vocab_size\n        self.initializer_range = initializer_range\n        self.layer_norm_eps = layer_norm_eps\n        self.gradient_checkpointing = gradient_checkpointing\n        self.position_embedding_type = position_embedding_type\n        self.use_cache = use_cache\n        self.expert_group_size = expert_group_size\n        self.expert_number = expert_number\n        self.tie_word_embeddings = tie_word_embeddings\n        self.add_manual_pipeline_markers = add_manual_pipeline_markers\n        self.pipeline_mp_size = pipeline_mp_size\n\n\ndef top2_gating_dummy(gates):  # [GSE] -> [GSEC, GSEC]\n    \"\"\"A temporary dummy implementation.\"\"\"\n    G, S, E = gates.shape\n    C = 2 * S // E\n    gates = jnp.reshape(gates, (G, S, E, 1))\n    combined_weights = jnp.broadcast_to(gates, (G, S, E, C))\n    dispatch_mask = combined_weights\n    return combined_weights, dispatch_mask\n\n\ndef top2_gating(gates):  # GSE -> (GSEC, GSEC)\n    \"\"\"Modified from https://github.com/tensorflow/lingvo/blob/\n    b885b91d4b5361c971a998b810fc58f83baa625f/lingvo/core/gshard_layers.py#L1787\n\n    # TODO(lmzheng): add the auxiliary loss. add 'random' policy for the second expert.\n    \"\"\"\n    G, S, E = gates.shape\n    C = 2 * S // E\n\n    mask_dtype = jnp.int32\n\n    index_1 = jnp.argmax(gates, axis=-1)  # GS\n    mask_1 = one_hot(index_1, E, dtype=mask_dtype)  # GSE\n    gate_1 = jnp.einsum(\"GSE,GSE->GS\", gates, mask_1)  # GS\n\n    gates_without_top_1 = gates * (1 - mask_1)\n\n    index_2 = jnp.argmax(gates_without_top_1, axis=-1)  # GSE\n    mask_2 = one_hot(index_2, E, dtype=mask_dtype)\n    gate_2 = jnp.einsum(\"GSE,GSE->GS\", gates_without_top_1, mask_2)\n\n    pos_1 = jnp.cumsum(mask_1, axis=-2) - mask_1\n    mask_1 *= pos_1 < C\n    pos_1 = jnp.einsum(\"GSE,GSE->GS\", pos_1, mask_1)\n\n    mask_1_count = jnp.sum(mask_1, axis=-2)\n    mask_1_flat = jnp.sum(mask_1, axis=-1)\n\n    pos_2 = (jnp.cumsum(mask_2, axis=-2) - mask_2) + jnp.expand_dims(\n        mask_1_count, -2)\n    mask_2 *= pos_2 < C\n    pos_2 = jnp.einsum(\"GSE,GSE->GS\", pos_2, mask_2)\n\n    mask_2_flat = jnp.sum(mask_2, axis=-1)\n\n    gate_1 *= mask_1_flat\n    gate_2 *= mask_2_flat\n\n    denom = gate_1 + gate_2\n    denom = jnp.where(denom > 0, denom, jnp.ones_like(denom))\n    gate_1 /= denom\n    gate_2 /= denom\n\n    a = jnp.expand_dims(gate_1 * mask_1_flat, -1) * one_hot(\n        index_1, E, dtype=gates.dtype)\n    b = one_hot(pos_1, C, dtype=gates.dtype)\n    first_part_of_combine_tensor = jnp.einsum(\"GSE,GSC->GSEC\", a, b)\n\n    a = jnp.expand_dims(gate_2 * mask_2_flat, -1) * one_hot(\n        index_2, E, dtype=gates.dtype)\n    b = one_hot(pos_2, C, dtype=gates.dtype)\n    second_part_of_combine_tensor = jnp.einsum(\"GSE,GSC->GSEC\", a, b)\n\n    combined_tensor = first_part_of_combine_tensor + second_part_of_combine_tensor\n    dispatch_tensor = combined_tensor.astype(jnp.bool_)\n\n    return combined_tensor, dispatch_tensor\n\n\nclass FlaxPositionWiseMoELayer(nn.Module):\n    config: MoEConfig\n    kernel_init: Callable[..., np.ndarray] = lecun_normal()\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    @nn.compact\n    def __call__(self, inputs):\n        S = self.config.expert_group_size\n        M = self.config.hidden_size\n        H = self.config.intermediate_size\n        E = self.config.expert_number\n\n        wg = self.param(\"wg\", self.kernel_init, (\n            M,\n            E,\n        ))\n        wi = self.param(\"wi\", self.kernel_init, (\n            E,\n            M,\n            H,\n        ))\n        wo = self.param(\"wo\", self.kernel_init, (\n            E,\n            H,\n            M,\n        ))\n\n        inputs = jnp.asarray(inputs, self.dtype)\n        wg = jnp.asarray(wg, self.dtype)\n        wi = jnp.asarray(wi, self.dtype)\n        wo = jnp.asarray(wo, self.dtype)\n\n        reshaped_inputs = jnp.reshape(inputs, (-1, S, M))\n        gates = jax.nn.softmax(jnp.einsum(\"GSM,ME->GSE\", reshaped_inputs, wg))\n        combined_weights, dispatch_mask = top2_gating(gates)\n        dispatched_expert_inputs = jnp.einsum(\"GSEC,GSM->EGCM\", dispatch_mask,\n                                              reshaped_inputs)\n        h = jnp.einsum(\"EGCM,EMH->EGCH\", dispatched_expert_inputs, wi)\n        h = nn.relu(h)\n        expert_outputs = jnp.einsum(\"EGCH,EHM->GECM\", h, wo)\n        outputs = jnp.einsum(\"GSEC,GECM->GSM\", combined_weights, expert_outputs)\n        outputs = jnp.reshape(outputs, inputs.shape)\n        return outputs\n\n\nclass FlaxMoELayer(nn.Module):\n    config: MoEConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        self.attention = FlaxBertAttention(self.config, dtype=self.dtype)\n        self.moe = FlaxPositionWiseMoELayer(self.config, dtype=self.dtype)\n        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)\n        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                      dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 attention_mask,\n                 deterministic: bool = True,\n                 output_attentions: bool = False):\n\n        if not isinstance(deterministic, bool):\n            # A temporary hack to walkaround the bug in flax.nn.remat\n            # Using `nn.remat(concrete=True)` works for regular use cases\n            # (e.g., train_step, init) but does not work for init_dummy.\n            # So we still need this hack.\n            deterministic = True\n            output_attentions = True\n\n        attention_outputs = self.attention(hidden_states,\n                                           attention_mask,\n                                           deterministic=deterministic,\n                                           output_attentions=output_attentions)\n        attention_output = attention_outputs[0]\n\n        hidden_states = self.moe(attention_output)\n        hidden_states = self.dropout(hidden_states, deterministic=deterministic)\n        hidden_states = self.LayerNorm(hidden_states + attention_output)\n\n        outputs = (hidden_states,)\n\n        if output_attentions:\n            outputs += (attention_outputs[1],)\n        return outputs\n\n\nclass FlaxMoELayerCollection(nn.Module):\n    config: MoEConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n\n        if self.config.gradient_checkpointing:\n            trans_func = partial(nn.remat, concrete=True)\n        else:\n            trans_func = lambda x: x\n\n        assert self.config.num_hidden_layers % 2 == 0\n        layers = []\n        for i in range(self.config.num_hidden_layers):\n            if i % 2 == 0:\n                layers.append(\n                    trans_func(FlaxMoELayer)(self.config,\n                                             name=str(i),\n                                             dtype=self.dtype))\n            else:\n                layers.append(\n                    trans_func(FlaxBertLayer)(self.config,\n                                              name=str(i),\n                                              dtype=self.dtype))\n        self.layers = layers\n\n    def __call__(\n        self,\n        hidden_states,\n        attention_mask,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        all_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n\n        for i, layer in enumerate(self.layers):\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            layer_outputs = layer(hidden_states,\n                                  attention_mask,\n                                  deterministic=deterministic,\n                                  output_attentions=output_attentions)\n\n            hidden_states = layer_outputs[0]\n\n            if output_attentions:\n                all_attentions += (layer_outputs[1],)\n\n            if self.config.add_manual_pipeline_markers:\n                layers_per_stage = self.config.num_hidden_layers // self.config.pipeline_mp_size\n                assert self.config.num_hidden_layers % self.config.pipeline_mp_size == 0\n                if i % layers_per_stage == layers_per_stage - 1 and i != len(\n                        self.layers) - 1:\n                    mark_pipeline_boundary()\n\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        outputs = (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in outputs if v is not None)\n\n        return FlaxBaseModelOutput(last_hidden_state=hidden_states,\n                                   hidden_states=all_hidden_states,\n                                   attentions=all_attentions)\n\n\nclass FlaxMoEEncoder(nn.Module):\n    config: MoEConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n\n    def setup(self):\n        self.layer = FlaxMoELayerCollection(self.config, dtype=self.dtype)\n\n    def __call__(\n        self,\n        hidden_states,\n        attention_mask,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        return self.layer(\n            hidden_states,\n            attention_mask,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n\nclass FlaxMoEModule(nn.Module):\n    config: MoEConfig\n    dtype: jnp.dtype = jnp.float32  # the dtype of the computation\n    add_pooling_layer: bool = True\n\n    def setup(self):\n        self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype)\n        self.encoder = FlaxMoEEncoder(self.config, dtype=self.dtype)\n        if self.add_pooling_layer:\n            self.pooler = FlaxBertPooler(self.config, dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask,\n        token_type_ids,\n        position_ids,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        hidden_states = self.embeddings(input_ids,\n                                        token_type_ids,\n                                        position_ids,\n                                        attention_mask,\n                                        deterministic=deterministic)\n        outputs = self.encoder(\n            hidden_states,\n            attention_mask,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n        hidden_states = outputs[0]\n        pooled = self.pooler(hidden_states) if self.add_pooling_layer else None\n\n        if not return_dict:\n            # if pooled is None, don't return it\n            if pooled is None:\n                return (hidden_states,) + outputs[1:]\n            return (hidden_states, pooled) + outputs[1:]\n\n        return FlaxBaseModelOutputWithPooling(\n            last_hidden_state=hidden_states,\n            pooler_output=pooled,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n\n\nclass FlaxMoEForLMModule(nn.Module):\n    config: MoEConfig\n    dtype: jnp.dtype = jnp.float32\n    bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros\n\n    def setup(self):\n        self.transformers = FlaxMoEModule(config=self.config,\n                                          add_pooling_layer=False,\n                                          dtype=self.dtype)\n\n        if self.config.tie_word_embeddings:\n            self.decoder = None\n        else:\n            self.decoder = nn.Dense(self.config.vocab_size,\n                                    dtype=self.dtype,\n                                    use_bias=False)\n        self.decoder_bias = self.param(\"bias\", self.bias_init,\n                                       (self.config.vocab_size,))\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask,\n        token_type_ids,\n        position_ids,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n    ):\n        # Model\n        outputs = self.transformers(\n            input_ids,\n            attention_mask,\n            token_type_ids,\n            position_ids,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n        )\n\n        hidden_states = outputs[0]\n        if self.config.tie_word_embeddings:\n            shared_embedding = self.transformers.variables[\"params\"][\n                \"embeddings\"][\"word_embeddings\"][\"embedding\"]\n            assert self.decoder is None\n            logits = hidden_states @ shared_embedding.T\n        else:\n            assert self.decoder is not None\n            logits = self.decoder(hidden_states)\n\n        logits += jnp.asarray(self.decoder_bias, self.dtype)\n\n        # Compute the prediction scores\n        if not return_dict:\n            return (logits,) + outputs[1:]\n\n        return FlaxMaskedLMOutput(\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n        )\n"
  },
  {
    "path": "alpa/model/unet_2d.py",
    "content": "\"\"\"\nThis file is modified from multiple files in\nhttps://github.com/huggingface/diffusers/blob/main/src/diffusers/models\n\"\"\"\n\n# Copyright 2022 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\nimport math\nfrom typing import Tuple, Union\nimport flax\nimport flax.linen as nn\nimport jax\nfrom jax.experimental.maps import FrozenDict\nimport jax.numpy as jnp\n\nfrom alpa import mark_pipeline_boundary\nfrom alpa.model.bert_model import BertConfig\nfrom alpa.model.model_util import ModelOutput\n\n\n# FIXME: not from bert config\nclass UNet2DConfig(BertConfig):\n\n    def __init__(self,\n                 *,\n                 sample_size: int = 32,\n                 in_channels: int = 4,\n                 out_channels: int = 4,\n                 layers_per_block: int = 2,\n                 freq_shift: int = 0,\n                 num_groups: int = 4,\n                 **kwargs):\n        super().__init__(**kwargs)\n        self.sample_size = sample_size\n        self.in_channels = in_channels,\n        self.out_channels = out_channels\n        self.layers_per_block = layers_per_block\n        self.freq_shift = freq_shift\n        # Group Norm factor\n        self.num_groups = num_groups\n\n\n@flax.struct.dataclass\nclass FlaxUNet2DConditionOutput(ModelOutput):\n    \"\"\"\n    Args:\n        sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`):\n            Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.\n    \"\"\"\n\n    sample: jnp.ndarray\n\n\n##### Embeddings - Do not add pipeline marker at this level\ndef get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):\n    \"\"\"\n    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.\n    :param timesteps: a 1-D tensor of N indices, one per batch element.\n                      These may be fractional.\n    :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the\n    embeddings. :return: an [N x dim] tensor of positional embeddings.\n    \"\"\"\n    half_dim = embedding_dim // 2\n    emb = math.log(10000) / (half_dim - freq_shift)\n    emb = jnp.exp(jnp.arange(half_dim) * -emb)\n    emb = timesteps[:, None] * emb[None, :]\n    emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)\n    return emb\n\n\nclass FlaxTimestepEmbedding(nn.Module):\n    r\"\"\"\n    Time step Embedding Module. Learns embeddings for input time steps.\n    Args:\n        time_embed_dim (`int`, *optional*, defaults to `32`):\n                Time step embedding dimension\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n                Parameters `dtype`\n    \"\"\"\n    time_embed_dim: int = 32\n    dtype: jnp.dtype = jnp.float32\n\n    @nn.compact\n    def __call__(self, temb):\n        temb = nn.Dense(self.time_embed_dim, dtype=self.dtype,\n                        name=\"linear_1\")(temb)\n        temb = nn.silu(temb)\n        temb = nn.Dense(self.time_embed_dim, dtype=self.dtype,\n                        name=\"linear_2\")(temb)\n        return temb\n\n\nclass FlaxTimesteps(nn.Module):\n    r\"\"\"\n    Wrapper Module for sinusoidal Time step Embeddings as described in https://arxiv.org/abs/2006.11239\n    Args:\n        dim (`int`, *optional*, defaults to `32`):\n                Time step embedding dimension\n    \"\"\"\n    dim: int = 32\n    freq_shift: float = 1\n\n    @nn.compact\n    def __call__(self, timesteps):\n        return get_sinusoidal_embeddings(timesteps,\n                                         self.dim,\n                                         freq_shift=self.freq_shift)\n\n\n##### ResNetBlocks - Do not add pipeline marker at this level\nclass FlaxUpsample2D(nn.Module):\n    out_channels: int\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.conv = nn.Conv(\n            self.out_channels,\n            kernel_size=(3, 3),\n            strides=(1, 1),\n            padding=((1, 1), (1, 1)),\n            dtype=self.dtype,\n        )\n\n    def __call__(self, hidden_states):\n        batch, height, width, channels = hidden_states.shape\n        hidden_states = jax.image.resize(\n            hidden_states,\n            shape=(batch, height * 2, width * 2, channels),\n            method=\"nearest\",\n        )\n        hidden_states = self.conv(hidden_states)\n        return hidden_states\n\n\nclass FlaxDownsample2D(nn.Module):\n    out_channels: int\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.conv = nn.Conv(\n            self.out_channels,\n            kernel_size=(3, 3),\n            strides=(2, 2),\n            padding=((1, 1), (1, 1)),  # padding=\"VALID\",\n            dtype=self.dtype,\n        )\n\n    def __call__(self, hidden_states):\n        # pad = ((0, 0), (0, 1), (0, 1), (0, 0))  # pad height and width dim\n        # hidden_states = jnp.pad(hidden_states, pad_width=pad)\n        hidden_states = self.conv(hidden_states)\n        return hidden_states\n\n\nclass FlaxResnetBlock2D(nn.Module):\n    in_channels: int\n    config: UNet2DConfig\n    out_channels: int = None\n    use_nin_shortcut: bool = None\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        out_channels = (self.in_channels\n                        if self.out_channels is None else self.out_channels)\n\n        self.norm1 = nn.GroupNorm(num_groups=self.config.num_groups,\n                                  epsilon=1e-5)\n        self.conv1 = nn.Conv(\n            out_channels,\n            kernel_size=(3, 3),\n            strides=(1, 1),\n            padding=((1, 1), (1, 1)),\n            dtype=self.dtype,\n        )\n\n        self.time_emb_proj = nn.Dense(out_channels, dtype=self.dtype)\n\n        self.norm2 = nn.GroupNorm(num_groups=self.config.num_groups,\n                                  epsilon=1e-5)\n        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)\n        self.conv2 = nn.Conv(\n            out_channels,\n            kernel_size=(3, 3),\n            strides=(1, 1),\n            padding=((1, 1), (1, 1)),\n            dtype=self.dtype,\n        )\n\n        use_nin_shortcut = (self.in_channels != out_channels\n                            if self.use_nin_shortcut is None else\n                            self.use_nin_shortcut)\n\n        self.conv_shortcut = None\n        if use_nin_shortcut:\n            self.conv_shortcut = nn.Conv(\n                out_channels,\n                kernel_size=(1, 1),\n                strides=(1, 1),\n                padding=\"VALID\",\n                dtype=self.dtype,\n            )\n\n    def __call__(self, hidden_states, temb, deterministic=True):\n        residual = hidden_states\n        hidden_states = self.norm1(hidden_states)\n        hidden_states = nn.swish(hidden_states)\n        hidden_states = self.conv1(hidden_states)\n\n        temb = self.time_emb_proj(nn.swish(temb))\n        temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)\n        hidden_states = hidden_states + temb\n\n        hidden_states = self.norm2(hidden_states)\n        hidden_states = nn.swish(hidden_states)\n        hidden_states = self.dropout(hidden_states, deterministic)\n        hidden_states = self.conv2(hidden_states)\n\n        if self.conv_shortcut is not None:\n            residual = self.conv_shortcut(residual)\n\n        return hidden_states + residual\n\n\n##### Attentions - Do not add pipeline marker at this level\nclass FlaxAttentionBlock(nn.Module):\n    r\"\"\"\n    A Flax multi-head attention module as described in: https://arxiv.org/abs/1706.03762\n    Parameters:\n        query_dim (:obj:`int`):\n            Input hidden states dimension\n        heads (:obj:`int`, *optional*, defaults to 8):\n            Number of heads\n        dim_head (:obj:`int`, *optional*, defaults to 64):\n            Hidden states dimension inside each head\n        dropout (:obj:`float`, *optional*, defaults to 0.0):\n            Dropout rate\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    query_dim: int\n    heads: int = 8\n    dim_head: int = 64\n    dropout: float = 0.0\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        inner_dim = self.dim_head * self.heads\n        self.scale = self.dim_head**-0.5\n\n        # Weights were exported with old names {to_q, to_k, to_v, to_out}\n        self.query = nn.Dense(inner_dim,\n                              use_bias=False,\n                              dtype=self.dtype,\n                              name=\"to_q\")\n        self.key = nn.Dense(inner_dim,\n                            use_bias=False,\n                            dtype=self.dtype,\n                            name=\"to_k\")\n        self.value = nn.Dense(inner_dim,\n                              use_bias=False,\n                              dtype=self.dtype,\n                              name=\"to_v\")\n\n        self.proj_attn = nn.Dense(self.query_dim,\n                                  dtype=self.dtype,\n                                  name=\"to_out_0\")\n\n    def reshape_heads_to_batch_dim(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.heads\n        tensor = tensor.reshape(batch_size, seq_len, head_size,\n                                dim // head_size)\n        tensor = jnp.transpose(tensor, (0, 2, 1, 3))\n        tensor = tensor.reshape(batch_size * head_size, seq_len,\n                                dim // head_size)\n        return tensor\n\n    def reshape_batch_dim_to_heads(self, tensor):\n        batch_size, seq_len, dim = tensor.shape\n        head_size = self.heads\n        tensor = tensor.reshape(batch_size // head_size, head_size, seq_len,\n                                dim)\n        tensor = jnp.transpose(tensor, (0, 2, 1, 3))\n        tensor = tensor.reshape(batch_size // head_size, seq_len,\n                                dim * head_size)\n        return tensor\n\n    def __call__(self, hidden_states, context=None, deterministic=True):\n        context = hidden_states if context is None else context\n\n        query_proj = self.query(hidden_states)\n        key_proj = self.key(context)\n        value_proj = self.value(context)\n\n        query_states = self.reshape_heads_to_batch_dim(query_proj)\n        key_states = self.reshape_heads_to_batch_dim(key_proj)\n        value_states = self.reshape_heads_to_batch_dim(value_proj)\n\n        # compute attentions\n        attention_scores = jnp.einsum(\"b i d, b j d->b i j\", query_states,\n                                      key_states)\n        attention_scores = attention_scores * self.scale\n        attention_probs = nn.softmax(attention_scores, axis=2)\n\n        # attend to values\n        hidden_states = jnp.einsum(\"b i j, b j d -> b i d\", attention_probs,\n                                   value_states)\n        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)\n        hidden_states = self.proj_attn(hidden_states)\n        return hidden_states\n\n\nclass FlaxBasicTransformerBlock(nn.Module):\n    r\"\"\"\n    A Flax transformer block layer with `GLU` (Gated Linear Unit) activation function as described in:\n    https://arxiv.org/abs/1706.03762\n    Parameters:\n        dim (:obj:`int`):\n            Inner hidden states dimension\n        n_heads (:obj:`int`):\n            Number of heads\n        d_head (:obj:`int`):\n            Hidden states dimension inside each head\n        dropout (:obj:`float`, *optional*, defaults to 0.0):\n            Dropout rate\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    dim: int\n    n_heads: int\n    d_head: int\n    dropout: float = 0.0\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        # self attention\n        self.attn1 = FlaxAttentionBlock(self.dim,\n                                        self.n_heads,\n                                        self.d_head,\n                                        self.dropout,\n                                        dtype=self.dtype)\n        # cross attention\n        self.attn2 = FlaxAttentionBlock(self.dim,\n                                        self.n_heads,\n                                        self.d_head,\n                                        self.dropout,\n                                        dtype=self.dtype)\n        self.ff = FlaxGluFeedForward(dim=self.dim,\n                                     dropout=self.dropout,\n                                     dtype=self.dtype)\n        self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)\n        self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)\n        self.norm3 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype)\n\n    def __call__(self, hidden_states, context, deterministic=True):\n        # self attention\n        residual = hidden_states\n        hidden_states = self.attn1(self.norm1(hidden_states),\n                                   deterministic=deterministic)\n        hidden_states = hidden_states + residual\n\n        # cross attention\n        residual = hidden_states\n        hidden_states = self.attn2(self.norm2(hidden_states),\n                                   context,\n                                   deterministic=deterministic)\n        hidden_states = hidden_states + residual\n\n        # feed forward\n        residual = hidden_states\n        hidden_states = self.ff(self.norm3(hidden_states),\n                                deterministic=deterministic)\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\nclass FlaxSpatialTransformer(nn.Module):\n    r\"\"\"\n    A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:\n    https://arxiv.org/pdf/1506.02025.pdf\n    Parameters:\n        in_channels (:obj:`int`):\n            Input number of channels\n        n_heads (:obj:`int`):\n            Number of heads\n        d_head (:obj:`int`):\n            Hidden states dimension inside each head\n        depth (:obj:`int`, *optional*, defaults to 1):\n            Number of transformers block\n        dropout (:obj:`float`, *optional*, defaults to 0.0):\n            Dropout rate\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    in_channels: int\n    n_heads: int\n    d_head: int\n    depth: int = 1\n    dropout: float = 0.0\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)\n\n        inner_dim = self.n_heads * self.d_head\n        self.proj_in = nn.Conv(\n            inner_dim,\n            kernel_size=(1, 1),\n            strides=(1, 1),\n            padding=\"VALID\",\n            dtype=self.dtype,\n        )\n\n        self.transformer_blocks = [\n            FlaxBasicTransformerBlock(inner_dim,\n                                      self.n_heads,\n                                      self.d_head,\n                                      dropout=self.dropout,\n                                      dtype=self.dtype)\n            for _ in range(self.depth)\n        ]\n\n        self.proj_out = nn.Conv(\n            inner_dim,\n            kernel_size=(1, 1),\n            strides=(1, 1),\n            padding=\"VALID\",\n            dtype=self.dtype,\n        )\n\n    def __call__(self, hidden_states, context, deterministic=True):\n        batch, height, width, channels = hidden_states.shape\n        residual = hidden_states\n        hidden_states = self.norm(hidden_states)\n        hidden_states = self.proj_in(hidden_states)\n\n        hidden_states = hidden_states.reshape(batch, height * width, channels)\n\n        for transformer_block in self.transformer_blocks:\n            hidden_states = transformer_block(hidden_states,\n                                              context,\n                                              deterministic=deterministic)\n\n        hidden_states = hidden_states.reshape(batch, height, width, channels)\n\n        hidden_states = self.proj_out(hidden_states)\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\nclass FlaxGluFeedForward(nn.Module):\n    r\"\"\"\n    Flax module that encapsulates two Linear layers separated by a gated linear unit activation from:\n    https://arxiv.org/abs/2002.05202\n    Parameters:\n        dim (:obj:`int`):\n            Inner hidden states dimension\n        dropout (:obj:`float`, *optional*, defaults to 0.0):\n            Dropout rate\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    dim: int\n    dropout: float = 0.0\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        # The second linear layer needs to be called\n        # net_2 for now to match the index of the Sequential layer\n        self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype)\n        self.net_2 = nn.Dense(self.dim, dtype=self.dtype)\n\n    def __call__(self, hidden_states, deterministic=True):\n        hidden_states = self.net_0(hidden_states)\n        hidden_states = self.net_2(hidden_states)\n        return hidden_states\n\n\nclass FlaxGEGLU(nn.Module):\n    r\"\"\"\n    Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from\n    https://arxiv.org/abs/2002.05202.\n    Parameters:\n        dim (:obj:`int`):\n            Input hidden states dimension\n        dropout (:obj:`float`, *optional*, defaults to 0.0):\n            Dropout rate\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    dim: int\n    dropout: float = 0.0\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        inner_dim = self.dim * 4\n        self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype)\n\n    def __call__(self, hidden_states, deterministic=True):\n        hidden_states = self.proj(hidden_states)\n        hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=2)\n        return hidden_linear * nn.gelu(hidden_gelu)\n\n\n##### UNetBlocks - Add pipeline marker at this level\nclass FlaxCrossAttnDownBlock2D(nn.Module):\n    r\"\"\"\n    Cross Attention 2D Downsizing block - original architecture from Unet transformers:\n    https://arxiv.org/abs/2103.06104\n    Parameters:\n        in_channels (:obj:`int`):\n            Input channels\n        out_channels (:obj:`int`):\n            Output channels\n        dropout (:obj:`float`, *optional*, defaults to 0.0):\n            Dropout rate\n        num_layers (:obj:`int`, *optional*, defaults to 1):\n            Number of attention blocks layers\n        attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):\n            Number of attention heads of each spatial transformer block\n        add_downsample (:obj:`bool`, *optional*, defaults to `True`):\n            Whether to add downsampling layer before each final output\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    in_channels: int\n    out_channels: int\n    config: UNet2DConfig\n    add_downsample: bool = True\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        resnets = []\n        attentions = []\n\n        for i in range(self.config.layers_per_block):\n            in_channels = self.in_channels if i == 0 else self.out_channels\n\n            res_block = FlaxResnetBlock2D(\n                in_channels=in_channels,\n                config=self.config,\n                out_channels=self.out_channels,\n                dtype=self.dtype,\n            )\n            resnets.append(res_block)\n\n            attn_block = FlaxSpatialTransformer(\n                in_channels=self.out_channels,\n                n_heads=self.config.num_attention_heads,\n                d_head=self.out_channels // self.config.num_attention_heads,\n                depth=1,\n                dtype=self.dtype,\n            )\n            attentions.append(attn_block)\n\n        self.resnets = resnets\n        self.attentions = attentions\n\n        if self.add_downsample:\n            self.downsamplers_0 = FlaxDownsample2D(self.out_channels,\n                                                   dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 temb,\n                 encoder_hidden_states,\n                 deterministic=True):\n        output_states = ()\n\n        for idx, (resnet, attn) in enumerate(zip(self.resnets,\n                                                 self.attentions)):\n            hidden_states = resnet(hidden_states,\n                                   temb,\n                                   deterministic=deterministic)\n            hidden_states = attn(hidden_states,\n                                 encoder_hidden_states,\n                                 deterministic=deterministic)\n            if self.config.add_manual_pipeline_markers:\n                if idx != self.config.layers_per_block - 1:\n                    mark_pipeline_boundary()\n            output_states += (hidden_states,)\n\n        if self.add_downsample:\n            hidden_states = self.downsamplers_0(hidden_states)\n            output_states += (hidden_states,)\n        if self.config.add_manual_pipeline_markers:\n            mark_pipeline_boundary()\n\n        return hidden_states, output_states\n\n\nclass FlaxDownBlock2D(nn.Module):\n    r\"\"\"\n    Flax 2D downsizing block\n\n    Parameters:\n        in_channels (:obj:`int`):\n            Input channels\n        out_channels (:obj:`int`):\n            Output channels\n        config (:obj:`UNet2DConfig`):\n            UNet Global Config\n        add_downsample (:obj:`bool`, *optional*, defaults to `True`):\n            Whether to add downsampling layer before each final output\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    in_channels: int\n    out_channels: int\n    config: UNet2DConfig\n    add_downsample: bool = True\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        resnets = []\n\n        for i in range(self.config.layers_per_block):\n            in_channels = self.in_channels if i == 0 else self.out_channels\n\n            res_block = FlaxResnetBlock2D(\n                in_channels=in_channels,\n                config=self.config,\n                out_channels=self.out_channels,\n                dtype=self.dtype,\n            )\n            resnets.append(res_block)\n        self.resnets = resnets\n\n        if self.add_downsample:\n            self.downsamplers_0 = FlaxDownsample2D(self.out_channels,\n                                                   dtype=self.dtype)\n\n    def __call__(self, hidden_states, temb, deterministic=True):\n        output_states = ()\n\n        for idx, resnet in enumerate(self.resnets):\n            hidden_states = resnet(hidden_states,\n                                   temb,\n                                   deterministic=deterministic)\n            if self.config.add_manual_pipeline_markers:\n                if idx != self.config.layers_per_block - 1:\n                    mark_pipeline_boundary()\n            output_states += (hidden_states,)\n\n        if self.add_downsample:\n            hidden_states = self.downsamplers_0(hidden_states)\n            output_states += (hidden_states,)\n        if self.config.add_manual_pipeline_markers:\n            # delaying the boundary here reduces the communciation memory\n            mark_pipeline_boundary()\n\n        return hidden_states, output_states\n\n\nclass FlaxCrossAttnUpBlock2D(nn.Module):\n    r\"\"\"\n    Cross Attention 2D Upsampling block - original architecture from Unet transformers:\n    https://arxiv.org/abs/2103.06104\n    Parameters:\n        in_channels (:obj:`int`):\n            Input channels\n        out_channels (:obj:`int`):\n            Output channels\n        dropout (:obj:`float`, *optional*, defaults to 0.0):\n            Dropout rate\n        num_layers (:obj:`int`, *optional*, defaults to 1):\n            Number of attention blocks layers\n        attn_num_head_channels (:obj:`int`, *optional*, defaults to 1):\n            Number of attention heads of each spatial transformer block\n        add_upsample (:obj:`bool`, *optional*, defaults to `True`):\n            Whether to add upsampling layer before each final output\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    in_channels: int\n    out_channels: int\n    prev_output_channel: int\n    config: UNet2DConfig\n    add_upsample: bool = True\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        resnets = []\n        attentions = []\n\n        for i in range(self.config.layers_per_block):\n            res_skip_channels = self.in_channels if (\n                i == self.config.layers_per_block - 1) else self.out_channels\n            resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels\n\n            res_block = FlaxResnetBlock2D(\n                in_channels=resnet_in_channels + res_skip_channels,\n                config=self.config,\n                out_channels=self.out_channels,\n                dtype=self.dtype,\n            )\n            resnets.append(res_block)\n\n            attn_block = FlaxSpatialTransformer(\n                in_channels=self.out_channels,\n                n_heads=self.config.num_attention_heads,\n                d_head=self.out_channels // self.config.num_attention_heads,\n                depth=1,\n                dtype=self.dtype,\n            )\n            attentions.append(attn_block)\n\n        self.resnets = resnets\n        self.attentions = attentions\n\n        if self.add_upsample:\n            self.upsamplers_0 = FlaxUpsample2D(self.out_channels,\n                                               dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 res_hidden_states_tuple,\n                 temb,\n                 encoder_hidden_states,\n                 deterministic=True):\n        for resnet, attn in zip(self.resnets, self.attentions):\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = jnp.concatenate((hidden_states, res_hidden_states),\n                                            axis=-1)\n\n            hidden_states = resnet(hidden_states,\n                                   temb,\n                                   deterministic=deterministic)\n            hidden_states = attn(hidden_states,\n                                 encoder_hidden_states,\n                                 deterministic=deterministic)\n            if self.config.add_manual_pipeline_markers:\n                mark_pipeline_boundary()\n\n        if self.add_upsample:\n            hidden_states = self.upsamplers_0(hidden_states)\n\n        return hidden_states\n\n\nclass FlaxUpBlock2D(nn.Module):\n    r\"\"\"\n    Flax 2D upsampling block\n\n    Parameters:\n        in_channels (:obj:`int`):\n            Input channels\n        out_channels (:obj:`int`):\n            Output channels\n        prev_output_channel (:obj:`int`):\n            Output channels from the previous block\n        config (:obj:`UNet2DConfig`):\n            UNet Global Config\n        add_downsample (:obj:`bool`, *optional*, defaults to `True`):\n            Whether to add downsampling layer before each final output\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    in_channels: int\n    out_channels: int\n    prev_output_channel: int\n    config: UNet2DConfig\n    add_upsample: bool = True\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        resnets = []\n\n        for i in range(self.config.layers_per_block + 1):\n            res_skip_channels = self.in_channels if (\n                i == self.config.layers_per_block) else self.out_channels\n            resnet_in_channels = (self.prev_output_channel\n                                  if i == 0 else self.out_channels)\n\n            res_block = FlaxResnetBlock2D(\n                in_channels=resnet_in_channels + res_skip_channels,\n                config=self.config,\n                out_channels=self.out_channels,\n                dtype=self.dtype,\n            )\n            resnets.append(res_block)\n\n        self.resnets = resnets\n\n        if self.add_upsample:\n            self.upsamplers_0 = FlaxUpsample2D(self.out_channels,\n                                               dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 res_hidden_states_tuple,\n                 temb,\n                 deterministic=True):\n        for resnet in self.resnets:\n            # pop res hidden states\n            res_hidden_states = res_hidden_states_tuple[-1]\n            res_hidden_states_tuple = res_hidden_states_tuple[:-1]\n            hidden_states = jnp.concatenate((hidden_states, res_hidden_states),\n                                            axis=-1)\n\n            hidden_states = resnet(hidden_states,\n                                   temb,\n                                   deterministic=deterministic)\n            if self.config.add_manual_pipeline_markers:\n                mark_pipeline_boundary()\n\n        if self.add_upsample:\n            hidden_states = self.upsamplers_0(hidden_states)\n        return hidden_states\n\n\nclass FlaxUNetMidBlock2DCrossAttn(nn.Module):\n    r\"\"\"\n    Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104\n    Parameters:\n        in_channels (:obj:`int`):\n            Input channels\n        config (:obj:`UNet2DConfig`):\n            UNet Global Config\n        num_layers (:obj:`int`, *optional*, defaults to 1):\n            Number of attention blocks layers\n        dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):\n            Parameters `dtype`\n    \"\"\"\n    in_channels: int\n    config: UNet2DConfig\n    num_layers: int = 1\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        # there is always at least one resnet\n        resnets = [\n            FlaxResnetBlock2D(\n                in_channels=self.in_channels,\n                config=self.config,\n                out_channels=self.in_channels,\n                dtype=self.dtype,\n            )\n        ]\n\n        attentions = []\n\n        for _ in range(self.num_layers):\n            attn_block = FlaxSpatialTransformer(\n                in_channels=self.in_channels,\n                n_heads=self.config.num_attention_heads,\n                d_head=self.in_channels // self.config.num_attention_heads,\n                depth=1,\n                dtype=self.dtype,\n            )\n            attentions.append(attn_block)\n\n            res_block = FlaxResnetBlock2D(\n                in_channels=self.in_channels,\n                config=self.config,\n                out_channels=self.in_channels,\n                dtype=self.dtype,\n            )\n            resnets.append(res_block)\n\n        self.resnets = resnets\n        self.attentions = attentions\n\n    def __call__(self,\n                 hidden_states,\n                 temb,\n                 encoder_hidden_states,\n                 deterministic=True):\n        hidden_states = self.resnets[0](hidden_states, temb)\n        for attn, resnet in zip(self.attentions, self.resnets[1:]):\n            hidden_states = attn(hidden_states,\n                                 encoder_hidden_states,\n                                 deterministic=deterministic)\n            if self.config.add_manual_pipeline_markers:\n                mark_pipeline_boundary()\n            hidden_states = resnet(hidden_states,\n                                   temb,\n                                   deterministic=deterministic)\n            if self.config.add_manual_pipeline_markers:\n                mark_pipeline_boundary()\n\n        return hidden_states\n\n\n##### UNet2D\nclass FlaxUNet2DConditionModel(nn.Module):\n    r\"\"\"\n    FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a\n    timestep and returns sample shaped output.\n    This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library\n    implements for all the models (such as downloading or saving, etc.)\n    Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)\n    subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to\n    general usage and behavior.\n    Finally, this model supports inherent JAX features such as:\n    - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)\n    - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)\n    - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)\n    - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)\n    Parameters:\n        config (:obj:`UNet2DConfig`):\n            UNet Global Config\n        down_block_types (`Tuple[str]`, *optional*, defaults to `(\"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"CrossAttnDownBlock2D\", \"DownBlock2D\")`):\n            The tuple of downsample blocks to use. The corresponding class names will be: \"FlaxCrossAttnDownBlock2D\",\n            \"FlaxCrossAttnDownBlock2D\", \"FlaxCrossAttnDownBlock2D\", \"FlaxDownBlock2D\"\n        up_block_types (`Tuple[str]`, *optional*, defaults to `(\"UpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\",)`):\n            The tuple of upsample blocks to use. The corresponding class names will be: \"FlaxUpBlock2D\",\n            \"FlaxCrossAttnUpBlock2D\", \"FlaxCrossAttnUpBlock2D\", \"FlaxCrossAttnUpBlock2D\"\n        block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):\n            The tuple of output channels for each block.\n        cross_attention_dim (`int`, *optional*, defaults to 768):\n            The dimension of the cross attention features.\n    \"\"\"\n\n    config: UNet2DConfig\n    down_block_types: Tuple[str] = (\n        \"CrossAttnDownBlock2D\",\n        \"CrossAttnDownBlock2D\",\n        \"CrossAttnDownBlock2D\",\n        \"DownBlock2D\",\n    )\n    up_block_types: Tuple[str] = (\"UpBlock2D\", \"CrossAttnUpBlock2D\",\n                                  \"CrossAttnUpBlock2D\", \"CrossAttnUpBlock2D\")\n    block_out_channels: Tuple[int] = (320, 640, 1280, 1280)\n    cross_attention_dim: int = 768\n    dtype: jnp.dtype = jnp.float32\n\n    def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:\n        # init input tensors\n        sample_shape = (1, self.config.in_channels, self.config.sample_size,\n                        self.config.sample_size)\n        sample = jnp.zeros(sample_shape, dtype=jnp.float32)\n        timesteps = jnp.ones((1,), dtype=jnp.int32)\n        encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim),\n                                          dtype=jnp.float32)\n\n        params_rng, dropout_rng = jax.random.split(rng)\n        rngs = {\"params\": params_rng, \"dropout\": dropout_rng}\n\n        return self.init(rngs, sample, timesteps,\n                         encoder_hidden_states)[\"params\"]\n\n    def setup(self):\n        block_out_channels = self.block_out_channels\n        time_embed_dim = block_out_channels[0] * 4\n\n        # input\n        self.conv_in = nn.Conv(\n            block_out_channels[0],\n            kernel_size=(3, 3),\n            strides=(1, 1),\n            padding=((1, 1), (1, 1)),\n            dtype=self.dtype,\n        )\n\n        # time\n        self.time_proj = FlaxTimesteps(block_out_channels[0],\n                                       freq_shift=self.config.freq_shift)\n        self.time_embedding = FlaxTimestepEmbedding(time_embed_dim,\n                                                    dtype=self.dtype)\n\n        # down\n        down_blocks = []\n        output_channel = block_out_channels[0]\n        for i, down_block_type in enumerate(self.down_block_types):\n            input_channel = output_channel\n            output_channel = block_out_channels[i]\n            is_final_block = i == len(block_out_channels) - 1\n\n            if down_block_type == \"CrossAttnDownBlock2D\":\n                down_block_cls = FlaxCrossAttnDownBlock2D\n            else:\n                down_block_cls = FlaxDownBlock2D\n            down_block = down_block_cls(\n                in_channels=input_channel,\n                out_channels=output_channel,\n                config=self.config,\n                add_downsample=not is_final_block,\n                dtype=self.dtype,\n            )\n\n            down_blocks.append(down_block)\n        self.down_blocks = down_blocks\n\n        # mid\n        self.mid_block = FlaxUNetMidBlock2DCrossAttn(\n            in_channels=block_out_channels[-1],\n            config=self.config,\n            dtype=self.dtype,\n        )\n\n        # up\n        up_blocks = []\n        reversed_block_out_channels = list(reversed(block_out_channels))\n        output_channel = reversed_block_out_channels[0]\n        for i, up_block_type in enumerate(self.up_block_types):\n            prev_output_channel = output_channel\n            output_channel = reversed_block_out_channels[i]\n            input_channel = reversed_block_out_channels[min(\n                i + 1,\n                len(block_out_channels) - 1)]\n\n            is_final_block = i == len(block_out_channels) - 1\n\n            if up_block_type == \"CrossAttnUpBlock2D\":\n                up_block_cls = FlaxCrossAttnUpBlock2D\n            else:\n                up_block_cls = FlaxUpBlock2D\n            up_block = up_block_cls(\n                in_channels=input_channel,\n                out_channels=output_channel,\n                prev_output_channel=prev_output_channel,\n                config=self.config,\n                add_upsample=not is_final_block,\n                dtype=self.dtype,\n            )\n\n            up_blocks.append(up_block)\n            prev_output_channel = output_channel\n        self.up_blocks = up_blocks\n\n        # out\n        self.conv_norm_out = nn.GroupNorm(num_groups=self.config.num_groups,\n                                          epsilon=1e-5)\n        self.conv_out = nn.Conv(\n            self.config.out_channels,\n            kernel_size=(3, 3),\n            strides=(1, 1),\n            padding=((1, 1), (1, 1)),\n            dtype=self.dtype,\n        )\n\n    def __call__(\n        self,\n        sample,\n        timesteps,\n        encoder_hidden_states,\n        return_dict: bool = True,\n        train: bool = False,\n    ) -> Union[FlaxUNet2DConditionOutput, Tuple]:\n        \"\"\"r\n        Args:\n            sample (`jnp.ndarray`): (channel, height, width) noisy inputs tensor\n            timestep (`jnp.ndarray` or `float` or `int`): timesteps\n            encoder_hidden_states (`jnp.ndarray`): (channel, height, width) encoder hidden states\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a\n                plain tuple.\n            train (`bool`, *optional*, defaults to `False`):\n                Use deterministic functions and disable dropout when not training.\n        Returns:\n            [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:\n            [`~models.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.\n            When returning a tuple, the first element is the sample tensor.\n        \"\"\"\n        # 1. time\n        if not isinstance(timesteps, jnp.ndarray):\n            timesteps = jnp.array([timesteps], dtype=jnp.int32)\n        elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:\n            timesteps = timesteps.astype(dtype=jnp.float32)\n            timesteps = jnp.expand_dims(timesteps, 0)\n\n        t_emb = self.time_proj(timesteps)\n        t_emb = self.time_embedding(t_emb)\n\n        # 2. pre-process\n        # (B, img_channel, sample_size, sample_size) -> (B, SS, SS, img_channel)\n        sample = jnp.transpose(sample, (0, 2, 3, 1))\n        # (B, SS, SS, block_out_channels[0])\n        sample = self.conv_in(sample)\n        if self.config.add_manual_pipeline_markers:\n            mark_pipeline_boundary()\n\n        # 3. down\n        down_block_res_samples = (sample,)\n        for down_block in self.down_blocks:\n            if isinstance(down_block, FlaxCrossAttnDownBlock2D):\n                sample, res_samples = down_block(sample,\n                                                 t_emb,\n                                                 encoder_hidden_states,\n                                                 deterministic=not train)\n            else:\n                sample, res_samples = down_block(sample,\n                                                 t_emb,\n                                                 deterministic=not train)\n\n            down_block_res_samples += res_samples\n\n        # 4. mid\n        sample = self.mid_block(sample,\n                                t_emb,\n                                encoder_hidden_states,\n                                deterministic=not train)\n\n        # 5. up\n        for up_block in self.up_blocks:\n            res_samples = down_block_res_samples[-(\n                self.config.layers_per_block + 1):]\n            down_block_res_samples = down_block_res_samples[:-(\n                self.config.layers_per_block + 1)]\n            if isinstance(up_block, FlaxCrossAttnUpBlock2D):\n                sample = up_block(\n                    sample,\n                    temb=t_emb,\n                    encoder_hidden_states=encoder_hidden_states,\n                    res_hidden_states_tuple=res_samples,\n                    deterministic=not train,\n                )\n            else:\n                sample = up_block(sample,\n                                  temb=t_emb,\n                                  res_hidden_states_tuple=res_samples,\n                                  deterministic=not train)\n\n        # 6. post-process\n        sample = self.conv_norm_out(sample)\n        sample = nn.silu(sample)\n        sample = self.conv_out(sample)\n        sample = jnp.transpose(sample, (0, 3, 1, 2))\n\n        if not return_dict:\n            return (sample,)\n\n        return FlaxUNet2DConditionOutput(sample=sample)\n\n\ndef get_unet_2d(sample_size,\n                down_block_types,\n                up_block_types,\n                block_out_channels,\n                in_channels=4,\n                out_channels=4,\n                dropout=0.0,\n                layers_per_block=2,\n                num_attention_heads=8,\n                freq_shift=0,\n                num_groups=4,\n                dtype=jnp.float32,\n                add_manual_pipeline_markers=True):\n    # Begin with Configs of Attention layers in the UNet_2D\n    hidden_act = \"gelu\"\n    hidden_size = block_out_channels[-1]\n    # Check block out channels: only the last does not do upsampling\n    assert block_out_channels[-1] == block_out_channels[-2]\n    cross_attention_dim = block_out_channels[-1]\n    config = UNet2DConfig(\n        hidden_size=hidden_size,\n        num_attention_heads=num_attention_heads,\n        intermediate_size=hidden_size * 4,\n        hidden_dropout_prob=dropout,\n        attention_probs_dropout_prob=dropout,\n        hidden_act=hidden_act,\n        add_manual_pipeline_markers=add_manual_pipeline_markers,\n        # UNet New configs\n        sample_size=sample_size,\n        in_channels=in_channels,\n        out_channels=out_channels,\n        layers_per_block=layers_per_block,\n        freq_shift=freq_shift,\n        num_groups=num_groups)\n    return FlaxUNet2DConditionModel(config,\n                                    down_block_types,\n                                    up_block_types,\n                                    block_out_channels,\n                                    cross_attention_dim=cross_attention_dim,\n                                    dtype=dtype)\n\n\nif __name__ == \"__main__\":\n    down_block_types: Tuple[str] = (\n        \"DownBlock2D\",\n        \"DownBlock2D\",\n        \"DownBlock2D\",\n        \"DownBlock2D\",\n    )\n    up_block_types: Tuple[str] = (\"UpBlock2D\", \"UpBlock2D\", \"UpBlock2D\",\n                                  \"UpBlock2D\")\n    block_out_channels: Tuple[int] = (32, 64, 128, 128)\n    channel = 3\n    sample_size = 24\n    model = get_unet_2d(sample_size,\n                        down_block_types,\n                        up_block_types,\n                        block_out_channels,\n                        cross_attention_dim=128)\n    rng = jax.random.PRNGKey(0)\n    batch = 5\n    sample = jnp.ones((batch, channel, sample_size, sample_size))\n    encoder_hidden_states = jnp.ones(\n        (batch, (sample_size // 2**(len(block_out_channels) - 1))**2,\n         block_out_channels[-1]))\n    timestep = 1\n    params = model.init(rng, sample, timestep, encoder_hidden_states)\n"
  },
  {
    "path": "alpa/model/wide_resnet.py",
    "content": "\"\"\"The definition of wide-resnet.\n\nModified from https://github.com/google/flax/blob/main/examples/imagenet/models.py.\nsee also: https://arxiv.org/pdf/1605.07146.pdf\n\"\"\"\n# Copyright 2021 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nfrom functools import partial\nfrom typing import Any, Callable, Sequence, Tuple\n\nfrom flax import linen as nn\nfrom flax.training import train_state, dynamic_scale as dynamic_scale_lib\nimport jax.numpy as jnp\n\nModuleDef = Any\n\n\nclass TrainState(train_state.TrainState):\n    batch_stats: Any\n    dynamic_scale: dynamic_scale_lib.DynamicScale\n\n\nclass ResNetBlock(nn.Module):\n    \"\"\"ResNet block.\"\"\"\n    filters: int\n    conv: ModuleDef\n    norm: ModuleDef\n    act: Callable\n    width_factor: int\n    strides: Tuple[int, int] = (1, 1)\n\n    @nn.compact\n    def __call__(\n        self,\n        x,\n    ):\n        assert self.width_factor == 1\n\n        residual = x\n        y = self.conv(self.filters, (3, 3), self.strides)(x)\n        y = self.norm()(y)\n        y = self.act(y)\n        y = self.conv(self.filters, (3, 3))(y)\n        y = self.norm(scale_init=nn.initializers.zeros)(y)\n\n        if residual.shape != y.shape:\n            residual = self.conv(self.filters, (1, 1),\n                                 self.strides,\n                                 name='conv_proj')(residual)\n            residual = self.norm(name='norm_proj')(residual)\n\n        return self.act(residual + y)\n\n\nclass BottleneckResNetBlock(nn.Module):\n    \"\"\"Bottleneck ResNet block.\"\"\"\n    filters: int\n    conv: ModuleDef\n    norm: ModuleDef\n    act: Callable\n    width_factor: int\n    strides: Tuple[int, int] = (1, 1)\n\n    @nn.compact\n    def __call__(self, x):\n        residual = x\n        y = self.conv(self.filters, (1, 1))(x)\n        y = self.norm()(y)\n        y = self.act(y)\n        y = self.conv(self.filters * self.width_factor, (3, 3), self.strides)(y)\n        y = self.norm()(y)\n        y = self.act(y)\n        y = self.conv(self.filters * 4, (1, 1))(y)\n        y = self.norm(scale_init=nn.initializers.zeros)(y)\n\n        if residual.shape != y.shape:\n            residual = self.conv(self.filters * 4, (1, 1),\n                                 self.strides,\n                                 name='conv_proj')(residual)\n            residual = self.norm(name='norm_proj')(residual)\n\n        return self.act(residual + y)\n\n\nclass ResNet(nn.Module):\n    \"\"\"ResNetV1.\"\"\"\n    stage_sizes: Sequence[int]\n    block_cls: ModuleDef\n    num_classes: int\n    num_filters: int\n    width_factor: int\n    dtype: Any = jnp.float32\n    act: Callable = nn.relu\n\n    @nn.compact\n    def __call__(self, x, train: bool = True):\n        conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)\n        norm = partial(nn.BatchNorm,\n                       use_running_average=not train,\n                       momentum=0.9,\n                       epsilon=1e-5,\n                       dtype=self.dtype)\n\n        x = conv(self.num_filters, (7, 7), (2, 2),\n                 padding=[(3, 3), (3, 3)],\n                 name='conv_init')(x)\n        x = norm(name='bn_init')(x)\n        x = nn.relu(x)\n        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')\n        for i, block_size in enumerate(self.stage_sizes):\n            for j in range(block_size):\n                strides = (2, 2) if i > 0 and j == 0 else (1, 1)\n                x = self.block_cls(self.num_filters * 2**i,\n                                   strides=strides,\n                                   conv=conv,\n                                   norm=norm,\n                                   width_factor=self.width_factor,\n                                   act=self.act)(x)\n        x = jnp.mean(x, axis=(1, 2))\n        x = nn.Dense(self.num_classes, dtype=self.dtype)(x)\n        x = jnp.asarray(x, self.dtype)\n        return x\n\n\nmodel_configs = {\n    0: {\n        \"stage_sizes\": [],\n        \"block_cls\": ResNetBlock\n    },\n    18: {\n        \"stage_sizes\": [2, 2, 2, 2],\n        \"block_cls\": ResNetBlock\n    },\n    34: {\n        \"stage_sizes\": [3, 4, 6, 3],\n        \"block_cls\": ResNetBlock\n    },\n    50: {\n        \"stage_sizes\": [3, 4, 6, 3],\n        \"block_cls\": BottleneckResNetBlock\n    },\n    101: {\n        \"stage_sizes\": [3, 4, 23, 3],\n        \"block_cls\": BottleneckResNetBlock\n    },\n    152: {\n        \"stage_sizes\": [3, 8, 36, 3],\n        \"block_cls\": BottleneckResNetBlock\n    },\n    200: {\n        \"stage_sizes\": [3, 24, 36, 3],\n        \"block_cls\": BottleneckResNetBlock\n    }\n}\n\n\ndef get_wide_resnet(num_layers, width_factor, num_filters, num_classes, dtype):\n    model_config = model_configs[num_layers]\n    model_config[\"width_factor\"] = width_factor\n    model_config[\"num_filters\"] = num_filters\n    model_config[\"num_classes\"] = num_classes\n    model_config[\"dtype\"] = dtype\n\n    return ResNet(**model_config)\n"
  },
  {
    "path": "alpa/monkey_patch.py",
    "content": "\"\"\"Monkey patch other python libraries.\"\"\"\n# pylint: disable=protected-access, unused-argument\nfrom functools import partial\n\nimport numpy as np\nimport jax\nfrom jax import core, lax, numpy as jnp\nfrom jax._src import dtypes, random as jax_src_random\nfrom jax._src.lib import xla_client as xc\nfrom jax._src.lib import xla_bridge as jax_src_lib_xla_bridge\nfrom jax._src.lib.mlir.dialects import mhlo\nfrom jax._src.lib.xla_bridge import get_backend as default_get_backend\nfrom jax.core import Primitive\nfrom jax.interpreters import pxla\nfrom jax.interpreters import xla, mlir\nfrom jax.interpreters.xla import xops\nimport flax\n\nfrom alpa.global_env import global_config, is_worker\n\n########################################\n##### Monkey patch the Jax backend\n########################################\n\noverride_backend = None\n\n\ndef set_override_backend(backend):\n    \"\"\"Enable the JAX backend monkey patch.\"\"\"\n    global override_backend\n    override_backend = backend\n\n\ndef override_get_backend(*args, **kwargs):\n    \"\"\"Override the `get_backend` in JAX to use PJRT backend managed by Alpa.\"\"\"\n    if override_backend is not None:\n        return override_backend\n    return default_get_backend(*args, **kwargs)\n\n\nif is_worker:\n    jax_src_lib_xla_bridge.get_backend = override_get_backend\n    jax.lib.xla_bridge.get_backend = override_get_backend\n\n########################################\n##### Monkey patch Jax\n########################################\n\n\n# Monkey patch random generator to use the stateful random generator.\n# This can simplify the computational graph for dropout.\ndef fast_uniform(key, shape=(), dtype=dtypes.float_, minval=0.0, maxval=1.0):\n    dtype = dtypes.canonicalize_dtype(dtype)\n    shape = core.as_named_shape(shape)\n    minval = jnp.asarray(minval, dtype)\n    maxval = jnp.asarray(maxval, dtype)\n    return lax.rng_uniform(minval, maxval, shape.positional)\n\n\ndef rng_normal(mu, sigma, shape):\n    \"\"\"Stateful PRNG generator. Experimental and its use is discouraged.\n\n    Returns random numbers following normal distribution with (mu, sigma)\n\n    You should use jax.random for most purposes; this function exists only for\n    niche use cases with special performance requirements.\n\n    This API may be removed at any time.\n    \"\"\"\n    return rng_normal_p.bind(mu, sigma, shape=tuple(shape))\n\n\ndef _rng_normal_abstract_eval(mu, sigma, *, shape):\n    if mu.dtype != sigma.dtype:\n        raise ValueError(\n            f\"Arguments to rng_normal must have identical dtypes, got \"\n            f\"{mu.dtype} and {sigma.dtype}.\")\n    if mu.shape != () or sigma.shape != ():\n        raise ValueError(f\"Arguments to rng_normal must be scalars; got shapes \"\n                         f\"{mu.shape} and {sigma.shape}.\")\n    return mu.update(shape=shape,\n                     dtype=mu.dtype,\n                     weak_type=(mu.weak_type and sigma.weak_type))\n\n\ndef _rng_normal_translation_rule(ctx, avals_in, avals_out, mu, sigma, *, shape):\n    c = ctx.builder\n    xla_shape = xc.Shape.array_shape(c.get_shape(mu).xla_element_type(), shape)\n    return [xops.RngNormal(mu, sigma, xla_shape)]\n\n\nrng_normal_p = Primitive(\"rng_normal\")\nrng_normal_p.def_impl(partial(xla.apply_primitive, rng_normal_p))\nrng_normal_p.def_abstract_eval(_rng_normal_abstract_eval)\nxla.register_translation(rng_normal_p, _rng_normal_translation_rule)\n\n\ndef _rng_normal_lowering(ctx, mu, sigma, *, shape):\n    aval_out, = ctx.avals_out\n    shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),\n                               canonicalize_types=False)\n    return mhlo.RngOp(mu, sigma, shape,\n                      mhlo.RngDistributionAttr.get(\"NORMAL\")).results\n\n\nmlir.register_lowering(rng_normal_p, _rng_normal_lowering)\n\n\ndef fast_normal(key, shape=(), dtype=dtypes.float_, mu=0.0, sigma=1.0):\n    dtype = dtypes.canonicalize_dtype(dtype)\n    shape = core.as_named_shape(shape)\n    mu = jnp.asarray(mu, dtype)\n    sigma = jnp.asarray(sigma, dtype)\n    return rng_normal(mu, sigma, shape.positional)\n\n\ndef fast_truncated_normal(key, lower, upper, shape=None, dtype=dtypes.float_):\n    dtype = dtypes.canonicalize_dtype(dtype)\n    if shape is not None:\n        shape = core.as_named_shape(shape)\n    out = fast_normal(key, shape=shape, dtype=dtype)\n    lower = lax.convert_element_type(lower, dtype)\n    upper = lax.convert_element_type(upper, dtype)\n    return jnp.clip(\n        out,\n        lax.nextafter(lax.stop_gradient(lower), np.array(np.inf, dtype=dtype)),\n        lax.nextafter(lax.stop_gradient(upper), np.array(-np.inf, dtype=dtype)))\n\n\ndef fast_bernoulli(key, p=np.float32(0.5), shape=None):\n    dtype = dtypes.canonicalize_dtype(lax.dtype(p))\n    return jax.random.uniform(key, shape, dtype) < p\n\n\ndef remove_fold_in(key, data):\n    return key\n\n\nrng_primitives = [lax.rng_uniform_p, rng_normal_p]\n\n# Monkey patch random generator to use the stateful random generator.\nbackup_random_uniform = jax.random.uniform\nbackup_random_truncated_normal = jax.random.truncated_normal\nbackup_random_normal = jax.random.normal\nbackup_random_bernoulli = jax.random.bernoulli\nbackup_random_foldin = jax.random.fold_in\n\n\ndef monkey_patch_random():\n    jax.random.uniform = fast_uniform\n    jax.random.truncated_normal = fast_truncated_normal\n    jax.random.normal = fast_normal\n    jax.random.bernoulli = fast_bernoulli\n    jax.random.fold_in = remove_fold_in\n\n    jax_src_random.uniform = fast_uniform\n    jax_src_random.truncated_normal = fast_truncated_normal\n    jax_src_random.normal = fast_normal\n    jax_src_random.bernoulli = fast_bernoulli\n    jax_src_random.fold_in = remove_fold_in\n\n\ndef restore_random():\n    jax.random.uniform = backup_random_uniform\n    jax.random.truncated_normal = backup_random_truncated_normal\n    jax.random.normal = backup_random_normal\n    jax.random.bernoulli = backup_random_bernoulli\n    jax.random.fold_in = backup_random_foldin\n\n    jax_src_random.uniform = backup_random_uniform\n    jax_src_random.truncated_normal = backup_random_truncated_normal\n    jax_src_random.normal = backup_random_normal\n    jax_src_random.bernoulli = backup_random_bernoulli\n    jax_src_random.fold_in = backup_random_foldin\n\n\n# Support using pickle on ShardingSpec\ndef sharding_spec_getstate(self):\n    sharding = []\n    for x in self.sharding:\n        if isinstance(x, pxla.NoSharding):\n            sharding.append((0,))\n        elif isinstance(x, pxla.Chunked):\n            sharding.append((1, x.chunks))\n        elif isinstance(x, pxla.Unstacked):\n            sharding.append((2, x.size))\n        else:\n            raise ValueError(f\"Invalid sharding: {x}\")\n    mesh_mapping = []\n    for x in self.mesh_mapping:\n        if isinstance(x, pxla.ShardedAxis):\n            mesh_mapping.append((0, x.axis))\n        elif isinstance(x, pxla.Replicated):\n            mesh_mapping.append((1, x.replicas))\n        else:\n            raise ValueError(f\"Invalid sharding: {x}\")\n    return (sharding, mesh_mapping)\n\n\ndef sharding_spec_setstate(self, state_tuple):\n    sharding_encoding, mesh_mapping_encoding = state_tuple\n\n    sharding = []\n    for x in sharding_encoding:\n        if x[0] == 0:\n            sharding.append(pxla.NoSharding())\n        elif x[0] == 1:\n            sharding.append(pxla.Chunked(x[1]))\n        elif x[0] == 2:\n            sharding.append(pxla.Unstacked(x[1]))\n        else:\n            raise ValueError(f\"Invalid sharding: {x}\")\n\n    mesh_mapping = []\n    for x in mesh_mapping_encoding:\n        if x[0] == 0:\n            mesh_mapping.append(pxla.ShardedAxis(x[1]))\n        elif x[0] == 1:\n            mesh_mapping.append(pxla.Replicated(x[1]))\n        else:\n            raise ValueError(f\"Invalid sharding: {x}\")\n\n    # pylint: disable=unnecessary-dunder-call\n    self.__init__(\n        sharding=sharding,\n        mesh_mapping=mesh_mapping,\n    )\n\n\npxla.ShardingSpec.__getstate__ = sharding_spec_getstate\npxla.ShardingSpec.__setstate__ = sharding_spec_setstate\n\n########################################\n##### Monkey patch Flax\n########################################\n\n\n# Monkey patch the nn.Embed in flax to use onehot + matmul instead of\n# gather/scatter,\n# because we currently do not support 2d partition of gather/scatter.\ndef embed_call_one_hot(self, inputs):\n    dtype = self.dtype\n    if global_config.flax_always_use_fp16_embedding:\n        dtype = jnp.float16\n    expanded = jax.nn.one_hot(inputs, self.embedding.shape[0], dtype=dtype)\n    ret = expanded @ jnp.asarray(self.embedding, dtype)\n    return ret\n\n\n# Monkey patch the nn.Embed in flax to add a fp16 conversion.\n# This is used for manual pipeline marker.\ndef embed_setup(self):\n    self.embedding = self.param(\"embedding\", self.embedding_init,\n                                (self.num_embeddings, self.features),\n                                self.param_dtype)\n    if self.dtype == jnp.float16:\n        self.embedding_fp16 = self.embedding.astype(jnp.float16)\n\n\nflax.linen.Embed.setup = embed_setup\nflax.linen.Embed.__call__ = embed_call_one_hot\n\n\n# Monkey patch a new method \"init_dummy\" to flax's Module.\n# This function initializes all weights with ones for testing/benchmark\n# purposes.\n# This function is much faster than the standard initialization.\ndef init_dummy(self, *args, **kwargs):\n    avals = jax.eval_shape(self.init, *args, **kwargs)\n    return jax.tree_util.tree_map(lambda x: jnp.full(x.shape, 1e-8, x.dtype),\n                                  avals)\n\n\nflax.linen.module.Module.init_dummy = init_dummy\n"
  },
  {
    "path": "alpa/parallel_method.py",
    "content": "\"\"\"Methods for parallelzing a function.\n\nAlpa classifies common parallel techniques into two categories:\n1. Shard parallelism or intra-operator parallelism. This includes data\n   parallelism, operator parallelism (or tensor model parallelism), expert\n   parallelism, zero optimizer and their combinations.\n2. Pipeline parallelism or inter-operator parallleism.\nPlease refer to the Alpa paper (https://arxiv.org/abs/2201.12023) for more\ndetails.\n\nBased on this, alpa provides two base parallel methods:\n- ShardParallel: which only uses shard parallelsim.\n- PipeshardParallel: which combines pipeline parallelism and shard parallelism.\n\"\"\"\nfrom abc import ABC, abstractmethod\nfrom typing import Callable, Optional, Sequence, Union, Any\n\nfrom jax import linear_util as lu\nfrom jax._src import traceback_util\nfrom jax.core import AbstractValue\nfrom jax.interpreters import pxla\nfrom jax.tree_util import PyTreeDef\nimport numpy as np\n\nfrom alpa.create_state_parallel import compile_create_state_executable\nfrom alpa.follow_parallel import compile_follow_parallel_executable\nfrom alpa.device_mesh import (PhysicalDeviceMesh, VirtualPhysicalMesh,\n                              LocalPhysicalDeviceMesh, get_global_physical_mesh,\n                              get_global_virtual_physical_mesh)\nfrom alpa.pipeline_parallel.compile_executable import compile_pipeshard_executable\nfrom alpa.pipeline_parallel.local_pipeline import compile_local_pipeline_executable\nfrom alpa.pipeline_parallel.layer_construction import (LayerOption,\n                                                       AutoLayerOption,\n                                                       ManualLayerOption)\nfrom alpa.pipeline_parallel.stage_construction import (StageOption,\n                                                       AutoStageOption,\n                                                       ManualStageOption,\n                                                       UniformStageOption)\nfrom alpa.shard_parallel.auto_sharding import AutoShardingOption, LogicalDeviceMesh\nfrom alpa.shard_parallel.compile_executable import compile_shard_executable\nfrom alpa.shard_parallel.manual_sharding import ManualShardingOption\n\ntraceback_util.register_exclusion(__file__)\n\n\nclass ParallelMethod(ABC):\n    \"\"\"Methods for parallelzing a function.\"\"\"\n\n    @abstractmethod\n    def compile_executable(\n        self,\n        fun: lu.WrappedFun,\n        in_tree: PyTreeDef,\n        out_tree_thunk: Callable[[], PyTreeDef],\n        static_argnums: Sequence[int],\n        donated_invars: Sequence[bool],\n        batch_invars: Sequence[bool],\n        *avals: Sequence[AbstractValue],\n    ):\n        \"\"\"Compile an executable.\"\"\"\n        raise NotImplementedError()\n\n\nclass ShardParallel(ParallelMethod):\n    \"\"\"Use shard parallelism to parallelize a function.\n\n    Args:\n        devices: Specify the devices to use. If it is None, use all devices\n          in the cluster.\n        num_micro_batches: The number of micro batches for gradient\n          accumulation.\n        auto_sharding_option: The options of the auto-sharding solver.\n    \"\"\"\n\n    def __init__(self,\n                 devices: Optional[Union[LogicalDeviceMesh,\n                                         PhysicalDeviceMesh]] = None,\n                 num_micro_batches: Optional[int] = None,\n                 auto_sharding_option: Optional[AutoShardingOption] = None,\n                 manual_sharding_option: Optional[ManualShardingOption] = None):\n        self.devices = devices\n        self.num_micro_batches = num_micro_batches\n        self.as_option = auto_sharding_option or AutoShardingOption()\n        self.ms_option = manual_sharding_option\n\n    def compile_executable(\n        self,\n        fun: lu.WrappedFun,\n        in_tree: PyTreeDef,\n        out_tree_thunk: Callable[[], PyTreeDef],\n        static_argnums: Sequence[int],\n        donated_invars: Sequence[bool],\n        batch_invars: Sequence[bool],\n        *avals: Sequence[AbstractValue],\n    ):\n        # Resolve the polymorphism in arguments\n        if self.devices is None:\n            mesh = get_global_physical_mesh(create_if_not_exist=True)\n            # Use 1d mesh by default\n            mesh = mesh.get_logical_mesh().flatten()\n        elif isinstance(self.devices, (list, tuple)):\n            mesh = LocalPhysicalDeviceMesh(self.devices)\n        else:\n            mesh = self.devices\n\n        assert isinstance(mesh, (PhysicalDeviceMesh, LogicalDeviceMesh))\n\n        return compile_shard_executable(fun, in_tree, out_tree_thunk,\n                                        static_argnums, donated_invars,\n                                        batch_invars, mesh,\n                                        self.num_micro_batches, self.as_option,\n                                        self.ms_option, *avals)\n\n\nclass DataParallel(ShardParallel):\n    \"\"\"\n    Use vanilla data parallelism.\n    This method syncs gradients by using all-reduce.\n    \"\"\"\n\n    def __init__(self,\n                 devices: Optional[Union[LogicalDeviceMesh,\n                                         PhysicalDeviceMesh]] = None,\n                 num_micro_batches: Optional[int] = None):\n        as_option = AutoShardingOption(force_data_parallel=True,\n                                       prefer_reduce_scatter=False)\n        super().__init__(devices, num_micro_batches, as_option)\n\n\nclass Zero2Parallel(ShardParallel):\n    \"\"\"\n    Use zero-2 based data parallelism. This method\n    1. replaces all-reduce by reduce-scatter and all-gather.\n    2. partitions more tensors such as optimizer states.\n    \"\"\"\n\n    def __init__(self,\n                 devices: Optional[Union[LogicalDeviceMesh,\n                                         PhysicalDeviceMesh]] = None,\n                 num_micro_batches: Optional[int] = None):\n        as_option = AutoShardingOption(force_data_parallel=True,\n                                       prefer_reduce_scatter=True)\n        super().__init__(devices, num_micro_batches, as_option)\n\n\nclass Zero3Parallel(ShardParallel):\n    \"\"\"\n    Use zero-3 based data parallelism.\n    Note that this method is experimental and not fully tested.\n    \"\"\"\n\n    def __init__(self,\n                 devices: Optional[Union[LogicalDeviceMesh,\n                                         PhysicalDeviceMesh]] = None,\n                 num_micro_batches: Optional[int] = None):\n        as_option = AutoShardingOption(force_zero_stage_3=True)\n        super().__init__(devices, num_micro_batches, as_option)\n\n\nclass PipeshardParallel(ParallelMethod):\n    \"\"\"\n    Use pipeshard parallelism which combines pipeline parallelism and\n    shard parallelism.\n\n    Args:\n        devices: Specify the devices to use. If it is None, use all the devices\n          in the cluster.\n        num_micro_batches: The number of micro batches for gradient\n          accumulation.\n        default_auto_sharding_option: The default options of the auto-sharding\n          solver.\n        pipeline_schedule: The pipieline schedules.\n          Possible choices: {\"1f1b\", \"gpipe\", \"inference\"}\n        layer_option: Options of grouping basic operators to layers.\n          Possible choices are {\"manual\", alpa.AutoLayerOption,\n                                 alpa.ManualLayerOption}\n        stage_option: Options of grouping layers into pipeline stages.\n          Possible choices are {\"uniform\", \"auto\", alpa.AutoStageOption,\n                                 alpa.ManualStageOption}\n        stage_input_shardings: Options of input sharding specs for each stage.\n          Shape: [num_pipeline_stages, num_input_vars_in_hlo_module].\n    \"\"\"\n\n    def __init__(\n            self,\n            devices: Optional[VirtualPhysicalMesh] = None,\n            num_micro_batches: int = 1,\n            default_auto_sharding_option: Optional[AutoShardingOption] = None,\n            pipeline_schedule: str = \"1f1b\",\n            layer_option: Optional[Union[LayerOption, str]] = None,\n            stage_option: Optional[Union[StageOption, str]] = None,\n            stage_input_shardings: Optional[Sequence[Sequence[\n                pxla.ShardingSpec]]] = None,\n            manual_sharding_option: ManualShardingOption = None):\n        self.devices = devices\n        self.num_micro_batches = num_micro_batches\n        self.as_option = (default_auto_sharding_option or\n                          AutoShardingOption(prefer_reduce_scatter=True))\n        self.pipeline_schedule = pipeline_schedule\n        if layer_option == \"manual\":\n            layer_option = ManualLayerOption()\n        self.layer_option = layer_option or AutoLayerOption(layer_num=2)\n        if stage_option == \"auto\":\n            stage_option = AutoStageOption(\n                submesh_physical_shape_space=\"power_of_two\",\n                submesh_logical_shape_space=\"single_node_model_parallel\",\n                stage_imbalance_tolerance=np.inf,\n                use_hlo_cost_model=False,\n                profiling_database_filename=None,\n                cached_profile_result=None,\n            )\n        elif stage_option == \"uniform\":\n            stage_option = UniformStageOption()\n        self.stage_option = stage_option or UniformStageOption()\n        self.stage_input_shardings = stage_input_shardings\n        assert not (stage_input_shardings is not None and\n                    manual_sharding_option is not None)\n        self.manual_sharding_option = manual_sharding_option\n\n    def compile_executable(\n        self,\n        fun: lu.WrappedFun,\n        in_tree: PyTreeDef,\n        out_tree_thunk: Callable[[], PyTreeDef],\n        static_argnums: Sequence[int],\n        donated_invars: Sequence[bool],\n        batch_invars: Sequence[bool],\n        *avals: Sequence[AbstractValue],\n    ):\n        # Resolve the polymorphism in arguments\n        if self.devices is None:\n            mesh = get_global_virtual_physical_mesh()\n            assert mesh is not None, (\n                \"Please run `alpa.init()` to initialize alpa.\")\n        else:\n            mesh = self.devices\n\n        assert isinstance(mesh, VirtualPhysicalMesh)\n\n        return compile_pipeshard_executable(\n            fun, in_tree, out_tree_thunk, static_argnums, donated_invars,\n            batch_invars, mesh, self.num_micro_batches, self.pipeline_schedule,\n            self.as_option, self.layer_option, self.stage_option, None,\n            self.stage_input_shardings, self.manual_sharding_option, *avals)\n\n\ndef get_3d_parallel_method(num_micro_batches: int,\n                           data_parallel: int,\n                           operator_parallel: int,\n                           pipeline_parallel: int,\n                           allow_degenerate_into_shard_parallel: bool = True,\n                           manual_layer_num: int = None,\n                           manual_sharding_option: ManualShardingOption = None):\n    \"\"\"\n    Get a parallel method for 3D parallelism, which reguarlly combines\n    data parallelism, operator parallelism and pipeline parallelism.\n    \"\"\"\n    # Validity check\n    virtual_mesh = get_global_virtual_physical_mesh()\n    num_devices = virtual_mesh.num_devices\n    num_devices_per_host = virtual_mesh.num_devices_per_host\n    if data_parallel == -1:\n        data_parallel = (num_devices // operator_parallel // pipeline_parallel)\n    assert num_devices % data_parallel == 0\n    assert num_devices % operator_parallel == 0\n    assert num_devices % pipeline_parallel == 0\n    assert (num_devices == data_parallel * operator_parallel *\n            pipeline_parallel)\n    pp = pipeline_parallel\n\n    # Decide logical and physical mesh shapes\n    logical_mesh_shape = (data_parallel, operator_parallel)\n    num_mesh_devices = np.prod(logical_mesh_shape)\n    if num_mesh_devices <= num_devices_per_host:\n        physical_mesh_shape = (1, num_mesh_devices)\n    else:\n        assert num_mesh_devices % num_devices_per_host == 0\n        physical_mesh_shape = (num_mesh_devices // num_devices_per_host,\n                               num_devices_per_host)\n\n    # If no pipeline parallel, degenerate into shard parallel\n    if pp == 1 and allow_degenerate_into_shard_parallel:\n        return ShardParallel(num_micro_batches=num_micro_batches,\n                             auto_sharding_option=AutoShardingOption(\n                                 prefer_reduce_scatter=True,\n                                 force_batch_dim_to_mesh_dim=0),\n                             devices=get_global_physical_mesh(\n                                 create_if_not_exist=True).get_logical_mesh(\n                                     [data_parallel, operator_parallel]))\n\n    # Return pipeshard parallel\n    if manual_layer_num is not None:\n        assert manual_layer_num % pp == 0\n        layer_option = ManualLayerOption()\n        stage_option = UniformStageOption(pp, physical_mesh_shape,\n                                          logical_mesh_shape, {})\n    else:\n        layer_option = AutoLayerOption(layer_num=pp, eps=0.1)\n        stage_option = ManualStageOption(\n            forward_stage_layer_ids=[[i] for i in range(pp)],\n            submesh_physical_shapes=[physical_mesh_shape] * pp,\n            submesh_logical_shapes=[logical_mesh_shape] * pp,\n            submesh_autosharding_option_dicts=[{}] * pp)\n    return PipeshardParallel(\n        devices=virtual_mesh,\n        num_micro_batches=num_micro_batches,\n        default_auto_sharding_option=AutoShardingOption(\n            enable_auto_sharding=manual_sharding_option is None,\n            prefer_reduce_scatter=True,\n            force_batch_dim_to_mesh_dim=0,\n        ),\n        layer_option=layer_option,\n        stage_option=stage_option,\n        manual_sharding_option=manual_sharding_option)\n\n\nclass LocalPipelineParallel(ParallelMethod):\n    \"\"\"\n    Run pipeline parallel on a single device.\n    This is only used for debugging.\n    \"\"\"\n\n    def compile_executable(\n        self,\n        fun: lu.WrappedFun,\n        in_tree: PyTreeDef,\n        out_tree_thunk: Callable[[], PyTreeDef],\n        static_argnums: Sequence[int],\n        donated_invars: Sequence[bool],\n        batch_invars: Sequence[bool],\n        *avals: Sequence[AbstractValue],\n    ):\n        return compile_local_pipeline_executable(fun, *avals)\n\n\nclass CreateStateParallel(ParallelMethod):\n    \"\"\"\n    Follow a train_step function to create the initial states distributedly.\n\n    Args:\n        train_step: The training step function.\n          See notes below for requirements.\n        other_args: Other arguments for calling the train_step function.\n\n    Notes:\n        To use thie parallel method, the function being parallelized should\n        return a single output `state`. Then train_step should take `state`\n        as the first argument and `other_args` as successive arguments.\n        See `tests/test_create_state.py` for example usages.\n    \"\"\"\n\n    def __init__(self, train_step: \"ParallelizedFunc\",\n                 other_args: Sequence[Any]):\n        # pylint: disable=import-outside-toplevel\n        from alpa.api import ParallelizedFunc\n        assert isinstance(train_step, ParallelizedFunc)\n\n        self.train_step = train_step\n        self.other_args = other_args\n\n        # TODO(lmzheng): support more flexible signatures.\n        # For example, the state does not have to be the first argument.\n\n    def compile_executable(\n        self,\n        fun: lu.WrappedFun,\n        in_tree: PyTreeDef,\n        out_tree_thunk: Callable[[], PyTreeDef],\n        static_argnums: Sequence[int],\n        donated_invars: Sequence[bool],\n        batch_invars: Sequence[bool],\n        *avals: Sequence[AbstractValue],\n    ):\n        return compile_create_state_executable(fun, in_tree, out_tree_thunk,\n                                               static_argnums, donated_invars,\n                                               self.train_step, self.other_args,\n                                               *avals)\n\n\nclass FollowParallel(ParallelMethod):\n    \"\"\"\n    Parallelize a function given its input placement specs.\n\n    Args:\n        num_micro_batches: The number of micro batches.\n        get_input_placement_specs: A callaback function that returns\n          the input placement specs.\n        pipeline_schedule: The pipeline schedule.\n          Possible choices: {\"1f1b\", \"gpipe\", \"inference\"}\n        layer_option: Options of grouping basic operators to layers.\n          Possible choices: {\"auto\", \"manual\"}.\n    \"\"\"\n\n    def __init__(self,\n                 src_func: \"ParallelizedFunc\",\n                 num_micro_batches: Optional[int] = None,\n                 get_input_placement_specs: Callable = None,\n                 pipeline_schedule: str = \"inference\",\n                 layer_option: str = \"follow\"):\n        self.src_func = src_func\n        self.num_micro_batches = num_micro_batches\n\n        if get_input_placement_specs is None:\n\n            def default_get():\n                executable = src_func.get_last_executable()\n                input_placement_specs = executable.get_input_placement_specs()\n                train_state, batch = input_placement_specs\n                return train_state.params, batch\n\n            get_input_placement_specs = default_get\n\n        self.get_input_placement_specs = get_input_placement_specs\n        self.pipeline_schedule = pipeline_schedule\n        self.layer_option = layer_option\n\n    def compile_executable(\n        self,\n        fun: lu.WrappedFun,\n        in_tree: PyTreeDef,\n        out_tree_thunk: Callable[[], PyTreeDef],\n        static_argnums: Sequence[int],\n        donated_invars: Sequence[bool],\n        batch_invars: Sequence[bool],\n        *avals: Sequence[AbstractValue],\n    ):\n        input_placement_specs = self.get_input_placement_specs()\n        return compile_follow_parallel_executable(\n            fun, in_tree, out_tree_thunk, static_argnums, donated_invars,\n            batch_invars, self.src_func, self.num_micro_batches,\n            input_placement_specs, self.pipeline_schedule, self.layer_option,\n            *avals)\n"
  },
  {
    "path": "alpa/parallel_plan.py",
    "content": "\"\"\"\nThe data strcutures to save all configurations/strategies of\na parallel execution plan.\n\"\"\"\nfrom dataclasses import dataclass\nfrom typing import Sequence, Tuple\n\nimport numpy as np\nfrom jax.core import ShapedArray\nfrom jax.interpreters import pxla\n\n\n@dataclass\nclass PlacementSpec:\n    \"\"\"Specify how a tensor is stored distributedly.\"\"\"\n    aval: ShapedArray\n    mesh_ids: Sequence[int]\n    sharding_specs: Sequence[pxla.ShardingSpec]\n\n\n@dataclass\nclass StagePlan:\n    \"\"\"The parallel plan for a single sharded stage.\"\"\"\n    build_random_seed: int\n    logical_mesh_shape: Tuple[int]\n    all_gather_threshold: int\n    all_reduce_threshold: int\n    auto_sharding_option: \"AutoShardingOption\"\n    auto_sharding_solution_vector: np.ndarray\n    auto_sharding_objective: int\n\n\n@dataclass\nclass PipelinePlan:\n    \"\"\"The parallel plan for a pipeline.\"\"\"\n    pipeline_schedule: str\n    layer_option: \"LayerOption\"\n    manual_stage_option: \"ManualStageOption\"\n\n\n@dataclass\nclass ClusterInfo:\n    num_hosts: int\n    num_devices_per_host: int\n\n\n@dataclass\nclass ParallelPlan:\n    \"\"\"The global parallel plan.\"\"\"\n    cluster_info: ClusterInfo\n    num_micro_batches: int\n    auto_sharding_option: \"AutoShardingOption\"\n    pipeline_plan: PipelinePlan\n    input_placement_specs: Sequence[PlacementSpec]\n\n\ndef plan_to_method(plan: ParallelPlan) -> \"ParallelMethod\":\n    \"\"\"Convert a parallel plan to a parallel method.\"\"\"\n    # pylint: disable=import-outside-toplevel\n    from alpa.parallel_method import ShardParallel, PipeshardParallel\n\n    if plan.pipeline_plan is None:\n        return ShardParallel(num_micro_batches=plan.num_micro_batches,\n                             auto_sharding_option=plan.auto_sharding_option)\n    else:\n        return PipeshardParallel(\n            num_micro_batches=plan.num_micro_batches,\n            default_auto_sharding_option=plan.auto_sharding_option,\n            pipeline_schedule=plan.pipeline_plan.pipeline_schedule,\n            layer_option=plan.pipeline_plan.layer_option,\n            stage_option=plan.pipeline_plan.manual_stage_option)\n"
  },
  {
    "path": "alpa/pipeline_parallel/__init__.py",
    "content": ""
  },
  {
    "path": "alpa/pipeline_parallel/apply_grad.py",
    "content": "\"\"\"Transformations and utilities to process gradient accumulation and\napply_gradient.\"\"\"\nimport logging\nfrom typing import Sequence, Dict, Tuple\n\nfrom jax._src.util import safe_map\nfrom jax.core import (Primitive, Var, Jaxpr, ClosedJaxpr, DropVar, Literal,\n                      get_aval, raise_to_shaped, JaxprEqn)\nfrom jax.interpreters import xla\nfrom jax.lax import add_p, div_p, and_p, or_p\nfrom jaxlib import xla_client as xc\nimport numpy as np\n\nfrom alpa.pipeline_parallel.computation import JaxPipelineComputation\nfrom alpa.pipeline_parallel.primitive_def import (pipeline_p,\n                                                  mark_pipeline_jaxpreqn)\nfrom alpa.util import (OrderedSet, clone_jaxpr, clone_jaxpr_eqn,\n                       get_var_mapping, mesh_ids_hash, new_jaxpr_eqn,\n                       slices_to_jaxpr)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n# pylint: disable=redefined-builtin\nunsafe_map, map = map, safe_map  # type: ignore\nAPPLY_GRAD_MARKER_SUFFIX = 'apply_grad'\n\n\ndef _filter_literal(vars):\n    return [v for v in vars if isinstance(v, Var)]\n\n\ndef _filter_droped(vars):\n    return [v for v in vars if not isinstance(v, DropVar)]\n\n\ndef _pipeline_marker_analysis(compute_eqns):\n    \"\"\"Get vars as inputs and outputs of layers\"\"\"\n    layer_invars = set()\n    pipeline_outvars = {}\n    marker_cnt = 0\n    for eqn in compute_eqns:\n        if eqn.primitive is pipeline_p:\n            if eqn.params['mark_type'] == 'end':\n                for v in _filter_droped(eqn.outvars):\n                    pipeline_outvars[v] = marker_cnt\n                marker_cnt += 1\n            elif eqn.params['mark_type'] == 'start':\n                layer_invars.update(_filter_literal(eqn.invars))\n    return layer_invars, pipeline_outvars\n\n\ndef _insert_to_pipeline_marker(marker, new_inv, mapping):\n    invs = list(marker.invars)\n    outvs = list(marker.outvars)\n    for inv in new_inv:\n        invs.append(inv)\n        outvs.append(mapping[inv])\n    return clone_jaxpr_eqn(marker, invs, outvs)\n\n\ndef _rewrite_compute_eqns(eqns, eqn_moved_to, gensym_fn):\n    \"\"\"Insert unmarked eqns(eqn_moved_to) to compute eqn sequence.\"\"\"\n    marker_cnt = 0\n    new_eqns = []\n    for eqn in eqns:\n        if eqn.primitive is not pipeline_p:\n            pass\n        elif eqn.params['mark_type'] == 'start':\n            cur_pipeline_start_idx = len(new_eqns)\n        elif marker_cnt not in eqn_moved_to:\n            marker_cnt += 1\n        else:\n            appended_eqns = eqn_moved_to[marker_cnt]\n            i_marker = new_eqns[cur_pipeline_start_idx]\n            o_marker = eqn\n            layer_invar_map = {\n                inv: outv\n                for inv, outv in zip(i_marker.invars, i_marker.outvars)\n                if isinstance(inv, Var) and not isinstance(outv, DropVar)\n            }\n            layer_outvar_map = {\n                outv: inv\n                for inv, outv in zip(o_marker.invars, o_marker.outvars)\n                if isinstance(inv, Var) and not isinstance(outv, DropVar)\n            }\n            # collect and create all vars, then rewrite and create eqns\n            inserted_invars = OrderedSet()\n            inserted_outvars = OrderedSet()\n            for eq in appended_eqns:\n                # collect and create all used and output vars\n                eq_new_invs = []\n                for inv in eq.invars:\n                    if isinstance(inv, Var):\n                        if inv in layer_outvar_map:\n                            # this layer defines the invar, use pre-marker ver.\n                            eq_new_invs.append(layer_outvar_map[inv])\n                        else:\n                            if inv not in layer_invar_map:\n                                # add new invar from other layers\n                                layer_invar_map[inv] = gensym_fn(inv.aval)\n                                inserted_invars.add(inv)\n                            eq_new_invs.append(layer_invar_map[inv])\n                    else:\n                        eq_new_invs.append(inv)\n                eq_new_outvs = []\n                for outv in eq.outvars:\n                    if isinstance(outv, DropVar):\n                        eq_new_outvs.append(outv)\n                    else:\n                        new_mapped = gensym_fn(outv.aval)\n                        layer_outvar_map[outv] = new_mapped\n                        inserted_outvars.add(new_mapped)\n                        eq_new_outvs.append(new_mapped)\n                # create the new eqn\n                new_eqns.append(clone_jaxpr_eqn(eq, eq_new_invs, eq_new_outvs))\n\n            # create the new in marker\n            new_eqns[cur_pipeline_start_idx] = _insert_to_pipeline_marker(\n                i_marker, inserted_invars, layer_invar_map)\n            layer_outvar_map = {v: k for k, v in layer_outvar_map.items()}\n            eqn = _insert_to_pipeline_marker(o_marker, inserted_outvars,\n                                             layer_outvar_map)\n            marker_cnt += 1\n\n        new_eqns.append(eqn)\n    return new_eqns\n\n\ndef _get_delayed_eqns(compute_eqns, layer_invars, pipeline_outvars, gensym_fn):\n    \"\"\"\n    Get eqns that can be delayed to apply gradient stage and rewrite eqns that\n    cannot do so by moving them into a layer.\n\n    An example of cannot delayed vars is: x is computed in layer0, and sent to\n    layer1 and layer2. There is grad(x) = grad_1(x) + grad_2(x), but the\n    grad(weight) depends on grad(x) and is in the acc_grad period, so we cannot\n    delay it to the apply_grad period.\n    \"\"\"\n    cross_layer_grad_eqns = []\n    new_compute_eqns = []\n    moved_to_layer_eqns = []\n\n    marked_vars = set()\n    used_vars = set()\n    out_marker = True\n    for eqn in reversed(compute_eqns):\n        invars = _filter_literal(eqn.invars)\n        outvars = _filter_droped(eqn.outvars)\n        used_outvars = used_vars.intersection(outvars)\n        if eqn.primitive is pipeline_p:\n            # invars of a pipeline end marker is marked\n            if eqn.params['mark_type'] == 'end':\n                marked_vars.update(invars)\n                out_marker = False\n            else:\n                out_marker = True\n            new_compute_eqns.append(eqn)\n        else:\n            # we don't want to do dce here, because it may make its operand be\n            # considered as cross layer grad, and then moved across microbatch\n            # boundary, which is harder to analyze.\n            if len(outvars) == 0 and out_marker:\n                continue\n            # only if an eqn is not used and is out marker will be it moved\n            # after microbatch boundary. Those inside a microbatch boundary is\n            # handled by later DCE.\n            elif not used_outvars and out_marker:\n                cross_layer_grad_eqns.append(eqn)\n                continue\n            elif marked_vars.issuperset(used_outvars):\n                # eqn is marked if all outvars are marked, then mark its invars.\n                marked_vars.update(invars)\n                new_compute_eqns.append(eqn)\n            else:\n                assert not marked_vars.intersection(\n                    outvars), f\"'{eqn}' is partially marked.\"\n                if layer_invars.intersection(outvars):\n                    # move the marked var to the latest stage producing some of\n                    # its invars.\n                    moved_to_layer_eqns.append(eqn)\n                    # update layer invars and marked vars.\n                    layer_invars.update(invars)\n                    marked_vars.update(outvars)\n                else:\n                    cross_layer_grad_eqns.append(eqn)\n                    continue\n        used_vars.update(invars)\n\n    new_compute_eqns = list(reversed(new_compute_eqns))\n    cross_layer_grad_eqns = list(reversed(cross_layer_grad_eqns))\n    eqn_moved_to = {}\n    for eqn in reversed(moved_to_layer_eqns):\n        invars = _filter_literal(eqn.invars)\n        outvars = _filter_droped(eqn.outvars)\n        moved_to = max(pipeline_outvars[v] for v in invars)\n        eqn_moved_to.setdefault(moved_to, []).append(eqn)\n        pipeline_outvars.update({v: moved_to for v in outvars})\n    if eqn_moved_to:\n        new_compute_eqns = _rewrite_compute_eqns(new_compute_eqns, eqn_moved_to,\n                                                 gensym_fn)\n    return cross_layer_grad_eqns, new_compute_eqns\n\n\ndef _rewrite_microbatch_bound(microbatch_bound, delayed_eqns, gensym_fn):\n    \"\"\"\n    Rewrite the microbatch bound because some eqns are moved from microbatched\n    part of the graph to non-microbatched part.\n    \"\"\"\n    microbatch_bound_in_to_outs = {}\n    for invar, outvar in zip(microbatch_bound.invars, microbatch_bound.outvars):\n        if isinstance(invar, Var) and not isinstance(outvar, DropVar):\n            microbatch_bound_in_to_outs[invar] = outvar\n    delayed_invars = OrderedSet()\n    delayed_outvars = OrderedSet()\n    for eqn in delayed_eqns:\n        delayed_invars.update(_filter_literal(eqn.invars))\n        delayed_outvars.update(_filter_droped(eqn.outvars))\n    delayed_invars.difference_update(delayed_outvars)\n    delayed_invars.difference_update(microbatch_bound_in_to_outs.keys())\n    delayed_outvars.intersection_update(microbatch_bound_in_to_outs.keys())\n    for invar in delayed_invars:\n        microbatch_bound_in_to_outs[invar] = gensym_fn(invar.aval)\n    # rewrite the microbatch_bound\n    new_microbatch_bound_invars = []\n    new_microbatch_bound_outvars = []\n    for idx, var in enumerate(microbatch_bound.invars + list(delayed_invars)):\n        # remove vars now defined after microbatch_bound.\n        if isinstance(var, Var) and var in delayed_outvars:\n            continue\n        new_microbatch_bound_invars.append(var)\n        # add vars now used after microbatch_bound.\n        new_microbatch_bound_outvars.append(\n            microbatch_bound.outvars[idx] if idx < len(microbatch_bound.invars)\n            else microbatch_bound_in_to_outs[var])\n    new_microbatch_bound = clone_jaxpr_eqn(microbatch_bound,\n                                           new_microbatch_bound_invars,\n                                           new_microbatch_bound_outvars)\n    return new_microbatch_bound, microbatch_bound_in_to_outs\n\n\ndef _rewrite_delayed_gradient_sum_eqns(delayed_eqns,\n                                       microbatch_bound_in_to_outs):\n    \"\"\"Change args of eqns that are delayed to the non-microbatched part.\"\"\"\n    new_apply_eqns = []\n    for eqn in delayed_eqns:\n        invars = [\n            microbatch_bound_in_to_outs[var] if isinstance(var, Var) and\n            var in microbatch_bound_in_to_outs else var for var in eqn.invars\n        ]\n        outvars = [\n            microbatch_bound_in_to_outs[var] if not isinstance(var, DropVar) and\n            var in microbatch_bound_in_to_outs else var for var in eqn.outvars\n        ]\n        new_apply_eqns.append(clone_jaxpr_eqn(eqn, invars, outvars))\n    return new_apply_eqns\n\n\ndef _value_to_literal(value, dtype):\n    literal_val = np.array(value, dtype)\n    return Literal(literal_val, raise_to_shaped(get_aval(literal_val)))\n\n\n# TODO(yonghao): delaying the cross layer grad accmulation increases memory\n# cost, but may not decrease communication: if c=a+b is delayed, both a and\n# b are accumulated, so the memory cost is more than when only accumulate c.\n# If layer that outputs a(called layer_a, and the same applys for b) is\n# merged with layer_b to the same stage, they do not need any communication,\n# so the communication does not benefit from the rewrite.\ndef _rewrite_cross_layer_grad(compute_eqns, microbatch_bound, apply_eqns,\n                              gensym_fn, closed_jaxpr):\n    \"\"\"\n    If a parameter is used in multiple stages, its gradient is computed in\n    multiple stages and then added together. We accumulate the results on each\n    stage, and add them together exactly at the start of apply grad period.\n\n    A common use case is the tied embedding in language models.\n    \"\"\"\n    layer_invars, pipeline_outvars = _pipeline_marker_analysis(compute_eqns)\n    # Those eqn directly use output of pipeline end is delayed to apply grad.\n    cross_layer_grad_eqns, new_compute_eqns = _get_delayed_eqns(\n        compute_eqns, layer_invars, pipeline_outvars, gensym_fn)\n    # Rewrite microbatch_bound and cross_layer_grad eqns.\n    (new_microbatch_bound,\n     microbatch_bound_in_to_outs) = _rewrite_microbatch_bound(\n         microbatch_bound, cross_layer_grad_eqns, gensym_fn)\n    # rewrite cross layer grad eqns and insert them to the top of apply eqns.\n    new_apply_eqns = _rewrite_delayed_gradient_sum_eqns(\n        cross_layer_grad_eqns, microbatch_bound_in_to_outs)\n    new_apply_eqns += apply_eqns\n    new_global_outvars = list(closed_jaxpr.jaxpr.outvars)\n    for idx in range(len(new_global_outvars)):\n        var = new_global_outvars[idx]\n        if isinstance(var, Literal):\n            continue\n        if isinstance(var, Var) and var in microbatch_bound_in_to_outs:\n            new_global_outvars[idx] = microbatch_bound_in_to_outs[var]\n    closed_jaxpr = clone_jaxpr(closed_jaxpr,\n                               eqns=new_compute_eqns + [new_microbatch_bound] +\n                               new_apply_eqns,\n                               outvars=new_global_outvars)\n    return closed_jaxpr\n\n\ndef _remove_replicated_marked_var(closed_jaxpr: ClosedJaxpr):\n    \"\"\"Some variables are marked multiple times with the same marker.\n    This pass removes them.\n    \"\"\"\n    new_eqns = []\n    var_map = {}\n    mb_idx = None\n    for eqn in closed_jaxpr.eqns:\n        if eqn.primitive == pipeline_p:\n            eqn_map = {}\n            new_invars = []\n            new_outvars = []\n            if eqn.params['mark_type'] == 'grad':\n                mb_idx = len(new_eqns)\n            for inv, outv in zip(eqn.invars, eqn.outvars):\n                if isinstance(outv, DropVar):\n                    continue\n                if isinstance(inv, Var):\n                    if inv in var_map:\n                        var_map[outv] = var_map[inv]\n                        continue\n                    elif inv in eqn_map:\n                        var_map[outv] = eqn_map[inv]\n                        continue\n                if isinstance(inv, Var):\n                    eqn_map[inv] = outv\n                new_invars.append(inv)\n                new_outvars.append(outv)\n            new_eqns.append(clone_jaxpr_eqn(eqn, new_invars, new_outvars))\n            continue\n        new_invars = [get_var_mapping(var_map, v) for v in eqn.invars]\n        new_eqns.append(clone_jaxpr_eqn(eqn, new_invars))\n    sliced_eqns = new_eqns[:mb_idx], [new_eqns[mb_idx]], new_eqns[mb_idx + 1:]\n    new_outvars = [\n        get_var_mapping(var_map, v) for v in closed_jaxpr.jaxpr.outvars\n    ]\n    return clone_jaxpr(closed_jaxpr, outvars=new_outvars,\n                       eqns=new_eqns), sliced_eqns\n\n\ndef jaxpr_have_apply_grad(closed_jaxpr: ClosedJaxpr):\n    \"\"\"Returns True if the jaxpr has apply_grad.\"\"\"\n    return any(eqn.primitive is pipeline_p and eqn.params['mark_type'] == 'grad'\n               for eqn in closed_jaxpr.eqns)\n\n\ndef split_compute_grad_and_apply_grad(closed_jaxpr: ClosedJaxpr, gensym_fn,\n                                      num_microbatch: int,\n                                      inference_mode: bool):\n    \"\"\"Split the train_step jaxpr into two parts: compute_grad and\n    apply_grad. These two parts are separated by a gradient marker generated\n    by `alpa.grad`.\"\"\"\n    # Locate the marker\n    split_eqn = None\n    for idx, eqn in enumerate(closed_jaxpr.eqns):\n        if eqn.primitive is pipeline_p and eqn.params['mark_type'] == 'grad':\n            split_eqn = eqn\n            split_idx = idx\n    if split_eqn is None:\n        if not inference_mode:\n            logger.warning(\n                'Missing microbatch_bound between compute and apply. '\n                'Assume there is no apply gradient step. '\n                'Hint: replace jax.grad by alpa.grad.')\n        dummy_jaxpr = ClosedJaxpr(Jaxpr([], [], [], []), [])\n        invars = list(closed_jaxpr.jaxpr.outvars) if num_microbatch > 1 else []\n        outvars = list(closed_jaxpr.jaxpr.outvars) if num_microbatch > 1 else []\n        dummy_bound = new_jaxpr_eqn(invars, outvars, pipeline_p, {\n            'mark_type': 'grad',\n            'name': ''\n        })\n        return closed_jaxpr, closed_jaxpr, dummy_jaxpr, dummy_bound\n    sliced_eqns = [\n        closed_jaxpr.eqns[:split_idx], split_eqn,\n        closed_jaxpr.eqns[split_idx + 1:]\n    ]\n    # Some equations are not marked. This pass moves them either into apply grad\n    # or a layer.\n    closed_jaxpr = _rewrite_cross_layer_grad(*sliced_eqns, gensym_fn,\n                                             closed_jaxpr)\n    closed_jaxpr, sliced_eqns = _remove_replicated_marked_var(closed_jaxpr)\n    # Reconstruct jaxpr\n    sliced_jaxprs = slices_to_jaxpr(closed_jaxpr, sliced_eqns)\n    compute_grad, _, apply_grad = sliced_jaxprs  # pylint: disable=unbalanced-tuple-unpacking\n    split_eqn = sliced_eqns[1][0]\n    if len(apply_grad.eqns) == 0:\n        logger.warning(\n            'the apply gradient part is empty. Hint: apply() after alpa.grad')\n    assert len(split_eqn.invars) == len(split_eqn.outvars)\n    invars_without_dropvar = []\n    outvars_without_dropvar = []\n    for invar, outvar in zip(split_eqn.invars, split_eqn.outvars):\n        if not isinstance(outvar, DropVar):\n            invars_without_dropvar.append(invar)\n            outvars_without_dropvar.append(outvar)\n    split_eqn = clone_jaxpr_eqn(split_eqn, invars_without_dropvar,\n                                outvars_without_dropvar)\n    return closed_jaxpr, compute_grad, apply_grad, split_eqn\n\n\ndef _get_post_to_pre_marker_mapping(compute_jaxpr):\n    \"\"\"\n    Get a dict that maps an out_var of a pipeline marker to\n    its corresponding in_var.\n    \"\"\"\n    post_marker_outs = _filter_droped(compute_jaxpr.jaxpr.outvars)\n    # Currently, assume no grad is literal\n    assert len(post_marker_outs) == len(compute_jaxpr.jaxpr.outvars)\n    post_marker_outs = OrderedSet(post_marker_outs)\n    # from post_marker_outs to post_to_pre_marker_outs(cross pipeline marker)\n    post_to_pre_marker_outs = {}\n    pre_to_post_marker_outs = {}\n    for eqn in reversed(compute_jaxpr.eqns):\n        if eqn.primitive is pipeline_p:\n            for i, outvar in enumerate(eqn.outvars):\n                if outvar in post_marker_outs:\n                    post_to_pre_marker_outs[outvar] = eqn.invars[i]\n                    pre_to_post_marker_outs[eqn.invars[i]] = outvar\n                elif outvar in pre_to_post_marker_outs:\n                    # in case that:\n                    #   invar = compute gradient\n                    #   invar' = pipeline end(invar)\n                    #   outvar = pipeline start(invar')\n                    #   final = pipeline end(outvar)\n                    # post_to_pre_marker_outs[final] = invar' instead of outvar\n                    final_outvar = pre_to_post_marker_outs[outvar]\n                    post_to_pre_marker_outs[final_outvar] = eqn.invars[i]\n                    pre_to_post_marker_outs[eqn.invars[i]] = final_outvar\n    for outvar in post_marker_outs:\n        assert outvar in post_to_pre_marker_outs, (\n            'all outputs should be captured by pipeline marker')\n    return post_to_pre_marker_outs\n\n\ndef _rewrite_jaxpr_to_reduced_outputs(compute_jaxpr, to_reduce_pre_marker_outs,\n                                      reduce_invars, reduce_outvars, gensym_fn):\n    new_eqns = []\n    pipe_start = None\n    pipe_eqns = []\n    to_acc = []\n    to_reduce_pre_marker_outs = OrderedSet(to_reduce_pre_marker_outs)\n    for eqn in compute_jaxpr.eqns:\n        if eqn.primitive is pipeline_p:\n            if eqn.params['mark_type'] == 'start':\n                pipe_start = eqn\n                for outvar in eqn.outvars:\n                    if (not isinstance(outvar, DropVar) and\n                            outvar in to_reduce_pre_marker_outs):\n                        # collect to_reduce_pre_marker_outs in this computation\n                        to_acc.append(outvar)\n                continue\n            if eqn.params['mark_type'] == 'end':\n                # add grad used in this computation in pipeline start\n                reduce_invar_post_pipe = {\n                    outvar: gensym_fn(outvar.aval) for outvar in to_acc\n                }\n                reduce_outvar_pre_pipe = {\n                    outvar: gensym_fn(outvar.aval) for outvar in to_acc\n                }\n                new_pipe_start = mark_pipeline_jaxpreqn(\n                    pipe_start.invars + map(lambda x: reduce_invars[x], to_acc),\n                    pipe_start.outvars +\n                    # pylint: disable=cell-var-from-loop\n                    map(lambda x: reduce_invar_post_pipe[x], to_acc),\n                    pipe_start.params['name'],\n                    pipe_start.params['mark_type'])\n                new_eqns.append(new_pipe_start)\n                # add normal eqns\n                new_eqns.extend(pipe_eqns)\n                # add acc grad(adds)\n                for gradient in to_acc:\n                    new_eqns.append(\n                        new_jaxpr_eqn(\n                            [reduce_invar_post_pipe[gradient], gradient],\n                            [reduce_outvar_pre_pipe[gradient]], add_p, {}))\n                # add grad created in this computation in pipeline end\n                new_pipe_end = mark_pipeline_jaxpreqn(\n                    # pylint: disable=cell-var-from-loop\n                    eqn.invars +\n                    map(lambda x: reduce_outvar_pre_pipe[x], to_acc),\n                    eqn.outvars + map(lambda x: reduce_outvars[x], to_acc),\n                    eqn.params['name'],\n                    eqn.params['mark_type'])\n                new_eqns.append(new_pipe_end)\n                pipe_start = None\n                pipe_eqns = []\n                to_acc = []\n                continue\n        pipe_eqns.append(eqn)\n        for outvar in eqn.outvars:\n            if (not isinstance(outvar, DropVar) and\n                    outvar in to_reduce_pre_marker_outs):\n                # collect to_reduce_pre_marker_outs in this computation\n                to_acc.append(outvar)\n    return new_eqns\n\n\n# TODO(yonghao): support not only reduction and concate. Some outputs may not\n# rely on batch dimension.\ndef compute_grad_to_accumulate_grad(\n        compute_jaxpr: ClosedJaxpr, microbatch_bound: JaxprEqn,\n        reduction_vector: Sequence[bool], gensym_fn,\n        num_microbatch) -> Tuple[ClosedJaxpr, JaxprEqn, Dict[Var, Var]]:\n    \"\"\"Transform compute_grad jaxpr with pipeline markers into accumulate_grad\n    jaxpr.\n\n    Args:\n        compute_jaxpr: the original jaxpr\n        microbatch_bound: The boundary eqn that separates compute_grad and\n          apply_grad.\n        reduction_vector: if the outvar is reduced(accumulated) or not\n        gensym_fn: gensym function\n\n    Returns:\n        acc_grad_jaxpr: The accumulate grad jaxpr\n        microbatch_bound: The updated microbatch boundary\n        reduced_in_to_out: From accumulated gradient inputs to outputs\n    \"\"\"\n    if num_microbatch <= 1:\n        return compute_jaxpr, microbatch_bound, {}\n\n    post_to_pre_marker_outs = _get_post_to_pre_marker_mapping(compute_jaxpr)\n    to_reduce_pre_marker_outs = []\n    for var, reduced in zip(compute_jaxpr.jaxpr.outvars, reduction_vector):\n        if reduced:\n            to_reduce_pre_marker_outs.append(post_to_pre_marker_outs[var])\n    # generate new variables\n    reduced_invars = {\n        outvar: gensym_fn(outvar.aval) for outvar in to_reduce_pre_marker_outs\n    }\n    reduced_outvars = {\n        outvar: gensym_fn(outvar.aval) for outvar in to_reduce_pre_marker_outs\n    }\n    # modify output, here all grads are acc_grad\n    new_glob_outvars = []\n    new_glob_invars = compute_jaxpr.jaxpr.invars + []\n    update_outs = {}\n    reduced_in_to_out = {}\n    for outvar, reduced in zip(compute_jaxpr.jaxpr.outvars, reduction_vector):\n        if not reduced:\n            new_glob_outvars.append(outvar)\n            update_outs[outvar] = outvar\n        elif isinstance(outvar, Var):\n            assert outvar in post_to_pre_marker_outs\n            pre_marker_outvar = post_to_pre_marker_outs[outvar]\n            reduced_outvar = reduced_outvars[pre_marker_outvar]\n            reduced_invar = reduced_invars[pre_marker_outvar]\n\n            new_glob_outvars.append(reduced_outvar)\n            new_glob_invars.append(reduced_invar)\n            update_outs[outvar] = reduced_outvar\n            reduced_in_to_out[reduced_invar] = reduced_outvar\n        else:\n            raise NotImplementedError('outputs cannot be Literal')\n    # rewrite eqns\n    new_eqns = _rewrite_jaxpr_to_reduced_outputs(compute_jaxpr,\n                                                 to_reduce_pre_marker_outs,\n                                                 reduced_invars,\n                                                 reduced_outvars, gensym_fn)\n\n    new_closed_jaxpr = clone_jaxpr(compute_jaxpr, new_glob_invars,\n                                   new_glob_outvars, new_eqns)\n\n    microbatch_bound_invars = [update_outs[x] for x in microbatch_bound.invars]\n    microbatch_bound = clone_jaxpr_eqn(microbatch_bound,\n                                       microbatch_bound_invars)\n    return new_closed_jaxpr, microbatch_bound, reduced_in_to_out\n\n\ndef _get_apply_grad_outvar_constraints(pipeline_stages, stage_to_mesh,\n                                       global_invars, donated_invars,\n                                       donation_mapping):\n    \"\"\"Infer outvar constraints of apply gradient based on donation.\"\"\"\n    outvar_mesh = {}\n    donated_global_vars = {\n        invar for invar, donate in zip(global_invars, donated_invars) if donate\n    }\n    for stage_idx, stage in enumerate(pipeline_stages):\n        for invar in stage.invars:\n            if invar in donated_global_vars:\n                outvar_mesh.setdefault(donation_mapping[invar],\n                                       OrderedSet()).add(\n                                           stage_to_mesh[stage_idx])\n    return outvar_mesh\n\n\ndef process_apply_gradient(apply_grad_jaxpr, microbatch_bound, pipeline_stages,\n                           stage_to_mesh, gensym_func, num_meshes,\n                           global_invars, global_outvars, donated_invars,\n                           profiling, mesh_num_devices):\n    \"\"\"Slice apply_grad jaxpr into stages and assign them to the corresponding\n    meshes.\"\"\"\n    # Process apply gradient:\n    # change invars of apply grad to outvars of accumulate grad\n    gradients = microbatch_bound.outvars\n    apply_in_to_acc_out = dict(zip(gradients, microbatch_bound.invars))\n\n    gradvar_to_mesh = get_var_to_mesh(gradients, pipeline_stages, stage_to_mesh,\n                                      apply_in_to_acc_out)\n\n    # update donation mapping\n    donation_mapping = {}\n    for idx, invar in enumerate(global_invars):\n        if donated_invars[idx]:\n            donation_mapping[invar] = global_outvars[idx]\n    # create outvar constraints\n    outvar_mesh = _get_apply_grad_outvar_constraints(pipeline_stages,\n                                                     stage_to_mesh,\n                                                     global_invars,\n                                                     donated_invars,\n                                                     donation_mapping)\n\n    sliced_apply_grad_stages, apply_grad_placement, allreduce_groups = (\n        slice_apply_gradient(apply_grad_jaxpr, gradvar_to_mesh, outvar_mesh,\n                             num_meshes, len(pipeline_stages), donation_mapping,\n                             gensym_func, profiling, mesh_num_devices))\n    sliced_apply_grad_stages, out_map = apply_grad_add_marker(\n        sliced_apply_grad_stages,\n        apply_in_to_acc_out,\n        gensym_func,\n        computation=True)\n    global_outvars = [get_var_mapping(out_map, var) for var in global_outvars]\n\n    return (sliced_apply_grad_stages, apply_grad_placement, global_outvars,\n            allreduce_groups)\n\n\ndef replace_all_with(closed_jaxpr: ClosedJaxpr, mapping):\n    \"\"\"Replace all variables in a jaxpr given the mapping.\"\"\"\n\n    def map_var(var):\n        return get_var_mapping(mapping, var)\n\n    new_glob_invars = [map_var(var) for var in closed_jaxpr.jaxpr.invars]\n    new_glob_outvars = [map_var(var) for var in closed_jaxpr.jaxpr.outvars]\n    new_eqns = []\n    for eqn in closed_jaxpr.eqns:\n        new_invars = [map_var(var) for var in eqn.invars]\n        new_outvars = [map_var(var) for var in eqn.outvars]\n        new_eqns.append(clone_jaxpr_eqn(eqn, new_invars, new_outvars))\n    new_jaxpr = clone_jaxpr(closed_jaxpr, new_glob_invars, new_glob_outvars,\n                            new_eqns)\n    return new_jaxpr\n\n\ndef apply_grad_get_mean(apply_grad_jaxpr, global_outvars, gradients, gensym_fn,\n                        num_microbatch, reduce_invars):\n    \"\"\"\n    Get the mean of input (accumulated) gradients and run apply gradient.\n\n    If the input is output, after this transform it outputs the divided version.\n    \"\"\"\n    mapping = {}\n    new_eqns = []\n    invar_set = OrderedSet(apply_grad_jaxpr.jaxpr.invars)\n    outvar_set = OrderedSet(apply_grad_jaxpr.jaxpr.outvars)\n    for invar, reduce in zip(gradients, reduce_invars):\n        if not reduce:\n            mapping[invar] = invar\n            continue\n        div_out = gensym_fn(invar.aval)\n        new_eqns.append(\n            new_jaxpr_eqn([\n                invar,\n                _value_to_literal(num_microbatch, invar.aval.dtype),\n            ], [div_out], div_p, {}))\n        mapping[invar] = div_out\n    replaced = replace_all_with(apply_grad_jaxpr, mapping)\n    final_invars = list(apply_grad_jaxpr.jaxpr.invars)\n    final_outvars = list(replaced.jaxpr.outvars)\n    for invar, reduce in zip(gradients, reduce_invars):\n        if not reduce:\n            continue\n        if invar not in invar_set:\n            final_invars.append(invar)\n        if invar in global_outvars and invar not in outvar_set:\n            # use the divided version to replace the original one\n            final_outvars.append(mapping[invar])\n    new_eqns.extend(replaced.jaxpr.eqns)\n    new_jaxpr = clone_jaxpr(apply_grad_jaxpr, final_invars, final_outvars,\n                            new_eqns)\n    global_outvars = [get_var_mapping(mapping, var) for var in global_outvars]\n    return new_jaxpr, global_outvars\n\n\ncross_mesh_allreduce_p = Primitive('__builtin$CrossMeshAllReduce')\n_primitive_to_str = {add_p: b'SUM', and_p: b'AND', or_p: b'OR'}\n\n\ndef _cross_mesh_allreduce_xla_translation(c, *args, **kwargs):\n    call_name = b'__builtin$CrossMeshAllReduce'\n    assert len(args) == 1\n    input_params = args[0]\n    input_shape = c.get_shape(input_params)\n    op_type = _primitive_to_str[kwargs['type']]\n    opaque = op_type + b';' + mesh_ids_hash(kwargs['group_meshes'])\n\n    # TODO(yonghao): the has_side_effect is to prevent CSE of the allreduce.\n    # It might be replaced by adding its outvar to output\n    sharding = xc.OpSharding()\n    sharding.type = sharding.type.REPLICATED\n    c.set_sharding(sharding)\n    output = xc.ops.CustomCall(c,\n                               call_name,\n                               operands=(input_params,),\n                               shape=input_shape,\n                               has_side_effect=True,\n                               opaque=opaque)\n    c.clear_sharding()\n    return output\n\n\nxla.translations[cross_mesh_allreduce_p] = _cross_mesh_allreduce_xla_translation\n\n\ndef _init_eqn_var_mesh(closed_jaxpr, var_mesh):\n    eqn_mesh = []\n    var_mesh = dict(var_mesh)\n    for eqn_idx, eqn in enumerate(closed_jaxpr.eqns):\n        eqn_mesh.append(OrderedSet())\n        for var in eqn.invars:\n            if isinstance(var, Var):\n                var_mesh.setdefault(var, OrderedSet())\n        for var in eqn.outvars:\n            if not isinstance(var, DropVar):\n                var_mesh.setdefault(var, OrderedSet())\n        if eqn.primitive != cross_mesh_allreduce_p:\n            continue\n        mesh_ids = eqn.params['group_meshes']\n        for var, mesh_id in zip(eqn.invars, mesh_ids):\n            var_mesh[var].add(mesh_id)\n        var_mesh[eqn.outvars[0]] = OrderedSet(mesh_ids)\n        eqn_mesh[eqn_idx] = OrderedSet(mesh_ids)\n    return eqn_mesh, var_mesh\n\n\ndef _propagate_with_donation(closed_jaxpr, donation_mapping, var_mesh):\n    changed = False\n    for invar in closed_jaxpr.jaxpr.invars:\n        if invar in donation_mapping:\n            outvar = donation_mapping[invar]\n            outvar_at = var_mesh[outvar]\n            invar_at = var_mesh[invar]\n            if invar_at.difference(outvar_at):\n                outvar_at.update(invar_at)\n                changed = True\n            if outvar_at.difference(invar_at):\n                invar_at.update(outvar_at)\n    return changed\n\n\ndef _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping, eqn_mesh,\n                                   var_mesh):\n    \"\"\"Propagate var_at_mesh from output to make sure all operands are ready.\"\"\"\n    # Different from forward propagation, the eqn should be at to any mesh of\n    # any outvar. Now the semantic switches from 'can be at' to 'is at'\n    changed = False\n    for reversed_idx, eqn in enumerate(reversed(closed_jaxpr.eqns)):\n        eqn_idx = len(closed_jaxpr.eqns) - 1 - reversed_idx\n        post_at_mesh = eqn_mesh[eqn_idx]\n        at_mesh = OrderedSet()\n        for outvar in eqn.outvars:\n            if not isinstance(outvar, DropVar):\n                at_mesh.update(var_mesh[outvar])\n        if not at_mesh:\n            continue\n        if (not post_at_mesh or at_mesh.difference(post_at_mesh)):\n            changed = True\n            post_at_mesh.update(at_mesh)\n            if eqn.primitive != cross_mesh_allreduce_p:\n                for invar in eqn.invars:\n                    if isinstance(invar, Var):\n                        var_mesh[invar].update(at_mesh)\n    changed |= _propagate_with_donation(closed_jaxpr, donation_mapping,\n                                        var_mesh)\n    return changed\n\n\ndef _forward_propagate_at_mesh(closed_jaxpr, eqn_mesh, var_mesh, aggressive):\n    \"\"\"\n    Propagate the can/may be at info for eqns and vars not yet allocated.\n\n    Can at mode is conservative. It computes the intersection of all invars'\n    meshes. When var_0 is at mesh_0 and var_1 at mesh_0,1, the eqn can only be\n    at mesh 0.\n\n    May at mode is to handle those cannot be solved by can at mode. That is,\n    at one point, the intersection of all invars' meshes is empty. Then there\n    should have some redundant computation and memory consumptions.\n\n    TODO: Currently we only use the first element of all available candidates in\n    both mode, but for 'may at' mode, we need to pick the one with the least\n    redundancy using some estimation. For 'can at' mode, a round-robin is better\n    \"\"\"\n    var_infered_at = {}\n    for eqn_idx, eqn in enumerate(closed_jaxpr.eqns):\n        if eqn_mesh[eqn_idx]:\n            continue\n        eqn_infered_at = None\n        # For invar_0 available at mesh_0, invar_1 available at mesh_0,1\n        # the outvar is better at mesh_0 instead of mesh_0,1\n        for var in eqn.invars:\n            if not isinstance(var, Var):\n                continue\n            if var_mesh[var]:\n                invar_infered_at = var_mesh[var]\n            elif var in var_infered_at and var_infered_at[var]:\n                invar_infered_at = var_infered_at[var]\n            else:\n                invar_infered_at = None\n            if invar_infered_at:\n                if eqn_infered_at is None:\n                    eqn_infered_at = OrderedSet(invar_infered_at)\n                else:\n                    if aggressive:\n                        eqn_infered_at.update(invar_infered_at)\n                    else:\n                        eqn_infered_at.intersection_update(invar_infered_at)\n        if eqn_infered_at:\n            for var in eqn.outvars:\n                if not isinstance(var, DropVar):\n                    var_infered_at[var] = OrderedSet(eqn_infered_at)\n    changed = False\n    for var in closed_jaxpr.jaxpr.outvars:\n        if (not isinstance(var, DropVar) and not var_mesh[var]):\n            if var in var_infered_at:\n                var_mesh[var] = OrderedSet([list(var_infered_at[var])[0]])\n            elif aggressive:\n                var_mesh[var] = OrderedSet([0])\n            else:\n                continue\n            changed = True\n    return changed\n\n\ndef _apply_grad_group_vars(closed_jaxpr: ClosedJaxpr, var_mesh, num_mesh):\n    \"\"\"Slice the input, output and consts of the jaxpr based on var_mesh.\"\"\"\n    global_invars = closed_jaxpr.jaxpr.invars\n    invars = [[] for _ in range(num_mesh)]\n    outvars = [[] for _ in range(num_mesh)]\n    constvars = [[] for _ in range(num_mesh)]\n    consts = [[] for _ in range(num_mesh)]\n    # grouping invars and outvars\n    for invar in global_invars:\n        for mesh in var_mesh[invar]:\n            invars[mesh].append(invar)\n    for outvar in closed_jaxpr.jaxpr.outvars:\n        for mesh in var_mesh[outvar]:\n            outvars[mesh].append(outvar)\n    # grouping consts and constvars\n    for aval, var in zip(closed_jaxpr.consts, closed_jaxpr.jaxpr.constvars):\n        for mesh in var_mesh[var]:\n            consts[mesh].append(aval)\n            constvars[mesh].append(var)\n    return invars, outvars, consts, constvars\n\n\n# Binary operators that satisfies the associativity and commutativity\n_reducable_operators = set([add_p, and_p, or_p])\n\n\nclass ApplyGradRewriter:\n    \"\"\"\n    Rewrite apply grad jaxpr to avoid replicated computation by inserting\n    cross-mesh allreduce.\n    \"\"\"\n\n    def __init__(self, apply_grad_jaxpr: ClosedJaxpr, var_mesh):\n        self.jaxpr = apply_grad_jaxpr\n        self.eqns = apply_grad_jaxpr.jaxpr.eqns\n        self.outvars = apply_grad_jaxpr.jaxpr.outvars\n        self.var_mesh = dict(var_mesh)\n        self.eqn_mesh = {}\n        self.var_use: Dict[Var, OrderedSet] = {}\n        self.var_def: Dict[Var, int] = {}\n\n    def _reducable(self, eqn):\n        \"\"\"An eqn is reducable if it is a reducable and scalar operation\"\"\"\n        # the is_scalar is to avoid a large all-reduce for tied-embedding\n        # it can be improved by adding computation-communication tradeoff\n        return (eqn.primitive in _reducable_operators and\n                eqn.outvars[0].aval.shape == ())\n\n    def _forward_propagate(self):\n        \"\"\"\n        A conservative propagation that stops when the eqn's invars are from\n        multiple meshes.\n        \"\"\"\n        self.eqn_mesh = {}\n        self.var_use = {}\n        self.var_def = {}\n        for eqn_idx, eqn in enumerate(self.eqns):\n            for invar in _filter_literal(eqn.invars):\n                self.var_use.setdefault(invar, OrderedSet()).add(eqn_idx)\n            for outvar in _filter_droped(eqn.outvars):\n                self.var_def[outvar] = eqn_idx\n        has_color = OrderedSet([\n            self.var_def[k]\n            for k in self.var_mesh\n            if (len(self.var_mesh[k]) > 0 and k in self.var_def)\n        ])\n        q = list(has_color)\n        while len(q) > 0:\n            for outv in _filter_droped(self.eqns[q[0]].outvars):\n                if outv not in self.var_use:\n                    continue\n                used_eqns = self.var_use[outv]\n                has_color.update(used_eqns)\n                for e_id in used_eqns.difference(has_color):\n                    q.append(e_id)\n            q = q[1:]\n\n        # Propagate the first round\n        for eqn_idx, eqn in enumerate(self.eqns):\n            at_mesh = OrderedSet()\n            for invar in _filter_literal(eqn.invars):\n                at_mesh.update(self.var_mesh.setdefault(invar, OrderedSet()))\n            # TODO(yonghao): round robin this and use it in later positions\n            if len(at_mesh) == 0 and eqn_idx not in has_color:\n                at_mesh = OrderedSet([0])\n            if len(at_mesh) == 1:\n                for invar in _filter_literal(eqn.invars):\n                    self.var_mesh.setdefault(invar,\n                                             OrderedSet()).update(at_mesh)\n            self.eqn_mesh[eqn_idx] = list(at_mesh)\n            for outvar in _filter_droped(eqn.outvars):\n                self.var_mesh[outvar] = OrderedSet(at_mesh)\n\n    def _reducable_chain_lookup(self, eqn_idx, num_mesh):\n        \"\"\"\n        Pattern matching. For y = x_0 op x_1 op x_2 ... op x_n, it is as\n        y_0 = x_0 op x_1, y_1 = y_0 op x_2, ... in jaxpr. This function collects\n        all such x_0, x_1, ... x_n by making sure that intermediates like y_0 &\n        y_1 are not used elsewhere.\n\n        Returns:\n            mesh_vars: list of variables being reduced in a certain mesh.\n            final_var: The final outvar(the y above)\n            removed: Indices of eqns being removed. They compute intermediates.\n            literals: Literals along with the reduction\n        \"\"\"\n        # List[mesh_idx -> List[Vars]]\n        mesh_vars = [[] for _ in range(num_mesh)]\n        literals = []\n        eqn = self.eqns[eqn_idx]\n        nxt_idx, nxt_eqn = eqn_idx, eqn\n        reducable_chain = []\n        while self._reducable(nxt_eqn) and (nxt_eqn.primitive == eqn.primitive):\n            cur_idx, cur_eqn = nxt_idx, nxt_eqn\n            reducable_chain.append(cur_idx)\n            outv_use = self.var_use.setdefault(cur_eqn.outvars[0], OrderedSet())\n            # If the var is used in multiple places or global output, it is not\n            # a safe intermediate variable and the chain ends.\n            if len(outv_use) != 1 or cur_eqn.outvars[0] in self.outvars:\n                break\n            nxt_idx = list(outv_use)[0]\n            nxt_eqn = self.eqns[nxt_idx]\n        if cur_idx == eqn_idx:\n            return None, None, None, None\n        final_var = cur_eqn.outvars[0]\n        # split eqns on the reducable chain into meshes\n        reducable_set = set(reducable_chain)\n        for reduced_idx in reducable_chain:\n            reduced_eqn = self.eqns[reduced_idx]\n            for op in reduced_eqn.invars:\n                # We can assign all literals to mesh 0 cuz they'll be optimized\n                # by arithmetic simplification.\n                if isinstance(op, Literal):\n                    mesh_vars[0].append(op)\n                    continue\n                def_idx = self.var_def[op]\n                if def_idx not in reducable_set:\n                    def_meshes = self.eqn_mesh[def_idx]\n                    # TODO(yonghao): round-robin this\n                    mesh_vars[list(def_meshes)[0]].append(op)\n        return mesh_vars, final_var, reducable_chain[:-1], literals\n\n    def _rewrite_eqns(self, primitive, mesh_vars, gensym_fn, outvar, literals):\n        # rewrite according to splits\n        # TODO: in some cases the literal can lead to final result(True&or_p)\n        appended_eqns = []\n        allreduce_vars = []\n        mesh_ids = []\n        literal_handled = False\n        for mesh_id, per_mesh_vars in enumerate(mesh_vars):\n            cur_val = None\n            for v in per_mesh_vars:\n                if cur_val is None:\n                    # This is the first var in the mesh for the chain\n                    cur_val = v\n                    continue\n                new_var = gensym_fn(cur_val.aval)\n                # accumulate in-mesh result\n                appended_eqns.append(\n                    new_jaxpr_eqn([cur_val, v], [new_var], primitive, {}))\n                cur_val = new_var\n            if cur_val is not None:\n                if not literal_handled:\n                    for literal in literals:\n                        new_var = gensym_fn(cur_val.aval)\n                        appended_eqns.append(\n                            new_jaxpr_eqn([cur_val, literal], [new_var],\n                                          primitive, {}))\n                        cur_val = new_var\n                    literal_handled = True\n                allreduce_vars.append(cur_val)\n                mesh_ids.append(mesh_id)\n        # modify the end of reduce chain eqn into an all-reduce.\n        # The allreduce will be immediately replaced by pipeline markers\n        appended_eqns.append(\n            new_jaxpr_eqn(allreduce_vars, [outvar], cross_mesh_allreduce_p, {\n                'type': primitive,\n                'group_meshes': mesh_ids\n            }))\n        return appended_eqns, mesh_ids\n\n    def split_replicated_eqns(self, gensym_fn, num_mesh):\n        \"\"\"Rewrite apply grad jaxpr to eqns so as to \"\"\"\n        self._forward_propagate()\n        new_eqns_before_var = {}\n        # Try to match the pattern\n        removed_eqns = set()\n        allreduce_groups = OrderedSet()\n        for eqn_idx, eqn in enumerate(self.eqns):\n            if eqn_idx in removed_eqns:\n                continue\n            if (eqn_idx in self.eqn_mesh and len(self.eqn_mesh[eqn_idx]) > 1 and\n                    self._reducable(eqn)):\n                (mesh_vars, final_var, removed,\n                 literals) = self._reducable_chain_lookup(eqn_idx, num_mesh)\n                if mesh_vars is None:\n                    # Only one eqn matches the pattern, skip it\n                    continue\n                removed_eqns.update(removed)\n                appended_eqns, allreduce_group = self._rewrite_eqns(\n                    eqn.primitive, mesh_vars, gensym_fn, final_var, literals)\n                new_eqns_before_var[final_var] = appended_eqns\n                allreduce_groups.add(tuple(allreduce_group))\n        if len(allreduce_groups) > 1:\n            raise NotImplementedError()\n        new_eqns = []\n        for eqn_idx, eqn in enumerate(self.eqns):\n            if eqn_idx in removed_eqns:\n                continue\n            outv = eqn.outvars[0] if len(eqn.outvars) > 0 else None\n            # insert new eqns before the previous last available eqn\n            if (not (outv is None or isinstance(outv, DropVar)) and\n                    outv in new_eqns_before_var):\n                new_eqns.extend(new_eqns_before_var[outv])\n            else:\n                new_eqns.append(eqn)\n        return clone_jaxpr(self.jaxpr, eqns=new_eqns), tuple(allreduce_groups)\n\n    @staticmethod\n    def rewrite_allreduce(closed_jaxpr: ClosedJaxpr, rewrite_to_dummy,\n                          num_devices, gensym_fn):\n        \"\"\"For cross-mesh allreduce, rewrite its invar to make it legal.\"\"\"\n        vars = set()\n        new_eqns = []\n        vars.update([\n            inv for inv in closed_jaxpr.jaxpr.invars\n            if not isinstance(inv, Var)\n        ])\n        for eqn in closed_jaxpr.eqns:\n            if eqn.primitive == cross_mesh_allreduce_p:\n                new_invars = set(eqn.invars).intersection(vars)\n                assert len(new_invars) == 1\n                if rewrite_to_dummy:\n                    zero = _value_to_literal(0, eqn.outvars[0].aval.dtype)\n                    invs = list(new_invars) + [zero]\n                    new_eqn = new_jaxpr_eqn(invs, list(eqn.outvars), add_p, {})\n                else:\n                    if eqn.params['type'] == add_p:\n                        inv = list(new_invars)[0]\n                        outv = gensym_fn(inv.aval)\n                        div_eqn = new_jaxpr_eqn([\n                            inv,\n                            _value_to_literal(num_devices, inv.aval.dtype)\n                        ], [outv], div_p, {})\n                        new_eqns.append(div_eqn)\n                        new_invars = [outv]\n                    new_eqn = new_jaxpr_eqn(list(new_invars), list(eqn.outvars),\n                                            eqn.primitive, dict(eqn.params))\n                new_eqns.append(new_eqn)\n            else:\n                new_eqns.append(eqn)\n            for v in eqn.outvars:\n                if not isinstance(v, DropVar):\n                    vars.add(v)\n        return clone_jaxpr(closed_jaxpr, eqns=new_eqns)\n\n\ndef _no_allreduce(eqns):\n    for eqn in eqns:\n        if eqn.primitive == cross_mesh_allreduce_p:\n            return False\n    return True\n\n\ndef slice_apply_gradient(closed_jaxpr: ClosedJaxpr, grad_mesh: Dict[Var, int],\n                         outvar_mesh: Dict[Var, OrderedSet[int]], num_mesh,\n                         num_stage, donation_mapping: Dict[Var, Var], gensym_fn,\n                         skip_cross_mesh_allreduce, mesh_num_devices):\n    \"\"\"\n    Slice the apply gradient jaxpr based on mesh allocation information.\n\n    Args:\n        closed_jaxpr: closed jaxpr of apply_gradient function.\n        grad_mesh: some invars should be at certain mesh;\n            If not in the dict, the variable should be a global parameter.\n        outvar_mesh: some outvars should be at certain mesh.\n        num_mesh: number of meshes. If a mesh does not have apply gradient\n          computation, add an empty jaxpr\n        num_stage: number of stages in the apply gradient computation.\n        donation_mapping: donation mapping for global invars\n        skip_cross_mesh_allreduce: Skip cross mesh allreduce in profiling.\n\n    Returns:\n        jaxprs(List[ClosedJaxpr]): The i-th ClosedJaxpr runs at the i-th\n          cluster.\n        mesh_assignment(Dict[int, int]): The i-th ClosedJaxpr runs at the\n          mesh_assignment[i]-th cluster.\n        allreduce_groups(Tuple[Tuple[int]]): Groups of mesh ids that need to\n          be in the same allreduce group to perform cross-mesh allreduce.\n    \"\"\"\n    var_mesh = {var: OrderedSet([mesh]) for var, mesh in grad_mesh.items()}\n    for var in outvar_mesh:\n        var_mesh.setdefault(var, OrderedSet()).update(outvar_mesh[var])\n    # TODO(yonghao): running the split multiple times until no new splits\n    closed_jaxpr, allreduce_groups = ApplyGradRewriter(\n        closed_jaxpr, var_mesh).split_replicated_eqns(gensym_fn, num_mesh)\n    eqn_mesh, var_mesh = _init_eqn_var_mesh(closed_jaxpr, var_mesh)\n    changed = True\n    _propagate_with_donation(closed_jaxpr, donation_mapping, var_mesh)\n    while changed:\n        changed = _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping,\n                                                 eqn_mesh, var_mesh)\n    changed = _forward_propagate_at_mesh(closed_jaxpr, eqn_mesh, var_mesh,\n                                         False)\n    while changed:\n        changed = _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping,\n                                                 eqn_mesh, var_mesh)\n    changed = _forward_propagate_at_mesh(closed_jaxpr, eqn_mesh, var_mesh, True)\n    while changed:\n        changed = _reverse_propagate_var_at_mesh(closed_jaxpr, donation_mapping,\n                                                 eqn_mesh, var_mesh)\n\n    sliced_eqns = [[] for _ in range(num_mesh)]\n    for eqn_idx, eqn in enumerate(closed_jaxpr.eqns):\n        if eqn_mesh[eqn_idx]:\n            for mesh in eqn_mesh[eqn_idx]:\n                sliced_eqns[mesh].append(eqn)\n\n    # grouping invars and outvars\n    invars, outvars, consts, constvars = _apply_grad_group_vars(\n        closed_jaxpr, var_mesh, num_mesh)\n\n    jaxprs = []\n    mesh_assignment = {}\n\n    for i in range(num_mesh):\n        if not outvars[i] and _no_allreduce(sliced_eqns[i]):\n            continue\n        computation_idx = num_stage + len(jaxprs)\n        # assign the current computation into mesh i\n        mesh_assignment[computation_idx] = i\n        sliced = Jaxpr(constvars[i], invars[i], outvars[i], sliced_eqns[i])\n        closed_jaxpr = ClosedJaxpr(sliced, consts[i])\n        num_devices = None if skip_cross_mesh_allreduce else mesh_num_devices[i]\n        closed_jaxpr = ApplyGradRewriter.rewrite_allreduce(\n            closed_jaxpr, skip_cross_mesh_allreduce, num_devices, gensym_fn)\n        jaxprs.append(closed_jaxpr)\n\n    return jaxprs, mesh_assignment, allreduce_groups\n\n\ndef apply_grad_add_marker(jaxprs: Sequence[ClosedJaxpr],\n                          apply_in_to_acc_out: Dict[Var, Var],\n                          gensym_fn,\n                          computation=False):\n    \"\"\"Add pipeline markers for sliced apply grads, keep invars and outvars\n    still unless.\n\n    The invar is in apply_in_to_acc_out or invar is outvar:\n    In the first case, the final invar follows the apply_in_to_acc_out;\n    In the second case, the final outvar is recorded in outvar_map.\n\n    Args:\n        jaxprs: sliced apply grads.\n        apply_in_to_acc_out: which output of accumulate grad corresponds to the\n            invar of apply grad\n        gensym_fn: gensym function of the whole jaxpr.\n        computation: output JaxPipelineComputation or ClosedJaxpr.\n    \"\"\"\n    results = []\n    outvar_map = {}\n    for i, jaxpr in enumerate(jaxprs):\n        new_map = {}\n        for invar in jaxpr.jaxpr.invars:\n            if invar not in apply_in_to_acc_out:\n                new_map[invar] = gensym_fn(invar.aval)\n        for outvar in jaxpr.jaxpr.outvars:\n            if not isinstance(outvar, Var):\n                raise NotImplementedError(\n                    'outvar of apply grad cannot be literal')\n            if outvar in jaxpr.jaxpr.invars:\n                if outvar not in outvar_map:\n                    outvar_map[outvar] = gensym_fn(outvar.aval)\n                continue\n            new_map[outvar] = gensym_fn(outvar.aval)\n        replaced = replace_all_with(jaxpr, new_map).jaxpr\n        new_invars = [\n            get_var_mapping(apply_in_to_acc_out, var)\n            for var in jaxpr.jaxpr.invars\n        ]\n        new_outvars = [\n            get_var_mapping(outvar_map, var) for var in jaxpr.jaxpr.outvars\n        ]\n        name = f'{i}_{APPLY_GRAD_MARKER_SUFFIX}'\n        start_marker = mark_pipeline_jaxpreqn(new_invars,\n                                              replaced.invars,\n                                              name=name,\n                                              mark_type='start')\n        end_marker = mark_pipeline_jaxpreqn(replaced.outvars,\n                                            new_outvars,\n                                            name=name,\n                                            mark_type='end')\n        new_eqns = [start_marker] + replaced.eqns + [end_marker]\n        if computation:\n            results.append(\n                JaxPipelineComputation(\n                    name, new_invars, new_outvars, new_eqns,\n                    dict(zip(jaxpr.jaxpr.constvars, jaxpr.consts))))\n        else:\n            new_jaxpr = clone_jaxpr(jaxpr, new_invars, new_outvars, new_eqns)\n            results.append(new_jaxpr)\n    outvar_map.update(apply_in_to_acc_out)\n    return results, outvar_map\n\n\ndef get_var_to_mesh(invars: Sequence[Var],\n                    computations: Sequence[JaxPipelineComputation],\n                    computation_to_mesh: Dict[int, int], apply_in_to_acc_out):\n    \"\"\"Get the mapping from variables to mesh.\"\"\"\n    # TODO(yonghao): now assume all gradients are variables(not literal)\n    outvar2mesh = {}\n    for i, computation in enumerate(computations):\n        for var in computation.outvars:\n            if isinstance(var, Var):\n                outvar2mesh[var] = computation_to_mesh[i]\n    return {\n        invar: outvar2mesh[apply_in_to_acc_out[invar]]\n        for invar in invars\n        if ((invar in apply_in_to_acc_out) and\n            (apply_in_to_acc_out[invar] in outvar2mesh))\n    }\n"
  },
  {
    "path": "alpa/pipeline_parallel/compile_executable.py",
    "content": "\"\"\"Compile executables for pipeshard parallelism.\"\"\"\nimport dataclasses\nimport logging\nimport time\nfrom typing import Callable, Sequence, Optional\n\nfrom jax import linear_util as lu\nfrom jax._src.lib import xla_client as xc\nfrom jax.core import gensym, AbstractValue, ClosedJaxpr\nfrom jax.interpreters import pxla\nfrom jax.tree_util import PyTreeDef\n\nfrom alpa.device_mesh import VirtualPhysicalMesh\nfrom alpa.global_env import global_config\nfrom alpa.pipeline_parallel.pipeshard_executable import PipeshardDriverExecutable\nfrom alpa.pipeline_parallel.runtime_emitter import (\n    OverlapFriendlyPipelineInstEmitter, PipelineInstEmitter)\nfrom alpa.pipeline_parallel.schedules import create_pipeline_schedule\nfrom alpa.pipeline_parallel.computation import (\n    create_donation_mapping, generate_computations_from_modules,\n    generate_sharded_xla_computations,\n    generate_sharded_xla_computations_arguments, get_donatable_intermediate,\n    mark_missing_vars_in_backward_computation_pipeline_marks, pipeline_dce,\n    slice_closed_jaxpr_by_full_pipeline_marks, split_donate_invars,\n    XlaShardedPipelineComputation)\nfrom alpa.pipeline_parallel.apply_grad import (\n    apply_grad_get_mean, compute_grad_to_accumulate_grad,\n    process_apply_gradient, split_compute_grad_and_apply_grad)\nfrom alpa.pipeline_parallel.layer_construction import LayerOption\nfrom alpa.pipeline_parallel.schedules import gen_dependency_with_stages\nfrom alpa.pipeline_parallel.stage_construction import (\n    cluster_layers_and_slice_mesh, StageOption)\nfrom alpa.pipeline_parallel.stage_profiling import CompileWorkerPool\nfrom alpa.shard_parallel.auto_sharding import (AutoShardingOption,\n                                               hlo_sharding_to_sharding_spec)\nfrom alpa.shard_parallel.manual_sharding import (ManualShardingOption,\n                                                 ParsedManualShardingOption,\n                                                 get_flatten_axis_resources,\n                                                 get_intermediate_parsed_spec,\n                                                 parsed_spec_to_opsharding)\nfrom alpa.util import (get_var_mapping, trace_jaxpr_with_micro_batch,\n                       OrderedSet, GradFuncTransformContext)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\ndef compile_pipeshard_executable(\n        fun: lu.WrappedFun, in_tree: PyTreeDef,\n        out_tree_thunk: Callable[[], PyTreeDef], static_argnums: Sequence[int],\n        donated_invars: Sequence[bool], batch_invars: Sequence[bool],\n        virtual_mesh: VirtualPhysicalMesh, num_microbatch: int,\n        pipeline_schedule: str, default_as_option: AutoShardingOption,\n        layer_option: LayerOption, stage_option: StageOption,\n        global_input_shardings: Optional[Sequence[pxla.ShardingSpec]],\n        stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]],\n        manual_shard_options: Optional[ManualShardingOption],\n        *avals: Sequence[AbstractValue]):\n    \"\"\"\n    Compile a callable for pipeshard parallel which combines\n    pipeline parallelism and 2d shard parallelsim.\n\n    Args:\n        fun: The function to be parallelized.\n        global_input_shardings: Forcibly set sharding specs of global\n          input vars.\n        stage_input_shardings: Forcibly set sharding specs of input vars of\n          each stage.\n        manual_sharding_options: pjit style sharding constraints of global input\n          vars.\n    \"\"\"\n    if global_config.backend == \"tpu\":\n        raise NotImplementedError(\"Pipeshard Parallel for tpu is not supported\")\n    debug_compilation_time(None)\n    name_base = f\"{fun.__name__}_pipeshard_parallel\"\n\n    # Apply layer construction to add pipeline markers.\n    with GradFuncTransformContext(layer_option.transform):\n        if pipeline_schedule == \"inference\":\n            f_backup = fun.f\n            fun.f = layer_option.transform(fun.f)\n\n        # Trace the function with a micro batch to get the jaxpr.\n        closed_jaxpr, micro_batch_size = trace_jaxpr_with_micro_batch(\n            fun, batch_invars, num_microbatch, avals)\n\n        # Trace again with a full batch.\n        # The full batch is used to derive the reduction operator across\n        # micro batches (e.g., addition, concatenation).\n        if num_microbatch > 1:\n            for store in fun.stores:\n                if store:\n                    store.reset()\n            full_batch_closed_jaxpr, _ = trace_jaxpr_with_micro_batch(\n                fun, batch_invars, 1, avals)\n        else:\n            full_batch_closed_jaxpr = None\n\n        if pipeline_schedule == \"inference\":\n            fun.f = f_backup\n    debug_compilation_time(\"trace\")\n\n    # flatten manual sharding axis resources\n    out_tree = out_tree_thunk()\n    if manual_shard_options is not None:\n        assert global_input_shardings is None\n        parsed_ms_option = get_flatten_axis_resources(manual_shard_options,\n                                                      in_tree, out_tree)\n    else:\n        parsed_ms_option = None\n    pipeshard_config = compile_pipeshard_executable_internal(\n        closed_jaxpr, full_batch_closed_jaxpr, micro_batch_size, donated_invars,\n        batch_invars, virtual_mesh, num_microbatch, pipeline_schedule,\n        default_as_option, stage_option, name_base, global_input_shardings,\n        None, stage_input_shardings, parsed_ms_option)\n\n    executable = PipeshardDriverExecutable(\n        mesh_group=virtual_mesh.launched_physical_mesh_group,\n        pipeshard_config=pipeshard_config,\n        num_batch=num_microbatch,\n        layer_option=layer_option,\n        in_tree=in_tree,\n        out_tree=out_tree,\n        static_argnums=static_argnums)\n    debug_compilation_time(\"driver executable\")\n    return executable\n\n\ndef compile_pipeshard_executable_internal(\n        closed_jaxpr: ClosedJaxpr,\n        full_batch_closed_jaxpr: Optional[ClosedJaxpr], micro_batch_size: int,\n        donated_invars: Sequence[bool], batch_invars: Sequence[bool],\n        virtual_mesh: VirtualPhysicalMesh, num_microbatch: int,\n        pipeline_schedule: str, default_as_option: AutoShardingOption,\n        stage_option: StageOption, name_base: str,\n        global_input_shardings: Optional[Sequence[pxla.ShardingSpec]],\n        global_output_shardings: Optional[Sequence[pxla.ShardingSpec]],\n        stage_input_shardings: Optional[Sequence[Sequence[pxla.ShardingSpec]]],\n        parsed_manual_sharding_option: Optional[ParsedManualShardingOption]):\n    \"\"\"\n    Args:\n        fun: The function to be parallelized.\n        global_input_shardings: Forcibly set sharding specs of global\n          input vars.\n        global_output_shardings: Forcibly set sharding specs of global\n          output vars.\n        stage_input_shardings: Forcibly set sharding specs of input vars of\n          each stage.\n    \"\"\"\n    global_invars = closed_jaxpr.jaxpr.invars\n    gensym_func = gensym([closed_jaxpr.jaxpr])\n    inference_mode = (pipeline_schedule == \"inference\")\n\n    (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr,\n     microbatch_bound, reduction_vector, post_microbatch_bound,\n     accumulator_mapping, acc_grad_invars,\n     acc_grad_outvars) = (split_and_process_layers(closed_jaxpr,\n                                                   full_batch_closed_jaxpr,\n                                                   num_microbatch,\n                                                   inference_mode, gensym_func))\n\n    debug_compilation_time(\"jaxpr operations\")\n\n    (jax_apply_layers,\n     apply_grad_global_info) = slice_apply_grad_for_stage_construction(\n         jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound, global_invars,\n         global_outvars, donated_invars, accumulator_mapping, gensym_func,\n         inference_mode)\n\n    # Construct pipeline stages by merging layers\n    (jax_pipeline_stages, stage_to_mesh, sliced_virtual_meshes,\n     manual_stage_option) = cluster_layers_and_slice_mesh(\n         jax_pipeline_layers, virtual_mesh, accumulator_mapping,\n         acc_grad_invars, acc_grad_outvars, num_microbatch, micro_batch_size,\n         jax_apply_layers, apply_grad_global_info, pipeline_schedule,\n         default_as_option, stage_option)\n    num_meshes = len(sliced_virtual_meshes)\n    debug_compilation_time(\"stage construction\")\n\n    # Process apply_gradient and donation\n    num_devices = [vmesh.num_devices for vmesh in sliced_virtual_meshes]\n    (sliced_apply_grad_stages, apply_grad_placement,\n     global_outvars, allreduce_groups) = process_apply_gradient(\n         apply_grad_jaxpr, microbatch_bound, jax_pipeline_stages, stage_to_mesh,\n         gensym_func, num_meshes, global_invars, global_outvars, donated_invars,\n         False, num_devices)\n    jax_all_stages = jax_pipeline_stages + sliced_apply_grad_stages\n\n    donation_mapping = create_donation_mapping(accumulator_mapping,\n                                               donated_invars, global_invars,\n                                               global_outvars)\n    donate_invars_dict, jax_all_stages = split_donate_invars(\n        donation_mapping, jax_all_stages, gensym_func)\n    global_outvars, concat_vars_mapping = _rewrite_global_outvars_post_concate(\n        global_outvars, reduction_vector, microbatch_bound,\n        post_microbatch_bound, gensym_func)\n    debug_compilation_time(\"apply grad\")\n\n    # Generate pipeline schedule and placement\n    dependency, fwd_intermediates = gen_dependency_with_stages(\n        jax_pipeline_stages, num_meshes, sliced_apply_grad_stages)\n    schedule = create_pipeline_schedule(\n        pipeline_schedule,\n        dependency=dependency,\n        meshes=sliced_virtual_meshes,\n        apply_grad_placement=apply_grad_placement,\n        num_batch=num_microbatch)\n\n    # Forcibly set the sharding specs of global invars and outvars.\n    # FIXME(yonghao): the invar can appear on multiple meshes and thus different\n    # sharding specs\n    if global_input_shardings:\n        assert len(global_input_shardings) == len(global_invars)\n        input_sharding_dict = dict(zip(global_invars, global_input_shardings))\n    else:\n        input_sharding_dict = {}\n    if global_output_shardings:\n        assert len(global_output_shardings) == len(global_outvars)\n        output_sharding_dict = dict(zip(global_outvars,\n                                        global_output_shardings))\n    else:\n        output_sharding_dict = {}\n    if parsed_manual_sharding_option is not None:\n        assert (global_input_shardings is None and\n                global_output_shardings is None)\n        (input_sharding_dicts,\n         output_sharding_dicts) = get_manual_input_output_sharding_specs(\n             jax_all_stages, manual_stage_option.submesh_logical_shapes,\n             parsed_manual_sharding_option, global_invars, global_outvars,\n             schedule.stage_mesh_mapping, fwd_intermediates)\n    else:\n        input_sharding_dicts = [input_sharding_dict] * num_meshes\n        output_sharding_dicts = [output_sharding_dict] * num_meshes\n\n    # Call auto-sharding pass to shard each stage\n    xla_stages, total_flops = shard_each_stage(\n        jax_all_stages, sliced_virtual_meshes, schedule, num_meshes,\n        accumulator_mapping, global_invars, acc_grad_outvars,\n        donate_invars_dict, num_microbatch,\n        manual_stage_option.submesh_logical_shapes,\n        manual_stage_option.submesh_autosharding_option_dicts,\n        default_as_option, input_sharding_dicts, output_sharding_dicts,\n        stage_input_shardings, name_base, gensym_func)\n    total_flops *= num_microbatch\n    debug_compilation_time(\"shard stages\")\n\n    # Launch the physical mesh group\n    if virtual_mesh.launched_physical_mesh_group is None:\n        virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes)\n    debug_compilation_time(\"launch meshes\")\n\n    # Wrap all things into a distributed runtime\n    # TODO(yonghao): use virtual mesh instead of launched physical group\n    emitter_kwargs = dict(stages=xla_stages,\n                          global_invars=global_invars,\n                          grad_dummy_invars=accumulator_mapping,\n                          global_outvars=global_outvars,\n                          concat_vars_mapping=concat_vars_mapping,\n                          mesh_group=virtual_mesh.launched_physical_mesh_group,\n                          schedule=schedule,\n                          is_batch=batch_invars,\n                          num_batch=num_microbatch,\n                          default_auto_sharding_option=default_as_option,\n                          manual_stage_option=manual_stage_option,\n                          flop_count=total_flops,\n                          allreduce_groups=allreduce_groups)\n    if pipeline_schedule == \"1f1b_overlap_friendly\":\n        emitter_cls = OverlapFriendlyPipelineInstEmitter\n        emitter_kwargs[\"outvar_def_order\"] = [\n            stage.outvars_def_order() for stage in jax_all_stages\n        ]\n    else:\n        emitter_cls = PipelineInstEmitter\n    pipeshard_config = emitter_cls(**emitter_kwargs).compile()\n\n    debug_compilation_time(\"runtime emitter\")\n    return pipeshard_config\n\n\ndef split_and_process_layers(closed_jaxpr, full_batch_closed_jaxpr,\n                             num_microbatch, inference_mode, gensym_func):\n    \"\"\"Split and process the input jaxpr with the following steps:\n\n    1. Split the jaxpr into the compute grad part and the apply grad part.\n    2. Transform the compute grad jaxpr to a accumulate grad jaxpr.\n    3. Split the accumulate grad jaxpr into forward and backward pipeline\n       layers.\n    4. Divide the accumulated gradient by the number of microbatches at the\n       start of accumulate gradient.\n\n    \"\"\"\n\n    # Split the jaxpr into compute_grad and apply_grad\n    (closed_jaxpr, compute_grad_jaxpr, apply_grad_jaxpr,\n     microbatch_bound) = split_compute_grad_and_apply_grad(\n         closed_jaxpr, gensym_func, num_microbatch, inference_mode)\n    global_outvars = closed_jaxpr.jaxpr.outvars\n\n    # Transform compute_grad to accumulate_grad\n    # FIXME(yonghao): use apply grad jaxpr returned by this function\n    (reduction_vector, post_microbatch_bound,\n     _) = _get_full_batch_apply_grad(full_batch_closed_jaxpr, microbatch_bound,\n                                     num_microbatch, inference_mode)\n    (acc_grad_jaxpr, microbatch_bound,\n     accumulator_mapping) = compute_grad_to_accumulate_grad(\n         compute_grad_jaxpr, microbatch_bound, reduction_vector, gensym_func,\n         num_microbatch)\n\n    # Slice the jaxpr into layers\n    acc_grad_invars = acc_grad_jaxpr.jaxpr.invars\n    acc_grad_outvars = acc_grad_jaxpr.jaxpr.outvars\n\n    jax_pipeline_layers = slice_closed_jaxpr_by_full_pipeline_marks(\n        acc_grad_jaxpr)\n    if not inference_mode:\n        jax_pipeline_layers = (\n            mark_missing_vars_in_backward_computation_pipeline_marks(\n                jax_pipeline_layers, acc_grad_invars, acc_grad_outvars,\n                gensym_func))\n    # TODO(yonghao): remove this pass. we can clear these vars when rewriting\n    #   compute grad to accumulate grad\n    jax_pipeline_layers = pipeline_dce(jax_pipeline_layers, acc_grad_outvars)\n\n    # Add compute mean and slice apply-grad stages\n    # FIXME (zhuohan): get_mean only works when we use jax.mean to\n    #                  calculate loss. It will fail if we use sum.\n    apply_grad_jaxpr, global_outvars = apply_grad_get_mean(\n        apply_grad_jaxpr, global_outvars, microbatch_bound.outvars, gensym_func,\n        num_microbatch, reduction_vector)\n\n    return (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr,\n            microbatch_bound, reduction_vector, post_microbatch_bound,\n            accumulator_mapping, acc_grad_invars, acc_grad_outvars)\n\n\ndef get_manual_input_output_sharding_specs(stages, mesh_shapes, ms_option,\n                                           global_invars, global_outvars,\n                                           stage_to_mesh, fwd_intermediates):\n    \"\"\"\n    Split user assigned input and output PartitionSpec into sharding specs for\n    each pipeline stage.\n    \"\"\"\n    invar_set = set(global_invars)\n    outvar_set = set(global_outvars)\n    var_to_pspec = {}\n    handle_invar = False\n    handle_outvar = False\n    # Add global input and output's parsed partition spec.\n    if ms_option.in_parsed_pspec is not None:\n        var_to_pspec.update(dict(zip(global_invars, ms_option.in_parsed_pspec)))\n        handle_invar = True\n    if ms_option.out_parsed_pspec is not None:\n        var_to_pspec.update(\n            dict(zip(global_outvars, ms_option.out_parsed_pspec)))\n        handle_outvar = True\n    # Add pipeline intermediate's parsed partition spec.\n    intermediate_to_pspec = {}\n    if ms_option.pipeline_intermediate_axes is not None:\n        for v in fwd_intermediates:\n            # TODO: This is a simple heuristic: we simply replicate 1d tensors.\n            if len(v.aval.shape) <= 1:\n                continue\n            intermediate_to_pspec[v] = get_intermediate_parsed_spec(\n                ms_option.pipeline_intermediate_axes, len(v.aval.shape))\n\n    submesh_axis_names = ms_option.submesh_axis_names\n    if submesh_axis_names is None:\n        submesh_axis_names = [ms_option.mesh_axis_names] * len(mesh_shapes)\n\n    def get_vars_to_sharding_specs(variables, mesh_shape, mesh_axis_names):\n        parsed_specs = [\n            (var_to_pspec[v] if v in var_to_pspec else intermediate_to_pspec[v])\n            for v in variables\n        ]\n        avals = [v.aval for v in variables]\n        var_op_shardings = parsed_spec_to_opsharding(parsed_specs, avals,\n                                                     mesh_shape,\n                                                     mesh_axis_names)\n        var_sharding_specs = [\n            hlo_sharding_to_sharding_spec(xc.HloSharding.from_proto(ops), aval,\n                                          mesh_shape)\n            for ops, aval in zip(var_op_shardings, avals)\n        ]\n        return dict(zip(variables, var_sharding_specs))\n\n    invar_shardings = [{}] * len(mesh_shapes)\n    outvar_shardings = [{}] * len(mesh_shapes)\n    for stage_idx, stage in enumerate(stages):\n        mesh_idx = stage_to_mesh[stage_idx]\n        assert len(mesh_idx) == 1\n        mesh_idx = list(mesh_idx)[0]\n        mesh_shape = mesh_shapes[mesh_idx]\n        mesh_axis_names = submesh_axis_names[mesh_idx]\n        # invars\n        if handle_invar:\n            invar_in_global = [var for var in stage.invars if var in invar_set]\n            # add intermediate vars\n            intermediate_var = [\n                var for var in stage.invars if var in intermediate_to_pspec\n            ]\n            invars = invar_in_global + intermediate_var\n            stage_invar_shardings = get_vars_to_sharding_specs(\n                invars, mesh_shape, mesh_axis_names)\n        else:\n            stage_invar_shardings = {}\n        # outvars\n        if handle_outvar:\n            outvar_in_global = [\n                var for var in stage.outvars if var in outvar_set\n            ]\n            stage_outvar_shardings = get_vars_to_sharding_specs(\n                outvar_in_global, mesh_shape, mesh_axis_names)\n        else:\n            stage_outvar_shardings = {}\n        invar_shardings[mesh_idx].update(stage_invar_shardings)\n        outvar_shardings[mesh_idx].update(stage_outvar_shardings)\n    return invar_shardings, outvar_shardings\n\n\ndef shard_each_stage(jax_all_stages, virtual_meshes, schedule, num_meshes,\n                     accumulator_mapping, global_invars, acc_grad_outvars,\n                     donate_invars_dict, num_microbatch, logical_mesh_shapes,\n                     autosharding_option_dicts, default_as_option,\n                     input_sharding_dicts, output_sharding_dicts,\n                     stage_input_shardings, name_base, gensym_func):\n    \"\"\"Run intra-op parallelism compilation for a stage.\"\"\"\n    # Initialize donation mapping\n    stage_dict = [[] for _ in range(num_meshes)]\n    stage_id_dict = [[] for _ in range(num_meshes)]\n    dummy_stage_id_dict = [[] for _ in range(num_meshes)]\n    donatable_dict = [[] for _ in range(num_meshes)]\n    mesh_stage_mapping = schedule.mesh_stage_mapping\n    donatable_list = get_donatable_intermediate(\n        jax_all_stages, mesh_stage_mapping,\n        OrderedSet(global_invars).union(accumulator_mapping.keys()))\n\n    if stage_input_shardings is None:\n        stage_input_shardings = [None for _ in range(num_meshes)]\n    assert len(stage_input_shardings) == num_meshes\n\n    for i, stage in enumerate(jax_all_stages):\n        mesh_indices = list(schedule.stage_placement(i))\n        assert len(mesh_indices) == 1\n        mesh_idx = mesh_indices[0]\n        if len(stage.outvars) == 0:\n            # This is a dummy stage, we don't need to shard it\n            dummy_stage_id_dict[mesh_idx].append(i)\n            continue\n        stage_id_dict[mesh_idx].append(i)\n        stage_dict[mesh_idx].append(stage)\n        donatable_dict[mesh_idx].append(donatable_list[i])\n\n    # Call auto-sharding pass on each stage\n    distributed_compile = global_config.pipeline_distributed_compile\n    xla_stages = [None] * len(jax_all_stages)\n    if distributed_compile:\n        compile_workers = CompileWorkerPool(num_meshes)\n        compile_fn = lambda w, v: w.run_auto_sharding_pass.remote(*v)  # pylint: disable=unnecessary-lambda-assignment\n        compile_intermediate = [None] * num_meshes\n    total_flops = 0\n    for mesh_idx in range(num_meshes):\n        virtual_mesh = virtual_meshes[mesh_idx]\n        logical_mesh = virtual_mesh.get_logical_mesh(\n            logical_mesh_shapes[mesh_idx])\n        autosharding_option = dataclasses.replace(\n            default_as_option, **autosharding_option_dicts[mesh_idx])\n\n        # Predefined shardings. stage_input_sharding should have shardings for\n        # all parameters, while the sharding dict can have only a portion of\n        # all parameters.\n        input_sharding_dict = input_sharding_dicts[mesh_idx]\n        output_sharding_dict = output_sharding_dicts[mesh_idx]\n        stage_input_sharding = stage_input_shardings[mesh_idx]\n\n        # Setup dummy stages\n        for i in dummy_stage_id_dict[mesh_idx]:\n            xla_stages[i] = XlaShardedPipelineComputation.dummy_computation(\n                jax_all_stages[i].name, logical_mesh.shape, gensym_func)\n\n        stage_donate_invars = [\n            donate_invars_dict[stage_idx]\n            for stage_idx in stage_id_dict[mesh_idx]\n        ]\n        if distributed_compile:\n            hlo, flops = (generate_sharded_xla_computations_arguments(\n                f\"{name_base}_mesh_{mesh_idx}\", stage_dict[mesh_idx],\n                stage_donate_invars, input_sharding_dict, output_sharding_dict,\n                stage_input_sharding))\n            other_kwargs = {\n                \"logical_mesh\": logical_mesh,\n                \"return_mode\": \"stages\",\n                \"as_option\": autosharding_option,\n                \"num_micro_batches\": num_microbatch,\n            }\n            compile_workers.submit(compile_fn, (mesh_idx, hlo, other_kwargs))\n            compile_intermediate[mesh_idx] = (stage_dict[mesh_idx],\n                                              stage_donate_invars)\n            total_flops += flops\n        else:\n            sharded_xla_stages, flops = generate_sharded_xla_computations(\n                f\"{name_base}_mesh_{mesh_idx}\", stage_dict[mesh_idx],\n                stage_donate_invars, donatable_dict[mesh_idx], acc_grad_outvars,\n                num_microbatch, logical_mesh, autosharding_option,\n                input_sharding_dict, output_sharding_dict, stage_input_sharding)\n            total_flops += flops\n            for i, xla_stage in zip(stage_id_dict[mesh_idx],\n                                    sharded_xla_stages):\n                xla_stages[i] = xla_stage\n\n    if distributed_compile:\n        for _ in range(num_meshes):\n            mesh_idx, (computation_names, computation_hlos,\n                       stage_plan) = compile_workers.get_next_unordered()\n            jax_computations, computation_donate_invars = compile_intermediate[\n                mesh_idx]\n            sharded_xla_stages = generate_computations_from_modules(\n                jax_computations, computation_names, computation_hlos,\n                computation_donate_invars, donatable_dict[mesh_idx],\n                acc_grad_outvars, stage_plan)\n            for i, xla_stage in zip(stage_id_dict[mesh_idx],\n                                    sharded_xla_stages):\n                xla_stages[i] = xla_stage\n        compile_workers.shutdown()\n\n    return xla_stages, total_flops\n\n\ndef slice_apply_grad_for_stage_construction(pipeline_layers, apply_grad_jaxpr,\n                                            microbatch_bound, global_invars,\n                                            global_outvars, donated_invars,\n                                            accumulator_mapping, gensym_func,\n                                            inference_mode):\n    if inference_mode:\n        num_layers = len(pipeline_layers)\n        num_mesh = num_layers\n        layer_to_mesh = list(range(num_mesh))\n    else:\n        num_layers = len(pipeline_layers)\n        assert len(pipeline_layers) % 2 == 0\n        num_mesh = num_layers // 2\n        layer_to_mesh = (list(range(num_mesh)) +\n                         list(reversed(range(num_mesh))))\n    (layers, apply_grad_placement, global_outvars,\n     _) = process_apply_gradient(apply_grad_jaxpr, microbatch_bound,\n                                 pipeline_layers, layer_to_mesh, gensym_func,\n                                 num_mesh, global_invars, global_outvars,\n                                 donated_invars, True, None)\n    apply_grad_donation = create_donation_mapping(accumulator_mapping,\n                                                  donated_invars, global_invars,\n                                                  global_outvars)\n    wrap_layers = [None] * num_mesh\n    for layer_idx, mesh_idx in apply_grad_placement.items():\n        wrap_layers[mesh_idx] = layers[layer_idx - num_layers]\n    apply_grad_global_info = apply_grad_donation, global_outvars\n    return wrap_layers, apply_grad_global_info\n\n\ndef _get_full_batch_apply_grad(closed_jaxpr,\n                               microbatch_bound,\n                               num_microbatch,\n                               inference_mode,\n                               batch_dim=0):\n    \"\"\"\n    Compare the micro-batch jaxpr and full-batch jaxpr. Return whether\n    the out var's is reduced across micro-batches.\n\n    TODO(yonghao): the reduction vector should be created by a\n    more careful analysis.\n    \"\"\"\n    if num_microbatch == 1:\n        reduced_vector = [True] * len(microbatch_bound.outvars)\n        post_microbatch_bound = microbatch_bound\n        apply_grad_jaxpr = None\n        return reduced_vector, post_microbatch_bound, apply_grad_jaxpr\n\n    gensym_func = gensym([closed_jaxpr.jaxpr])\n    (_, _, apply_grad_jaxpr,\n     post_microbatch_bound) = (split_compute_grad_and_apply_grad(\n         closed_jaxpr, gensym_func, num_microbatch, inference_mode))\n    reduced_vector = []\n    for mb_var, var in zip(microbatch_bound.outvars,\n                           post_microbatch_bound.outvars):\n        microbatch_shape = mb_var.aval.shape\n        batch_shape = var.aval.shape\n        if microbatch_shape != batch_shape:\n            expected_microbatched_shape = list(batch_shape)\n            assert expected_microbatched_shape[batch_dim] % num_microbatch == 0\n            expected_microbatched_shape[batch_dim] //= num_microbatch\n            assert tuple(expected_microbatched_shape) == microbatch_shape\n            if len(apply_grad_jaxpr.eqns) > 0:\n                raise NotImplementedError(\n                    \"Some vars marked by gradient markers are not reduced \"\n                    \"but concatenated. This case in the training mode \"\n                    \"is not supported yet.\")\n        reduced_vector.append(microbatch_shape == batch_shape)\n\n    return reduced_vector, post_microbatch_bound, apply_grad_jaxpr\n\n\ndef _rewrite_global_outvars_post_concate(global_outvars, reduction_vector,\n                                         microbatch_bound,\n                                         post_microbatch_bound, gensym_func):\n    concat_vars_mapping = {}\n    for idx, reduce in enumerate(reduction_vector):\n        if not reduce:\n            var = microbatch_bound.outvars[idx]\n            actual_aval = post_microbatch_bound.outvars[idx].aval\n            concat_vars_mapping[gensym_func(actual_aval)] = var\n    reversed_mapping = {v: k for k, v in concat_vars_mapping.items()}\n    global_outvars = [\n        get_var_mapping(reversed_mapping, v) for v in global_outvars\n    ]\n    return global_outvars, concat_vars_mapping\n\n\n_tic = None\n\n\ndef debug_compilation_time(message):\n    \"\"\"Print compilation time for debugging.\"\"\"\n    global _tic\n    if message and global_config.print_compilation_time:\n        print(f\"compile_pipeshard_executable::{message}: \"\n              f\"{time.time() - _tic:.2f} s\")\n    _tic = time.time()\n"
  },
  {
    "path": "alpa/pipeline_parallel/computation.py",
    "content": "\"\"\"Pipeline computation definitions.\"\"\"\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass, field\nimport logging\nfrom typing import Sequence, Any, Dict, Optional\n\nfrom jax import jit\nfrom jax._src.lib import xla_bridge as xb, xla_extension as xe\nfrom jax._src.util import partial, safe_map\nfrom jax._src import dispatch\nfrom jax.core import (Atom, Var, JaxprEqn, Jaxpr, ClosedJaxpr, DropVar, Literal,\n                      jaxpr_as_fun, gensym, named_call_p, ShapedArray)\nfrom jax.interpreters import pxla\nimport numpy as np\n\nfrom alpa.mesh_executable import PartialGradAccMeshDriverExecutable\nfrom alpa.parallel_plan import StagePlan\nfrom alpa.pipeline_parallel.primitive_def import (mark_hook_jaxpreqn,\n                                                  pipeline_p,\n                                                  mark_pipeline_jaxpreqn)\nfrom alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass,\n                                               run_spmd_partitioner_pass,\n                                               get_input_output_sharding_specs,\n                                               hlo_sharding_to_sharding_spec,\n                                               AutoShardingOption)\nfrom alpa.global_env import global_config\nfrom alpa.util import (OrderedSet, clone_jaxpr, clone_jaxpr_eqn,\n                       get_compile_options, jaxpr_to_hlo,\n                       setup_computation_alias, compile_dummy_zero_constant,\n                       get_var_mapping, undefined_sharding_spec_proto,\n                       new_jaxpr_eqn, replicated_sharding_spec_proto)\nfrom alpa.wrapped_hlo import HloStatus, WrappedHlo\n\n# pylint: disable=redefined-builtin\nunsafe_map, map = map, safe_map  # type: ignore\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\n@dataclass\nclass PipelineComputation(ABC):\n    \"\"\"\n    Base class of pipeline computations.\n\n    Attributes:\n        name (str): The name of the pipeline computation.\n        invars (Sequence[Var]): The list of input variables, corresponding to\n            the order of the runnable inputs.\n        outvars (Sequence[Var]): The list of output variables, corresponding to\n            the order of the runnable outputs.\n    \"\"\"\n\n    name: str\n    invars: Sequence[Var] = field(default_factory=list)\n    outvars: Sequence[Var] = field(default_factory=list)\n\n    @abstractmethod\n    def get_runnable(self, mesh=None):\n        \"\"\"Compile the computation and get the runnable.\"\"\"\n        raise NotImplementedError()\n\n\n@dataclass\nclass StrVarPipelineComputation:\n    \"\"\"Stringified computation with all Set/Dict have string keys.\"\"\"\n\n    name: str\n    invars: Sequence[str]\n    outvars: Sequence[str]\n\n    @classmethod\n    def from_pipeline_computation(cls,\n                                  pipeline_computation: PipelineComputation):\n        \"\"\"Construct a StrVarPipelineComputation from a PipelineComputation.\"\"\"\n        return cls(\n            name=pipeline_computation.name,\n            invars=[repr(var) for var in pipeline_computation.invars],\n            outvars=[repr(var) for var in pipeline_computation.outvars],\n        )\n\n\n@dataclass\nclass JaxPipelineComputation(PipelineComputation):\n    \"\"\"\n    A pipeline computation defined by Jaxpr.\n\n    Attributes:\n        eqns (Sequence[JaxprEqn]): Jaxpr equations of the pipeline computation.\n        consts_dir: Dict[Atom, Any]: All the constants used in the pipeline\n            computation.\n    \"\"\"\n\n    eqns: Sequence[JaxprEqn] = field(default_factory=list)\n    consts_dir: Dict[Atom, Any] = field(default_factory=dict)\n\n    def closed_jaxpr(self) -> ClosedJaxpr:\n        \"\"\"\n        Get the closed Jaxpr of the pipeline computation.\n\n        Returns:\n            ClosedJaxpr: The result ClosedJaxpr.\n        \"\"\"\n        jaxpr = Jaxpr(\n            constvars=list(self.consts_dir.keys()),\n            invars=self.invars,\n            outvars=self.outvars,\n            eqns=self.eqns,\n        )\n        closed_jaxpr = ClosedJaxpr(jaxpr, list(self.consts_dir.values()))\n        return closed_jaxpr\n\n    def get_runnable(self, mesh=None):\n        \"\"\"Return a JIT callable of the pipeline computation.\"\"\"\n        closed_jaxpr = self.closed_jaxpr()\n        return jit(jaxpr_as_fun(closed_jaxpr))\n\n    @classmethod\n    def from_closed_jaxpr(cls, name, closed_jaxpr: ClosedJaxpr):\n        \"\"\"Construct a JaxPipelineComputation from a Jaxpr.\"\"\"\n        return cls(name=name,\n                   invars=closed_jaxpr.jaxpr.invars,\n                   outvars=closed_jaxpr.jaxpr.outvars,\n                   eqns=closed_jaxpr.eqns,\n                   consts_dir=dict(\n                       zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts)))\n\n    def outvars_def_order(self):\n        \"\"\"\n        Get the order of outvars by when they are defined in the jaxpr. This may\n        be not accurate because XLA optimizations may reorder it, but we only\n        focus on the order of activations which have data dependency so it's ok.\n        \"\"\"\n        outvars = self.outvars\n        assert self.eqns[-1].primitive is pipeline_p\n        assert tuple(self.eqns[-1].outvars) == tuple(outvars)\n        pre_marker_vars = self.eqns[-1].invars\n        pre_marker_vars = {v: idx for idx, v in enumerate(pre_marker_vars)}\n        final_order = []\n        for inv in self.invars:\n            if inv in pre_marker_vars:\n                final_order.append(pre_marker_vars[inv])\n        for eqn in self.eqns:\n            for var in eqn.outvars:\n                if not isinstance(var, DropVar) and var in pre_marker_vars:\n                    final_order.append(pre_marker_vars[var])\n        assert len(final_order) == len(outvars)\n        return [outvars[idx] for idx in final_order]\n\n\n@dataclass\nclass XlaPipelineComputation(PipelineComputation):\n    \"\"\"A pipeline computation defined by XLA HLO Module.\"\"\"\n\n    hlo: WrappedHlo = None\n\n    @classmethod\n    def from_jax_pipeline_computation(\n            cls, jax_pipeline_computation: JaxPipelineComputation):\n        \"\"\"\n        Construct a XlaPipelineComputation from a JaxPipelineComputation.\n\n        Args:\n            jax_pipeline_computation (JaxPipelineComputation): the source\n              JaxPipelineComputation.\n        \"\"\"\n        closed_jaxpr = jax_pipeline_computation.closed_jaxpr()\n        name = f\"pipeline_computation_{jax_pipeline_computation.name}\"\n        donated_invars = (False,) * len(jax_pipeline_computation.invars)\n        hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars)\n\n        return cls(\n            name=jax_pipeline_computation.name,\n            hlo=hlo,\n            invars=jax_pipeline_computation.invars,\n            outvars=jax_pipeline_computation.outvars,\n        )\n\n    def get_runnable(self, mesh=None):\n        \"\"\"Return a callable of the pipeline computation.\"\"\"\n        out_avals = [var.aval for var in self.outvars]\n        tuple_args = len(self.invars) > 100 and global_config.backend == \"tpu\"\n        backend = xb.get_backend(global_config.backend)\n        device = backend.get_default_device_assignment(1)[0]\n        options = get_compile_options(\n            num_replicas=1,\n            num_partitions=1,\n            device_assignment=(device.id,) if device else None,\n            use_spmd_partitioning=False,\n            parameter_is_tupled_arguments=tuple_args,\n            build_random_seed=global_config.compile_random_seed,\n        )\n\n        xla_computation = self.hlo.get_computation()\n        compiled = backend.compile(xla_computation, compile_options=options)\n        self.hlo.module = compiled.hlo_modules()[0]\n        self.hlo.status = HloStatus.FULLY_OPTIMIZED\n        # pylint: disable=protected-access\n        result_handler = dispatch._result_handler(backend, device, [(\n            aval,\n            True,\n        ) for aval in out_avals])\n        buffer_counts = (None if len(out_avals) == 1 else [\n            dispatch.aval_to_num_buffers(aval) for aval in out_avals\n        ])\n        kept_var_idx = range(len(self.invars))\n        return partial(\n            dispatch._execute_compiled,  # pylint: disable=protected-access\n            self.name,\n            compiled,\n            None,\n            buffer_counts,\n            result_handler,\n            False,\n            (),\n            kept_var_idx,\n            False)\n\n    def get_hlo_text(self):\n        \"\"\"Get the HLO text.\"\"\"\n        return self.hlo.to_string()\n\n\n@dataclass\nclass XlaShardedPipelineComputation(PipelineComputation):\n    \"\"\"\n    A pipeline computation defined by XLA HLO Module.\n    The XLA HLO is annotated by sharding spec.\n    \"\"\"\n\n    hlo: WrappedHlo = None\n    donated_invars: Sequence[bool] = None\n    stage_plan: StagePlan = None\n    input_sharding_specs: Sequence[pxla.ShardingSpec] = None\n    output_sharding_specs: Sequence[pxla.ShardingSpec] = None\n    output_acc_grad_indices: Sequence[int] = None\n    donatables: OrderedSet[Var] = None\n\n    @classmethod\n    def dummy_computation(cls, name, logical_mesh_shape, gensym_func):\n        \"\"\"Create a dummy computation.\"\"\"\n        stage_plan = StagePlan(global_config.compile_random_seed,\n                               logical_mesh_shape, 1, 1, AutoShardingOption(),\n                               None, 0)\n        sharding_annotated_hlo = compile_dummy_zero_constant()\n        outvar = gensym_func(ShapedArray((), np.dtype(np.int32)))\n        return cls(\n            name=name,\n            hlo=sharding_annotated_hlo,\n            stage_plan=stage_plan,\n            donated_invars=[],\n            invars=[],\n            outvars=[outvar],\n            output_acc_grad_indices=[],\n            donatables=OrderedSet(),\n        )\n\n    @classmethod\n    def from_auto_sharded_computation(\n            cls,\n            *,\n            jax_pipeline_computation: JaxPipelineComputation,\n            sharding_annotated_hlo: WrappedHlo,\n            stage_plan: StagePlan,\n            donated_invars: Sequence[bool] = None,\n            acc_grad_outvars: Sequence[Var] = (),\n            donatables: OrderedSet[Var] = None):\n        \"\"\"Run auto-sharding optimizer on a Jax pipeline computation.\"\"\"\n        if donatables is None:\n            donatables = OrderedSet()\n\n        if not donated_invars:\n            donated_invars = (False,) * len(jax_pipeline_computation.invars)\n\n        acc_grad_indices = [\n            out_idx\n            for out_idx, outvar in enumerate(jax_pipeline_computation.outvars)\n            if outvar in acc_grad_outvars\n        ]\n\n        return cls(name=jax_pipeline_computation.name,\n                   hlo=sharding_annotated_hlo,\n                   stage_plan=stage_plan,\n                   donated_invars=donated_invars,\n                   invars=jax_pipeline_computation.invars,\n                   outvars=jax_pipeline_computation.outvars,\n                   output_acc_grad_indices=acc_grad_indices,\n                   donatables=donatables)\n\n    def donate_intermediates(self, computation):\n        \"\"\"Donate intermediate variables.\"\"\"\n        # FIXME (yonghao): this function is not being used.\n        # get sharding annotated hlo module\n        hlo_module = computation.as_hlo_module()\n        donatable = OrderedSet(self.donatables)\n        # get sharding specs\n        hlo_module.infer_spmd_shardings()\n        avals = [var.aval for var in self.invars]\n        out_avals = [var.aval for var in self.outvars]\n        logical_mesh_shape = self.stage_plan.logical_mesh_shape\n        input_shardings = hlo_module.spmd_parameters_shardings()\n        input_sharding_specs = [\n            hlo_sharding_to_sharding_spec(proto_tuple, aval, logical_mesh_shape)\n            for (proto_tuple, aval) in zip(input_shardings, avals)\n        ]\n        output_shardings = hlo_module.spmd_output_sharding()\n        output_sharding_specs = hlo_sharding_to_sharding_spec(\n            output_shardings, out_avals, logical_mesh_shape)\n\n        num_donated = np.count_nonzero(self.donated_invars)\n        donatable_outvars = OrderedSet(self.outvars[num_donated:])\n        donated_invars = []\n        donated_outvars = []\n        var_indices = dict(zip(self.outvars, range(len(self.outvars))))\n        var_indices.update(dict(zip(self.invars, range(len(self.invars)))))\n        for idx, invar in enumerate(self.invars):\n            if invar not in donatable:\n                # not donatable\n                continue\n            if self.donated_invars[idx]:\n                # already donated\n                continue\n            for outvar in donatable_outvars:\n                if (invar.aval.shape == outvar.aval.shape and\n                        input_sharding_specs[var_indices[invar]]\n                        == output_sharding_specs[var_indices[outvar]]):\n                    donated_invars.append(invar)\n                    donated_outvars.append(outvar)\n                    donatable_outvars.discard(outvar)\n                    break\n        # set alias\n        for invar, outvar in zip(donated_invars, donated_outvars):\n            invar_idx, outvar_idx = var_indices[invar], var_indices[outvar]\n            computation.setup_alias((outvar_idx,), invar_idx, ())\n        for invar in donated_invars:\n            self.donated_invars[var_indices[invar]] = True\n\n    def get_spmd_partitioned(self):\n        \"\"\"Run spmd partitioner to get the input/output sharding specs after\n        partitioning.\"\"\"\n        if self.hlo.is_spmd_partitioned():\n            return self.hlo\n\n        stage_plan = self.stage_plan\n        logical_mesh_shape = stage_plan.logical_mesh_shape\n        setup_computation_alias(self.hlo, self.donated_invars)\n\n        num_devices = np.prod(logical_mesh_shape)\n        rewrite_for_grad_acc = len(self.output_acc_grad_indices) > 0\n        hlo = run_spmd_partitioner_pass(\n            self.hlo,\n            num_devices,\n            rewrite_for_grad_acc=rewrite_for_grad_acc,\n            rewrite_grad_acc_indices=self.output_acc_grad_indices)\n\n        avals = [var.aval for var in self.invars]\n        out_avals = [var.aval for var in self.outvars]\n        input_sharding_specs, output_sharding_specs = (\n            get_input_output_sharding_specs(hlo.get_module(), avals, out_avals,\n                                            num_devices,\n                                            stage_plan.logical_mesh_shape))\n        self.input_sharding_specs = input_sharding_specs\n        self.output_sharding_specs = output_sharding_specs\n        # The run_spmd_partitioner_pass modifies hlo module in-place,\n        # so the old hlo module cannot be accessed anymore\n        return hlo\n\n    def get_runnable(self, mesh=None):\n        \"\"\"Return a callable of the pipeline computation.\"\"\"\n        if not mesh:\n            raise RuntimeError(\n                \"`XlaShardedPipelineComputation` requires a mesh.\")\n        hlo = self.get_spmd_partitioned()\n\n        avals = [var.aval for var in self.invars]\n        out_avals = [var.aval for var in self.outvars]\n        mesh_executable = PartialGradAccMeshDriverExecutable(\n            mesh, hlo, self.stage_plan, avals, out_avals, self.donated_invars)\n        return mesh_executable.get_driver_callable()\n\n    def get_hlo_text(self):\n        \"\"\"Get the HLO text.\"\"\"\n        assert self.hlo.is_sharding_annotated()\n        return self.hlo.to_string()\n\n\ndef slice_closed_jaxpr_by_full_pipeline_marks(\n        closed_jaxpr: ClosedJaxpr) -> Sequence[JaxPipelineComputation]:\n    \"\"\"Slice a closed jaxpr into multiple JaxPipelineComputation by full\n    pipeline markers.\"\"\"\n    global_consts_dir = dict(\n        zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))\n\n    result_computations = []\n    current_computation = None\n\n    for eqn in closed_jaxpr.jaxpr.eqns:\n        if eqn.primitive is pipeline_p and eqn.params[\"mark_type\"] == \"start\":\n            assert current_computation is None, (\n                \"Defining a pipeline computation \"\n                \"inside a pipeline computation is \"\n                \"not allowed.\")\n            current_computation = JaxPipelineComputation(\n                name=eqn.params[\"name\"])\n            for var in eqn.invars:\n                if isinstance(var, Literal):\n                    pass\n                elif var in global_consts_dir:\n                    current_computation.consts_dir[var] = global_consts_dir[var]\n                else:\n                    current_computation.invars.append(var)\n\n        for var in eqn.invars:\n            if not isinstance(var, Literal) and var in global_consts_dir:\n                current_computation.consts_dir[var] = global_consts_dir[var]\n\n        assert current_computation is not None\n        current_computation.eqns.append(eqn)\n\n        if eqn.primitive is pipeline_p and eqn.params[\"mark_type\"] == \"end\":\n            assert current_computation is not None, (\n                \"Ending a pipeline computation before its start.\")\n            assert current_computation.name == eqn.params[\"name\"], (\n                \"Ending a pipeline computation different from its start.\")\n            for var in eqn.outvars:\n                current_computation.outvars.append(var)\n            result_computations.append(current_computation)\n            current_computation = None\n\n    return result_computations\n\n\ndef mark_missing_vars_in_backward_computation_pipeline_marks(\n        computations: Sequence[JaxPipelineComputation], global_invars,\n        global_outvars, gensym_func):\n    \"\"\"\n    Fix missing vars generated by jax.grad and alpa.grad.\n\n    Fix missing input variables in pipeline markers of stages generated by\n    jax.grad or alpa.grad. Also remove unused variables in the pipeline\n    markers.\n    \"\"\"\n    assert len(computations) % 2 == 0.\n    num_forward_computations = len(computations) // 2\n\n    var_computation_id = {}\n    for var in global_invars:\n        if not isinstance(var, Literal):\n            var_computation_id[var] = -1\n\n    computation_marked_to_unmarked_invars = [{} for _ in computations]\n    computation_weight_invars = [{} for _ in computations]\n    computation_additional_invars = [OrderedSet() for _ in computations]\n    computation_additional_outvars = [OrderedSet() for _ in computations]\n    for computation_id, computation in enumerate(computations):\n        for eqn in computation.eqns:\n            if eqn.primitive == pipeline_p and eqn.params[\n                    \"mark_type\"] == \"start\":\n                for invar, outvar in zip(eqn.invars, eqn.outvars):\n                    computation_marked_to_unmarked_invars[computation_id][\n                        outvar] = invar\n            for var in eqn.invars:\n                if (not isinstance(var, Literal) and\n                        var not in computation.consts_dir and\n                        var not in computation.invars):\n                    source_computation_id = var_computation_id[var]\n                    if source_computation_id != computation_id:\n                        # Special case for the model weights. If a backward\n                        # computation is using an invar of a forward\n                        # computation, do not let the invar go into the stage.\n                        # Instead, we can directly use the original invar.\n                        if (computation_id >= num_forward_computations and\n                                source_computation_id\n                                == 2 * num_forward_computations -\n                                computation_id - 1 and\n                                var in computation_marked_to_unmarked_invars[\n                                    source_computation_id]):\n                            computation_weight_invars[computation_id][var] = (\n                                computation_marked_to_unmarked_invars[\n                                    source_computation_id][var])\n                            continue\n                        # Mark all the variables in the backward computation\n                        # that are not currently defined in pipeline markers.\n                        if (source_computation_id != -1 and var not in\n                                computations[source_computation_id].outvars):\n                            computation_additional_outvars[\n                                source_computation_id].add(var)\n                        computation_additional_invars[computation_id].add(var)\n            for var in eqn.outvars:\n                var_computation_id[var] = computation_id\n\n    for var in global_outvars:\n        source_computation_id = var_computation_id[var]\n        if source_computation_id != -1 and var not in computations[\n                source_computation_id].outvars:\n            computation_additional_outvars[source_computation_id].add(var)\n\n    new_computations = []\n\n    for i, computation in enumerate(computations):\n        assert (computation.eqns[0].primitive is pipeline_p and\n                computation.eqns[0].params[\"mark_type\"] == \"start\")\n        assert (computation.eqns[-1].primitive is pipeline_p and\n                computation.eqns[-1].params[\"mark_type\"] == \"end\")\n        new_computation = JaxPipelineComputation(\n            computation.name, consts_dir=computation.consts_dir)\n\n        computation_var_mapping = {\n            var: gensym_func(var.aval)\n            for var in computation_additional_invars[i] |\n            computation_additional_outvars[i] |\n            computation_weight_invars[i].keys()\n        }\n        pipeline_start_invars = list(computation.eqns[0].invars)\n        pipeline_start_outvars = [\n            get_var_mapping(computation_var_mapping, var)\n            for var in computation.eqns[0].outvars\n        ]\n        new_computation.invars = list(computation.invars)\n        for var in computation_additional_invars[i]:\n            pipeline_start_invars.append(var)\n            pipeline_start_outvars.append(computation_var_mapping[var])\n        for marked_var, unmarked_var in computation_weight_invars[i].items():\n            pipeline_start_invars.append(unmarked_var)\n            pipeline_start_outvars.append(computation_var_mapping[marked_var])\n        pipeline_start_invars_without_literal = []\n        pipeline_start_outvars_without_literal = []\n        for invar, outvar in zip(pipeline_start_invars, pipeline_start_outvars):\n            if isinstance(invar, Literal):\n                computation_var_mapping[outvar] = invar\n            else:\n                pipeline_start_invars_without_literal.append(invar)\n                pipeline_start_outvars_without_literal.append(outvar)\n        new_computation.invars = list(pipeline_start_invars_without_literal)\n        new_computation.eqns.append(computation.eqns[0]._replace(\n            invars=pipeline_start_invars_without_literal,\n            outvars=pipeline_start_outvars_without_literal))\n\n        for eqn in computation.eqns[1:-1]:\n            invars = [\n                get_var_mapping(computation_var_mapping, var)\n                for var in eqn.invars\n            ]\n            outvars = [\n                get_var_mapping(computation_var_mapping, var)\n                for var in eqn.outvars\n            ]\n            new_computation.eqns.append(\n                eqn._replace(invars=invars, outvars=outvars))\n\n        pipeline_end_invars = [\n            get_var_mapping(computation_var_mapping, var)\n            for var in computation.eqns[-1].invars\n        ]\n        pipeline_end_outvars = list(computation.eqns[-1].outvars)\n        for var in computation_additional_outvars[i]:\n            pipeline_end_invars.append(computation_var_mapping[var])\n            pipeline_end_outvars.append(var)\n        pipeline_end_invars_without_dropvar = []\n        pipeline_end_outvars_without_dropvar = []\n        for invar, outvar in zip(pipeline_end_invars, pipeline_end_outvars):\n            if not isinstance(outvar, DropVar):\n                pipeline_end_invars_without_dropvar.append(invar)\n                pipeline_end_outvars_without_dropvar.append(outvar)\n        new_computation.outvars = list(pipeline_end_outvars_without_dropvar)\n        new_computation.eqns.append(computation.eqns[-1]._replace(\n            invars=pipeline_end_invars_without_dropvar,\n            outvars=pipeline_end_outvars_without_dropvar))\n        new_computations.append(new_computation)\n\n    return new_computations\n\n\ndef pipeline_dce(jax_pipeline_computations: Sequence[JaxPipelineComputation],\n                 global_outvars):\n    \"\"\"\n    Clear unused vars cross pipeline computations.\n\n    This function removes grad and only keeps accumulated grad.\n    \"\"\"\n\n    def dce_pipe_marker(marker: JaxprEqn, used_set):\n        kept_indices = [\n            i for i, var in enumerate(marker.outvars) if var in used_set\n        ]\n        new_marker = mark_pipeline_jaxpreqn(\n            [marker.invars[i] for i in kept_indices],\n            [marker.outvars[i] for i in kept_indices], marker.params[\"name\"],\n            marker.params[\"mark_type\"])\n        return new_marker\n\n    global_used = OrderedSet(global_outvars)\n    new_computations = []\n    for computation in reversed(jax_pipeline_computations):\n        new_eqns = []\n        # handle pipe end\n        pipe_end = computation.eqns[-1]\n        assert (pipe_end.primitive is pipeline_p and\n                pipe_end.params[\"mark_type\"]\n                == \"end\"), \"computation not ended by a pipeline marker\"\n        new_pipe_end = dce_pipe_marker(pipe_end, global_used)\n        new_eqns.append(new_pipe_end)\n        # handle normal instructions\n        local_used = OrderedSet(new_pipe_end.invars)\n        for eqn in reversed(computation.eqns[1:-1]):\n            for outvar in eqn.outvars:\n                if not isinstance(outvar, DropVar) and outvar in local_used:\n                    new_eqns.append(eqn)\n                    local_used.update([\n                        invar for invar in eqn.invars if isinstance(invar, Var)\n                    ])\n                    break\n        # handle pipe start\n        pipe_start = computation.eqns[0]\n        assert (pipe_start.primitive is pipeline_p and\n                pipe_start.params[\"mark_type\"]\n                == \"start\"), \"computation not started by a pipeline marker\"\n        new_pipe_start = dce_pipe_marker(pipe_start, local_used)\n        new_eqns.append(new_pipe_start)\n        global_used.update(new_pipe_start.invars)\n\n        new_eqns = list(reversed(new_eqns))\n        new_computation = JaxPipelineComputation(\n            computation.name,\n            invars=new_pipe_start.invars,\n            outvars=new_pipe_end.outvars,\n            eqns=new_eqns,\n            consts_dir=computation.consts_dir)\n        new_computations.append(new_computation)\n    new_computations = list(reversed(new_computations))\n    return new_computations\n\n\ndef rearrange_vars(invars,\n                   selected: Sequence[Var],\n                   pipe_marker=None,\n                   is_input=True):\n    \"\"\"\n    Rearrange vars to let those in selected be first.\n\n    If the pipe_marker is given, rearrange invars and outvars in pipemarker as\n    well.\n\n    Args:\n        invars (Sequence[Var]): all vars to be rearranged.\n        selected (Sequence[Var]): vars selected to be prior.\n        pipe_marker (JaxprEqn): pipe marker corresponding to vars\n        is_input (bool): the var is input of pipe_marker, if False, it is output\n    \"\"\"\n    new_vars = list(selected)\n    selected = OrderedSet(selected)\n    for var in invars:\n        if var not in selected:\n            new_vars.append(var)\n\n    if pipe_marker is None:\n        return new_vars\n\n    if is_input:\n        new_invars = list(new_vars)\n        var_set = set(new_vars)\n        # the pipeline start marker also include constvars\n        for v in pipe_marker.invars:\n            if v not in var_set:\n                new_invars.append(v)\n        invar_idx = {v: idx for idx, v in enumerate(pipe_marker.invars)}\n        new_outvars = [\n            pipe_marker.outvars[invar_idx[var]] for var in new_invars\n        ]\n    else:\n        new_outvars = list(new_vars)\n        outvar_idx = {v: idx for idx, v in enumerate(pipe_marker.outvars)}\n        new_invars = [\n            pipe_marker.invars[outvar_idx[var]] for var in new_outvars\n        ]\n    new_marker = clone_jaxpr_eqn(pipe_marker, new_invars, new_outvars)\n    return new_vars, new_marker\n\n\ndef generate_computations_from_modules(\n        jax_computations, computation_names, computation_hlos, donate_invars,\n        donatable_lists, acc_grad_outvars,\n        stage_plan) -> Sequence[XlaShardedPipelineComputation]:\n    \"\"\"Generate pipeline computation from HLO modules.\"\"\"\n    module_dict = dict(zip(computation_names, computation_hlos))\n    computations = [\n        XlaShardedPipelineComputation.from_auto_sharded_computation(\n            sharding_annotated_hlo=module_dict[computation.name],\n            jax_pipeline_computation=computation,\n            stage_plan=stage_plan,\n            donated_invars=donate_invars,\n            acc_grad_outvars=acc_grad_outvars,\n            donatables=donatables)\n        for computation, donate_invars, donatables in zip(\n            jax_computations, donate_invars, donatable_lists)\n    ]\n    return computations\n\n\ndef generate_sharded_xla_computations_arguments(\n        name: str, jax_computations: Sequence[JaxPipelineComputation],\n        computation_donate_invars: Sequence[bool],\n        input_sharding_dict: Dict[Var, pxla.ShardingSpec],\n        output_sharding_dict: Dict[Var, pxla.ShardingSpec],\n        stage_input_sharding: Optional[Sequence[pxla.ShardingSpec]]):\n    \"\"\"\n    Generates the arguments for distributed compilation.\n\n    Similar to generate_sharded_xla_computations but only generate arguments.\n    \"\"\"\n    invars = OrderedSet()\n    outvars = OrderedSet()\n    donation_mapping = {}\n    eqns = []\n    consts_dir = {}\n    for computation, donation in zip(jax_computations,\n                                     computation_donate_invars):\n        consts_dir.update(computation.consts_dir)\n        # Do not add local invars into the invars\n        invars.update([var for var in computation.invars if var not in outvars])\n        outvars.update(computation.outvars)\n        for idx, var in enumerate(computation.invars):\n            if not donation[idx] or var not in invars:\n                continue\n            donation_mapping[computation.invars[idx]] = computation.outvars[idx]\n        eqns += computation.eqns\n    invars = rearrange_vars(invars, donation_mapping.keys())\n    outvars = rearrange_vars(outvars, donation_mapping.values())\n    jaxpr = Jaxpr(\n        constvars=list(consts_dir.keys()),\n        invars=list(invars),\n        outvars=list(outvars),\n        eqns=eqns,\n    )\n\n    donation_num = len(donation_mapping)\n    dummy_donated_invars = (True,) * donation_num + (False,) * (len(invars) -\n                                                                donation_num)\n    closed_jaxpr = ClosedJaxpr(jaxpr, consts_dir.values())\n    hlo = jaxpr_to_hlo(name, closed_jaxpr, dummy_donated_invars)\n\n    if input_sharding_dict:\n        sharding_protos = []\n        for x in invars:\n            spec = input_sharding_dict.get(x, None)\n            if spec is None:\n                sharding_protos.append(undefined_sharding_spec_proto())\n            else:\n                sharding_protos.append(spec.sharding_proto())\n        hlo.set_input_shardings(sharding_protos)\n\n    if output_sharding_dict:\n        sharding_protos = []\n        for x in outvars:\n            spec = output_sharding_dict.get(x, None)\n            if spec is None:\n                sharding_protos.append(replicated_sharding_spec_proto())\n            else:\n                sharding_protos.append(spec.sharding_proto())\n        hlo.set_output_shardings(sharding_protos)\n\n    if stage_input_sharding:\n        sharding_protos = [\n            sharding_spec.sharding_proto()\n            for sharding_spec in stage_input_sharding\n        ]\n        hlo.set_input_shardings(sharding_protos)\n\n    flops = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module())\n    return hlo, flops\n\n\ndef generate_sharded_xla_computations(\n        name: str, jax_computations: Sequence[JaxPipelineComputation],\n        computation_donate_invars, donatable_lists, acc_grad_outvars,\n        num_micro_batches, logical_mesh, autosharding_option,\n        input_sharding_dict, output_sharding_dict, stage_input_sharding):\n    \"\"\"\n    Generate sharded XLA computations.\n\n    It runs the auto-sharding pass on the given JaxPipelineComputations.\n    Note: we merge the co-located forward and backward computation and compile\n    them together to get a sharding strategy config.\n    \"\"\"\n    hlo, flops = generate_sharded_xla_computations_arguments(\n        name, jax_computations, computation_donate_invars, input_sharding_dict,\n        output_sharding_dict, stage_input_sharding)\n\n    #  pylint: disable=unbalanced-tuple-unpacking\n    (computation_names, computation_hlos,\n     stage_plan) = run_auto_sharding_pass(hlo, logical_mesh, \"stages\",\n                                          num_micro_batches,\n                                          autosharding_option)\n\n    computations = generate_computations_from_modules(\n        jax_computations, computation_names, computation_hlos,\n        computation_donate_invars, donatable_lists, acc_grad_outvars,\n        stage_plan)\n    return computations, flops\n\n\ndef rewrite_hook(eqns, gensym_fn):\n    \"\"\" (Deprecated because we now profile forward and backward separately)\n    Rewrite the hook marker to include the intermediate variables.\n\n    Assume there is a special \"hook\" marker eqn in eqns that devide the\n    eqns into two parts. This function rewrites the hook to capture all the\n    variables that are passed between the two parts.\n    \"\"\"\n    for idx, eqn in enumerate(eqns):\n        eqn: JaxprEqn\n        if (\"mark_type\" in eqn.params and eqn.params[\"mark_type\"] == \"hook\"):\n            used_vars = OrderedSet()\n            defined_vars = OrderedSet()\n            for e in eqns[0:idx]:\n                defined_vars.update(\n                    [v for v in e.outvars if not isinstance(v, DropVar)])\n            for e in eqns[idx + 1:]:\n                used_vars.update([v for v in e.invars if isinstance(v, Var)])\n            marked = used_vars.intersection(defined_vars)\n            hooked = list(marked)\n            new_hook = mark_hook_jaxpreqn(hooked,\n                                          [gensym_fn(v.aval) for v in hooked])\n            rewrite_dict = dict(zip(hooked, new_hook.outvars))\n            eqns[idx] = new_hook\n            for i in range(idx + 1, len(eqns)):\n                e = eqns[i]\n                eqns[i] = clone_jaxpr_eqn(\n                    e, [get_var_mapping(rewrite_dict, v) for v in e.invars])\n            return new_hook\n    return None\n\n\ndef _wrap_with_call(closed_jaxpr: ClosedJaxpr, invars, outvars, name):\n    new_invars = closed_jaxpr.jaxpr.invars + closed_jaxpr.jaxpr.constvars\n    jaxpr = clone_jaxpr(closed_jaxpr, new_invars, constvars=[], consts=[]).jaxpr\n    params = dict(name=name, call_jaxpr=jaxpr)\n    return new_jaxpr_eqn(invars + closed_jaxpr.jaxpr.constvars, outvars,\n                         named_call_p, params)\n\n\ndef _rearrange_in_out_for_donation(invars, outvars, donation_map):\n    outvar_set = set(outvars)\n    donated_invars = [\n        var for var in invars\n        if (var in donation_map and donation_map[var] in outvar_set)\n    ]\n    donated_outvars = [donation_map[var] for var in donated_invars]\n    invars = rearrange_vars(invars, donated_invars)\n    outvars = rearrange_vars(outvars, donated_outvars)\n    num_donated = len(donated_invars)\n    return invars, outvars, num_donated\n\n\ndef merge_unmarked_with_call(jaxprs: Sequence[ClosedJaxpr],\n                             names: Sequence[str],\n                             outvars,\n                             donation_map=None):\n    \"\"\"Merge a sequence of jaxprs (no pipeline marker) using named call.\"\"\"\n    gensym_fn = gensym([closed_jaxpr.jaxpr for closed_jaxpr in jaxprs])\n    eqns = []\n    invars = OrderedSet()\n    intermediates = OrderedSet()\n    const_dir = {}\n    for stage_name, closed_jaxpr in zip(names, jaxprs):\n        invars.update(closed_jaxpr.jaxpr.invars)\n        intermediates.update(closed_jaxpr.jaxpr.outvars)\n        const_dir.update(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))\n        jaxpr = closed_jaxpr.jaxpr\n\n        sym_invars = [gensym_fn(var.aval) for var in jaxpr.invars]\n        sym_outvars = [gensym_fn(var.aval) for var in jaxpr.outvars]\n        eqns.append(\n            mark_pipeline_jaxpreqn(jaxpr.invars, sym_invars, stage_name,\n                                   \"start\"))\n        eqns.append(\n            _wrap_with_call(closed_jaxpr, sym_invars, sym_outvars, stage_name))\n        eqns.append(\n            mark_pipeline_jaxpreqn(sym_outvars, jaxpr.outvars, stage_name,\n                                   \"end\"))\n    invars.difference_update(intermediates)\n    # handle donation\n    num_donated = 0\n    if donation_map:\n        (invars, outvars,\n         num_donated) = _rearrange_in_out_for_donation(invars, outvars,\n                                                       donation_map)\n    is_donated = [True] * num_donated + [False] * (len(invars) - num_donated)\n    jaxpr = Jaxpr(const_dir.keys(), invars, outvars, eqns)\n    closed_jaxpr = ClosedJaxpr(jaxpr, const_dir.values())\n    return closed_jaxpr, is_donated\n\n\ndef _wrap_by_marker(jaxpr: Jaxpr, name, gensym_fn):\n    eqns = []\n    new_invars = list(jaxpr.invars)\n    new_outvars = list(jaxpr.outvars)\n    sym_invars = [gensym_fn(var.aval) for var in new_invars]\n    sym_outvars = [gensym_fn(var.aval) for var in new_outvars]\n    eqns.append(mark_pipeline_jaxpreqn(new_invars, sym_invars, name, \"start\"))\n    params = dict(name=name,\n                  call_jaxpr=Jaxpr([], new_invars + jaxpr.constvars,\n                                   new_outvars, jaxpr.eqns))\n    eqns.append(\n        new_jaxpr_eqn(sym_invars + jaxpr.constvars, sym_outvars, named_call_p,\n                      params))\n    eqns.append(mark_pipeline_jaxpreqn(sym_outvars, new_outvars, name, \"end\"))\n    return Jaxpr(list(jaxpr.constvars), list(jaxpr.invars), new_outvars, eqns)\n\n\ndef merge_marked_jaxprs_with_named_call(jaxprs: Sequence[ClosedJaxpr],\n                                        may_outvars: OrderedSet[Var],\n                                        donation_map=None,\n                                        prefix=None,\n                                        wrap_with_marker=False,\n                                        gensym_fn=None) -> ClosedJaxpr:\n    \"\"\"\n    Merge continuous jaxprs and remove pipe markers.\n\n    Args:\n        jaxprs: jaxprs to be merged.\n        may_outvars: outvars of the merged jaxpr.\n        donation_map: donation map of merged jaxpr, may have redundant items.\n        prefix: name of pipeline marker for merged jaxpr\n        insert_hook_after: index of a layer to insert a hook after it.\n            The hook records sharding specs of all tensors cross it.\n        wrap_with_marker: Whether the returned jaxpr has pipeline marker\n\n    Returns:\n        The merged ClosedJaxpr. If insert_hook_after is not None, it returns\n        invars of the hook as well.\n    \"\"\"\n\n    def unwrap_with_call(jaxpr, name):\n        assert jaxpr.eqns[0].primitive == pipeline_p\n        assert jaxpr.eqns[-1].primitive == pipeline_p\n        used_var = OrderedSet()\n        for eqn in jaxpr.eqns[1:-1]:\n            used_var.update([var for var in eqn.invars if isinstance(var, Var)])\n        used_var.intersection_update(jaxpr.eqns[0].outvars)\n        new_invars = {}\n        for invar, outvar in zip(jaxpr.eqns[0].invars, jaxpr.eqns[0].outvars):\n            if outvar in used_var:\n                new_invars[outvar] = invar\n        new_jaxpr = clone_jaxpr(jaxpr, new_invars.keys(), jaxpr.eqns[-1].invars,\n                                jaxpr.eqns[1:-1])\n        return _wrap_with_call(new_jaxpr, list(new_invars.values()),\n                               jaxpr.eqns[-1].outvars, name)\n\n    def has_output(jaxpr):\n        return len([v for v in jaxpr.outvars if not isinstance(v, DropVar)])\n\n    name_prefix = prefix or \"\"\n    new_eqns = []\n    invars = []\n    env = OrderedSet()\n    const_dir = {}\n    outvars = OrderedSet()\n    gensym_fn = gensym_fn or gensym([j.jaxpr for j in jaxprs])\n    # Merge everything together\n    for i, jaxpr in enumerate(jaxprs):\n        const_dir.update(zip(jaxpr.jaxpr.constvars, jaxpr.consts))\n        env.update(jaxpr.jaxpr.constvars)\n        if has_output(jaxpr.jaxpr):\n            call_eqn = unwrap_with_call(jaxpr, name_prefix + str(i))\n            new_eqns.append(call_eqn)\n            invars.extend(OrderedSet(call_eqn.invars).difference(env))\n            env.update(call_eqn.invars + call_eqn.outvars)\n        outvars.update(jaxpr.jaxpr.outvars)\n    outvars.intersection_update(may_outvars)\n\n    # handle donation\n    if donation_map:\n        invars, outvars, _ = _rearrange_in_out_for_donation(\n            invars, outvars, donation_map)\n    # wrap with marker\n    jaxpr = Jaxpr(const_dir.keys(), invars, outvars, new_eqns)\n    if wrap_with_marker:\n        jaxpr = _wrap_by_marker(jaxpr, prefix, gensym_fn)\n    closed_jaxpr = ClosedJaxpr(jaxpr, const_dir.values())\n\n    return closed_jaxpr\n\n\ndef create_donation_mapping(initial_mapping, donated_invars, invars, outvars):\n    \"\"\"Infer donation of global invar-outvars.\"\"\"\n    donation_mapping = dict(initial_mapping)\n    donated_outvars = OrderedSet()\n\n    for donate, invar in zip(donated_invars, invars):\n        if not donate:\n            continue\n        for outvar in outvars:\n            if outvar in donated_outvars:\n                continue\n            if invar.aval.shape != outvar.aval.shape:\n                continue\n            donated_outvars.add(outvar)\n            donation_mapping[invar] = outvar\n            break\n        if invar not in donation_mapping:\n            logger.warning(\n                f\"{invar} is marked donated but no match outvar for it\")\n    return donation_mapping\n\n\ndef get_local_donation_mapping_and_add_missing_invars(computation,\n                                                      reversed_donation_mapping,\n                                                      gensym_fn):\n    \"\"\"Get the local donation mapping of selected computation and add missing\n    input variables of the donated output variables.\n\n    If an outvar is donated from an invar not in the current computation, the\n    function add the invar and create a new computation and corresponding to\n    the donation mapping.\n    \"\"\"\n    invars = OrderedSet(computation.invars)\n    donation_mapping = {}\n    appended_invars = OrderedSet()\n    for var in computation.outvars:\n        if var not in reversed_donation_mapping:\n            continue\n        invar = reversed_donation_mapping[var]\n        assert invar.aval.shape == var.aval.shape\n        donation_mapping[invar] = var\n        if invar not in invars:\n            appended_invars.add(invar)\n    if not donation_mapping:\n        return donation_mapping, computation\n    # append invars for donation\n    new_invars = list(computation.invars)\n    new_outvars = list(computation.outvars)\n    new_eqns = list(computation.eqns)\n    appended_invars = list(appended_invars)\n    if appended_invars:\n        new_invars = new_invars + appended_invars\n        pipe_start = new_eqns[0]\n        new_eqns[0] = mark_pipeline_jaxpreqn(\n            pipe_start.invars + appended_invars, pipe_start.outvars +\n            list(map(lambda v: gensym_fn(v.aval), appended_invars)),\n            pipe_start.params[\"name\"], pipe_start.params[\"mark_type\"])\n    # rearrange to keep donated invars and outvars have same index\n    new_invars, new_pipe_start = rearrange_vars(new_invars,\n                                                list(donation_mapping.keys()),\n                                                new_eqns[0], True)\n    new_outvars, new_pipe_end = rearrange_vars(new_outvars,\n                                               list(donation_mapping.values()),\n                                               new_eqns[-1], False)\n    new_eqns[0] = new_pipe_start\n    new_eqns[-1] = new_pipe_end\n    new_computation = JaxPipelineComputation(computation.name, new_invars,\n                                             new_outvars, new_eqns,\n                                             computation.consts_dir)\n    return donation_mapping, new_computation\n\n\ndef split_donate_invars(donation_mapping,\n                        stages: Sequence[JaxPipelineComputation], gensym_fn):\n    \"\"\"\n    Split donated invars for sliced jaxprs, then rewrite stages.\n\n    Currently, we only donate:\n    1. global invars that can be donated(set by users);\n    2. buffers for accumulated gradients.\n    But if auto-sharding supports, we can add:\n    1. local invars not used later in this mesh, not main copy\n    2. local invars not used later in all meshes, main copy\n\n    Args:\n        donation_mapping (Dict[Var, Var]): known mapping of donations, including\n            global invar-outvar and accumulate gradients.\n        stages: slices in topology order of execution.\n\n    Returns:\n        donate_invars_dict:Sequence[Sequence[bool]]: donate_invars for each\n            stage.\n    \"\"\"\n    reversed_donation_mapping = {v: k for k, v in donation_mapping.items()}\n\n    ans = [None for _ in range(len(stages))]\n    new_stages = []\n\n    for stage_idx, stage in enumerate(stages):\n        # find donation mapping of the stage\n        donation_mapping, new_stage = (\n            get_local_donation_mapping_and_add_missing_invars(\n                stage, reversed_donation_mapping, gensym_fn))\n        donated_num = len(donation_mapping)\n        ans[stage_idx] = (True,) * donated_num + (False,) * (\n            len(new_stage.invars) - donated_num)\n        new_stages.append(new_stage)\n\n    return ans, new_stages\n\n\ndef get_donatable_intermediate(stages: Sequence[JaxPipelineComputation],\n                               worker_stage_mapping, global_invars):\n    \"\"\"\n    Get donatable invars of each stage.\n\n    A donatable invar is:\n    1. An intermediate;\n    2. Either a main copy never used, or not a main copy.\n\n    Args:\n        stages (Sequence[JaxPipelineStage]): all stages.\n        worker_stage_mapping (Dict[int, OrderedSet[int]]): indices of stages in\n            each mesh.\n        global_invars (Sequence[Var] | OrderedSet[Var]): global input variables.\n\n    Returns:\n        donatable_list (Sequence[OrderedSet[Var]]): donatable invars of each\n            stage.\n    \"\"\"\n    global_invars = OrderedSet(global_invars)\n    main_copy_at = {}\n    stage_at = {}\n    for mesh_idx, stage_indices in worker_stage_mapping.items():\n        for stage_idx in stage_indices:\n            stage = stages[stage_idx]\n            for outvar in stage.outvars:\n                main_copy_at[outvar] = mesh_idx\n            stage_at[stage_idx] = mesh_idx\n\n    donatable_list = []\n    used = OrderedSet()\n    for stage_idx in reversed(range(len(stages))):\n        stage = stages[stage_idx]\n        donatable = OrderedSet()\n        for invar in stage.invars:\n            if invar in global_invars:\n                continue  # do not consider global inputs\n            if main_copy_at[invar] != stage_at[stage_idx]:\n                donatable.add(invar)  # not a main copy\n            if invar not in used:\n                donatable.add(invar)  # is a main copy never used\n        used.update(stage.invars)\n        donatable_list.append(donatable)\n    donatable_list = list(reversed(donatable_list))\n    return donatable_list\n"
  },
  {
    "path": "alpa/pipeline_parallel/cross_mesh_resharding.py",
    "content": "\"\"\"Cross mesh resharding for pipeline parallelism.\"\"\"\nfrom abc import ABC, abstractmethod\nfrom collections import namedtuple\nimport logging\nimport math\nimport random\nimport time\nfrom typing import List, Any\n\nfrom jax.interpreters import pxla\nimport numpy as np\nimport ray\n\nimport alpa.collective as col\nfrom alpa.device_mesh import (DistributedArray, RemoteArrayRef,\n                              ReshardingRecvSpec, ReshardingSendSpec,\n                              ReshardingTileSpec, ReshardingBroadcastSpec,\n                              _device_mesh_put_dummy, device_id_to_str)\nfrom alpa.global_env import global_config\nfrom alpa.mesh_executable import (UtilMeshWorkerExecutable,\n                                  next_mesh_executable_uuid)\nfrom alpa.pipeline_parallel.computation import XlaShardedPipelineComputation\nfrom alpa.pipeline_parallel.resharding_tensor import (VirtualDistributedArray,\n                                                      TileSlice,\n                                                      unflatten_tile_index)\nfrom alpa.util import OrderedSet, compile_allgather\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\nresharding_task_counter = 0\n\n\ndef next_resharding_task_uuid():\n    \"\"\"Generate the next resharding task uuid.\"\"\"\n    global resharding_task_counter\n    resharding_task_counter = (resharding_task_counter + 1) % (1 << 60)\n    return resharding_task_counter\n\n\ndef _get_chunk_value(spec):\n    if isinstance(spec, pxla.Chunked):\n        return int(np.prod(spec.chunks))\n    return 1\n\n\ndef _add_chunk(spec, chunk):\n    if isinstance(spec, pxla.Chunked):\n        return pxla.Chunked(spec.chunks + [chunk])\n    return pxla.Chunked([chunk])\n\n\ndef _get_chunk_prefixsum(shardings):\n    chunk_cnt = 0\n    chunk_prefixsum = []\n    for dim_sharding in shardings:\n        chunk_prefixsum.append(chunk_cnt)\n        if isinstance(dim_sharding, pxla.Chunked):\n            chunk_cnt += len(dim_sharding.chunks)\n    return chunk_prefixsum\n\n\ndef _get_mesh_mapping(shardings, init_mesh_mapping, squeezed_mesh_mapping):\n    chunk_prefixsum = _get_chunk_prefixsum(shardings)\n    mesh_mapping = []\n    for mesh_dim, mapping in enumerate(squeezed_mesh_mapping):\n        prev_mapping = init_mesh_mapping[mesh_dim]\n        if mapping is None:\n            mesh_mapping.append(prev_mapping)\n            continue\n        replicas = 1\n        if isinstance(prev_mapping, pxla.Replicated):\n            replicas = prev_mapping.replicas\n        for (tensor_dim, chunk_idx) in mapping:\n            mesh_mapping.append(\n                pxla.ShardedAxis(chunk_prefixsum[tensor_dim] + chunk_idx))\n            replicas //= shardings[tensor_dim].chunks[chunk_idx]\n        if replicas > 1:\n            mesh_mapping.append(pxla.Replicated(replicas))\n    return mesh_mapping\n\n\nclass ReshardingTask:\n    \"\"\"\n    A task that addresses cross-mesh resharding between two meshes.\n\n    Args:\n        task_spec (ReshardingTaskSpec): the task spec of this task.\n        collective_group (CollectiveGroup): the collective group information.\n        src_mesh (PhysicalMesh): the source mesh to send.\n        dst_mesh (PhysicalMesh): the destination mesh to receive.\n    \"\"\"\n\n    def __init__(self, task_spec, collective_group, src_mesh, dst_mesh):\n        self.task_spec: ReshardingTaskSpec = task_spec\n        self.collective_group = collective_group\n        self.src_mesh = src_mesh\n        self.dst_mesh = dst_mesh\n\n    @property\n    def is_local_allgather_task(self):\n        \"\"\"If this task involves a post scatter-allgather task.\"\"\"\n        return self.task_spec.strategy.is_local_allgather\n\n\nclass EagerReshardingTask(ReshardingTask):\n    \"\"\"An eager resharding task.\n\n    It does not put task info into remote workers. Instead, it provides\n    a do() interface to execute the task immediately.\n    \"\"\"\n\n    def do(self, src_array):\n        \"\"\"According to the task_spec, launch send/recv operations eagerly.\n\n        Used in centralized distributed runtime.\n\n        Args:\n            src_array (DistributedArray): the source array to be resharded.\n        \"\"\"\n        if src_array.device_mesh != self.src_mesh:\n            raise RuntimeError(f\"The src array locates on a different \"\n                               f\"mesh `{src_array.device_mesh}` than \"\n                               f\"self.src_mesh `{self.src_mesh}`.\")\n\n        remote_ref = _device_mesh_put_dummy(src_array.aval, self.dst_mesh,\n                                            self.task_spec.dst_indices, 1)  # pylint: disable=protected-access\n        for i, (dst_tile, src_tiles, indices_in_dst_tiles) in enumerate(\n                self.task_spec.dst_tile_to_src_tiles_map):\n            # Loop over each dst tile for this shard\n            s = self.task_spec.strategy[i]\n            # strategy is len(dst_tile.device_strs) by len(src_tiles)\n            for replica_index, receiver in enumerate(\n                    dst_tile.replica_device_strs):\n                # loop over this replica (hence a specific destination gpu\n                # device)\n                senders = [\n                    s[replica_index][src_tile_index]\n                    for src_tile_index, src_tile in enumerate(src_tiles)\n                ]\n                self.same_destination_group_send_recv(src_array, senders,\n                                                      src_tiles,\n                                                      indices_in_dst_tiles,\n                                                      receiver, remote_ref.uuid)\n\n        # Now construct the distributed array\n        dst_array = DistributedArray(self.dst_mesh, src_array.aval,\n                                     self.task_spec.dst_sharding_spec,\n                                     remote_ref, self.task_spec.dst_indices)\n        return dst_array\n\n    def same_destination_group_send_recv(self, src_array, senders, src_tiles,\n                                         indices_in_dst_tiles, receiver, uuid):\n        \"\"\"P2P Communication accounting for multiple senders and one receiver\n        (a destination tile).\"\"\"\n        receiver_device_id = self.collective_group.device_str_to_device_id_map[\n            receiver]\n        receiver_worker = self.collective_group.device_str_to_mesh_worker_map[\n            receiver]\n        # Put an empty buffer first.\n        receiver_rank, receiver_gpu_idx = (\n            self.collective_group.device_str_to_rank_map[receiver])\n        for i, sender in enumerate(senders):\n            # send is a device_str in src_mesh\n            # we need to find out its mesh_worker, and the corresponded sender\n            # remotebuf (uuid-indexed).\n            sender_worker = self.collective_group.device_str_to_mesh_worker_map[\n                sender]\n            # assert sender_buf.device_id == i\n            sender_rank, sender_gpu_idx = (\n                self.collective_group.device_str_to_rank_map[sender])\n            # launch NCCL send/recv\n            tile = src_tiles[i]\n            indices_in_dst_tile = indices_in_dst_tiles[i]\n            send_done_ref = sender_worker.send_tile.remote(\n                src_array.remote_ref.uuid, tile.offset, receiver_rank,\n                receiver_gpu_idx, self.collective_group.group_name)\n            recv_done_ref = receiver_worker.recv_tile.remote(\n                uuid, receiver_device_id, indices_in_dst_tile, sender_rank,\n                sender_gpu_idx, self.collective_group.group_name)\n            ray.get([send_done_ref, recv_done_ref])\n\n\nclass SymbolicReshardingTask(ReshardingTask):\n    \"\"\"A symbolic resharding task that puts task info in remote workers.\"\"\"\n\n    def __init__(self, task_spec, collective_group, src_mesh, dst_mesh):\n        super().__init__(task_spec, collective_group, src_mesh, dst_mesh)\n        # Dict of worker -> ((offset, rank, gpu index))\n        self._sender_tasks = {w: [] for w in self.src_mesh.workers}\n        # Dict of worker -> ((indices, rank, gpu index))\n        self._receiver_tasks = {w: [] for w in self.dst_mesh.workers}\n        self.allgather_uuid = None\n\n        self.send_worker_task_ids = {}\n        self.recv_worker_task_ids = {}\n\n        # generate the above states\n        self._compile()\n        # print(self.__str__()+\"\\n\")\n\n    @property\n    def sender_tasks(self):\n        \"\"\"Return sender sub-tasks.\"\"\"\n        return self._sender_tasks\n\n    @property\n    def receiver_tasks(self):\n        \"\"\"Return receiver sub-tasks.\"\"\"\n        return self._receiver_tasks\n\n    def _compile(self):\n        \"\"\"\n        Generate all send, recv, and allgather tasks.\n\n        This function does the following:\n        (1) generate send, recv, and allgather tasks (if needed),\n        (2) put all tasks to their corresponding MeshHostWorkers.\n        (3) pre-generate NCCL communicators for those tasks.\n        \"\"\"\n        self._compile_send_recv_tasks()\n\n        if not global_config.debug_with_pipeshard_runtime:\n            self.put_all_tasks()\n\n    def put_all_tasks(self):\n        \"\"\"\n        Put all send, recv and allgather tasks to their MeshHostWorkers\n        \"\"\"\n        # put send and recv tasks\n        task_dones = []\n        for worker, task in self.sender_tasks.items():\n            uuid = next_resharding_task_uuid()\n            self.send_worker_task_ids[worker] = uuid\n            task_dones.append(\n                worker.put_resharding_send_task.remote(\n                    uuid, task, self.collective_group.group_name))\n        for worker, task in self.receiver_tasks.items():\n            uuid = next_resharding_task_uuid()\n            self.recv_worker_task_ids[worker] = uuid\n            task_dones.append(\n                worker.put_resharding_recv_task.remote(\n                    uuid, task, self.collective_group.group_name))\n        ray.get(task_dones)\n\n        # put allgather tasks\n        task_dones = []\n        if self.is_local_allgather_task:\n            self.allgather_uuid = uuid = next_mesh_executable_uuid()\n            task_spec = self.task_spec\n            hlo = compile_allgather(task_spec.aval.shape, task_spec.aval.dtype,\n                                    task_spec.dst_sharding_spec,\n                                    task_spec.final_dst_spec,\n                                    np.prod(self.dst_mesh.shape))\n            for worker in self.dst_mesh.workers:\n                task_dones.append(\n                    worker.put_executable.remote(uuid, UtilMeshWorkerExecutable,\n                                                 hlo))\n        ray.get(task_dones)\n\n    def create_resharding_communicators(self):\n        \"\"\"Create the NCCL communicators in advance.\"\"\"\n        communicator_params = set()\n        for worker, recv_tasks in self.receiver_tasks.items():\n            dst_rank = self.collective_group.worker_to_rank_map[worker]\n            for recv_task in recv_tasks:\n                dst_gpu_idx = recv_task.device_id\n                tile_specs = recv_task.tile_specs\n                for tile_spec in tile_specs:\n                    src_rank = tile_spec.rank\n                    src_gpu_idx = tile_spec.gpu_idx\n                    param = (src_rank, src_gpu_idx, dst_rank, dst_gpu_idx)\n                    if param not in communicator_params:\n                        communicator_params.add(param)\n\n        # now init the communicators\n        group_name = self.collective_group.group_name\n        task_dones = []\n        for param in communicator_params:\n            src_rank, src_gpu_idx, dst_rank, dst_gpu_idx = param\n            src_worker = self.collective_group.mesh_workers[src_rank]\n            dst_worker = self.collective_group.mesh_workers[dst_rank]\n            nccl_uid = ray.get(src_worker.generate_nccl_uid.remote(group_name))\n            task_dones.append(\n                src_worker.init_p2p_communicator.remote(group_name, src_rank,\n                                                        src_gpu_idx, dst_rank,\n                                                        dst_gpu_idx, nccl_uid))\n            task_dones.append(\n                dst_worker.init_p2p_communicator.remote(group_name, dst_rank,\n                                                        dst_gpu_idx, src_rank,\n                                                        src_gpu_idx, nccl_uid))\n        ray.get(task_dones)\n\n    def _compile_send_recv_tasks(self):\n        \"\"\"Generate all send/recv tasks.\"\"\"\n        dtype = self.task_spec.src.aval.dtype\n\n        # print(\"order: \", self.task_spec.strategy.order)\n        for i, k, j in self.task_spec.strategy.order:\n            spec_plan = self.task_spec.strategy.per_spec_plans[i]\n            dst_tile, src_tiles, indices_in_dst_tiles = (\n                self.task_spec.dst_tile_to_src_tiles_map[i])\n            replica_index, receiver = k, dst_tile.replica_device_strs[k]\n            _, _, indices_in_dst_tile = (j, src_tiles[j],\n                                         indices_in_dst_tiles[j])\n\n            # Get args for an empty buffer\n            receiver_device_id = (\n                self.collective_group.device_str_to_device_id_map[receiver])\n            receiver_worker = (\n                self.collective_group.device_str_to_mesh_worker_map[receiver])\n            dtype = self.task_spec.src.aval.dtype\n            # Get args for send/recv\n            senders = [\n                spec_plan[replica_index][src_tile_index]\n                for src_tile_index, _ in enumerate(src_tiles)\n            ]\n            receiver_rank, receiver_gpu_idx = (\n                self.collective_group.device_str_to_rank_map[receiver])\n            recv_tile_specs = []\n            for sender_idx, sender in enumerate(senders):\n                # Sender's task\n                sender_worker = (\n                    self.collective_group.device_str_to_mesh_worker_map[sender])\n                src_device_id = (\n                    self.collective_group.device_str_to_device_id_map[sender])\n                self._sender_tasks[sender_worker].append(\n                    ReshardingSendSpec(\n                        src_device_id,\n                        ReshardingTileSpec(src_tiles[sender_idx].offset,\n                                           receiver_rank, receiver_gpu_idx)))\n                # Receiver's task\n                sender_rank, sender_gpu_idx = \\\n                    self.collective_group.device_str_to_rank_map[sender]\n                indices_in_dst_tile = indices_in_dst_tiles[sender_idx]\n                recv_tile_specs.append(\n                    ReshardingTileSpec(indices_in_dst_tile, sender_rank,\n                                       sender_gpu_idx))\n            receiver_task = ReshardingRecvSpec(receiver_device_id,\n                                               dst_tile.tile_shape, dtype,\n                                               recv_tile_specs)\n            self._receiver_tasks[receiver_worker].append(receiver_task)\n\n    # FIXME(Hao): test the function below; it might be buggy.\n    def do_prepared(self, src_array, profiling=False):\n        \"\"\"Execute a task which has been put in the remote workers.\"\"\"\n\n        result_ref = RemoteArrayRef(self.dst_mesh)\n\n        results = []\n        if profiling:\n            for worker, uuid in self.send_worker_task_ids.items():\n                results.append(\n                    worker.profile_resharding_send_task.remote(\n                        uuid, src_array.remote_ref.uuid))\n            for worker, uuid in self.recv_worker_task_ids.items():\n                results.append(\n                    worker.profile_resharding_recv_task.remote(\n                        uuid, result_ref.uuid))\n        else:\n            for worker, uuid in self.send_worker_task_ids.items():\n                results.append(\n                    worker.run_resharding_send_task.remote(\n                        uuid, src_array.remote_ref.uuid))\n            for worker, uuid in self.recv_worker_task_ids.items():\n                results.append(\n                    worker.run_resharding_recv_task.remote(\n                        uuid, result_ref.uuid))\n            logger.debug(\"Precompiled tasks launched.\")\n            ray.get(results)\n        # Now construct the distributed array\n        dst_array = DistributedArray(self.dst_mesh, src_array.aval,\n                                     self.task_spec.dst_sharding_spec,\n                                     result_ref, self.task_spec.dst_indices)\n        if profiling:\n            return results\n        return dst_array\n\n    def __str__(self):\n        return (f\"ReshardingTask(shape: {self.task_spec.aval.shape}, \"\n                f\"mesh_id: {self.src_mesh.mesh_id}->{self.dst_mesh.mesh_id},\\n\"\n                f\"{self.task_spec.src_sharding_spec} ->\\n\"\n                f\"{self.task_spec.dst_sharding_spec})\")\n\n\nclass CommunicatorConfig:\n    \"\"\"Config used to initilize broadcast communicator.\"\"\"\n\n    def __init__(self, comm_key):\n        self.comm_key = comm_key\n        self.workers = []\n        self.device_ids = []\n\n    def add(self, worker, device_id):\n        self.workers.append(worker)\n        self.device_ids.append(device_id)\n\n    def __hash__(self):\n        return hash(\n            (self.comm_key, tuple(self.workers), tuple(self.device_ids)))\n\n    def __eq__(self, other):\n        if not isinstance(other, CommunicatorConfig):\n            return False\n        elif self.comm_key != other.comm_key:\n            return False\n        elif len(self.workers) != len(other.workers):\n            return False\n\n        for i in range(len(self.workers)):\n            if (self.workers[i] != other.workers[i] or\n                    self.device_ids[i] != other.device_ids[i]):\n                return False\n\n        return True\n\n\nclass SymbolicBroadcastReshardingTask(ReshardingTask):\n    \"\"\"A Broadcast based symbolic resharding task that puts task info in remote\n    workers.\"\"\"\n\n    def __init__(self, task_spec, collective_group, src_mesh, dst_mesh):\n        super().__init__(task_spec, collective_group, src_mesh, dst_mesh)\n        # task is a dict: (i, src_tile_index)->ReshardingBroadcastSpec\n        self._broadcast_tasks = {\n            host: {} for host in self.src_mesh.workers + self.dst_mesh.workers\n        }\n        self.broadcast_worker_task_ids = {}\n        self.communicator_configs = set()\n\n        # generate the above states\n        self._compile()\n        # print(self.__str__()+\"\\n\")\n\n    @property\n    def broadcast_tasks(self):\n        \"\"\"Return broadcast sub-tasks.\"\"\"\n        return self._broadcast_tasks\n\n    def _compile(self):\n        \"\"\"\n        Generate all broadcast tasks.\n\n        This function does the following:\n        (1) generate broadcast tasks (if needed),\n        (2) put all tasks to their corresponding MeshHostWorkers.\n        (3) pre-generate NCCL communicators for those tasks.\n        \"\"\"\n        self._compile_broadcast_tasks()\n\n        if not global_config.debug_with_pipeshard_runtime:\n            self.put_all_tasks()\n\n    def put_all_tasks(self):\n        \"\"\"Put all tasks to their corresponding MeshHostWorkers.\"\"\"\n        task_dones = []\n        for worker, task in self._broadcast_tasks.items():\n            uuid = next_resharding_task_uuid()\n            self.broadcast_worker_task_ids[worker] = uuid\n            # print(worker, uuid, task)\n            task_dones.append(\n                worker.put_resharding_broadcast_task.remote(\n                    uuid, task, self.collective_group.group_name))\n        ray.get(task_dones)\n\n    def _compile_broadcast_tasks(self):\n        \"\"\"Compile broadcast tasks.\"\"\"\n        dtype = self.task_spec.src.aval.dtype\n\n        # print(\"order: \", self.task_spec.strategy.order)\n        for i, j in self.task_spec.strategy.order:\n            spec_plan = self.task_spec.strategy.per_spec_plans[i]\n            dst_tile, src_tiles, indices_in_dst_tiles = (\n                self.task_spec.dst_tile_to_src_tiles_map[i])\n            src_tile, indices_in_dst_tile = (src_tiles[j],\n                                             indices_in_dst_tiles[j])\n\n            sender = spec_plan[j]\n            sender_worker = (\n                self.collective_group.device_str_to_mesh_worker_map[sender])\n            broadcast_group = (i, j)\n            devices = [sender] + dst_tile.replica_device_strs\n            comm_key = \"$\".join(devices)\n            world_size = len(devices)\n\n            comm_config = CommunicatorConfig(comm_key)\n\n            group_spec = self._broadcast_tasks[sender_worker].setdefault(\n                broadcast_group,\n                ReshardingBroadcastSpec(comm_key=comm_key,\n                                        world_size=world_size,\n                                        devices_ids=[\n                                            self.collective_group.\n                                            device_str_to_device_id_map[sender]\n                                        ],\n                                        devices_global_rank=[0],\n                                        tensor_slices=[src_tile.offset],\n                                        recv_tile_shape=src_tile.tile_shape,\n                                        dtype=dtype))\n            comm_config.add(\n                sender_worker,\n                self.collective_group.device_str_to_device_id_map[sender])\n\n            for replica_index, receiver in enumerate(\n                    dst_tile.replica_device_strs):\n                receiver_worker = (self.collective_group.\n                                   device_str_to_mesh_worker_map[receiver])\n                group_spec = self._broadcast_tasks[receiver_worker].setdefault(\n                    broadcast_group,\n                    ReshardingBroadcastSpec(comm_key=comm_key,\n                                            world_size=world_size,\n                                            devices_ids=[],\n                                            devices_global_rank=[],\n                                            tensor_slices=[],\n                                            recv_tile_shape=dst_tile.tile_shape,\n                                            dtype=dtype))\n\n                group_spec.devices_ids.append(\n                    self.collective_group.device_str_to_device_id_map[receiver])\n                group_spec.devices_global_rank.append(1 + replica_index)\n                group_spec.tensor_slices.append(indices_in_dst_tile)\n                comm_config.add(\n                    receiver_worker,\n                    self.collective_group.device_str_to_device_id_map[receiver])\n\n            self.communicator_configs.add(comm_config)\n\n        return self._broadcast_tasks\n\n    def create_resharding_communicators(self):\n        \"\"\"Create the NCCL communicators for broadcast in advance.\"\"\"\n        group_name = self.collective_group.group_name\n        for config in self.communicator_configs:\n            task_dones = []\n            worker_to_devices_and_global_ranks = {}\n            world_size = len(config.workers)\n            for global_rank, (worker, device_id) in enumerate(\n                    zip(config.workers, config.device_ids)):\n                if worker not in worker_to_devices_and_global_ranks:\n                    worker_to_devices_and_global_ranks[worker] = {\n                        \"device_ids\": [],\n                        \"global_ranks\": []\n                    }\n                worker_to_devices_and_global_ranks[worker][\"device_ids\"].append(\n                    device_id)\n                worker_to_devices_and_global_ranks[worker][\n                    \"global_ranks\"].append(global_rank)\n\n            sender_worker = config.workers[0]\n            nccl_uid = ray.get(\n                sender_worker.generate_nccl_uid.remote(group_name))\n\n            for worker, devices_info in (\n                    worker_to_devices_and_global_ranks.items()):\n                task_dones.append(\n                    worker.init_broadcast_communicator.remote(\n                        group_name, config.comm_key, world_size,\n                        devices_info[\"device_ids\"],\n                        devices_info[\"global_ranks\"], nccl_uid))\n            ray.get(task_dones)\n\n    def __str__(self):\n        return (f\"B-ReshardingTask(shape: {self.task_spec.aval.shape}, \"\n                f\"mesh_id: {self.src_mesh.mesh_id}->{self.dst_mesh.mesh_id},\\n\"\n                f\"{self.task_spec.src_sharding_spec} ->\\n\"\n                f\"{self.task_spec.dst_sharding_spec})\")\n\n\nclass CollectiveGroup:\n    \"\"\"\n    A class for setting up real NCCL groups.\n\n    Args:\n        device_strs (List[str]): list of device strs in this group.\n        src_mesh (PhysicalDeviceMesh): the source physical mesh.\n        dst_mesh (PhysicalDeviceMesh): the destination physical mesh.\n    \"\"\"\n\n    def __init__(self, device_strs, src_mesh, dst_mesh):\n        self.instantiated = False\n        self.device_strs = device_strs\n        self.src_mesh = src_mesh\n        self.dst_mesh = dst_mesh\n\n        # generate a group name\n        self.group_name = \",\".join(self.device_strs)\n\n        # construct a device str -> rank: (process_rank, gpu_index) map\n        self.device_str_to_rank_map = {}\n        self.device_str_to_mesh_worker_map = {}\n        self.device_str_to_host_id_map = {}\n        self.device_str_to_device_id_map = {}\n        self.worker_to_rank_map = {}\n\n        # arranged following the rank order\n        num_host = len(self.src_mesh.host_ips) + len(self.dst_mesh.host_ips)\n        self.mesh_workers: List[Any] = [None] * num_host\n        for i, _ in enumerate(src_mesh.host_ips):\n            self.mesh_workers[i] = self.src_mesh.workers[i]\n            for j in range(src_mesh.num_devices_per_host):\n                device_str = self.src_mesh.device_strs[\n                    i * src_mesh.num_devices_per_host + j]\n                self.device_str_to_rank_map[device_str] = (i, j)\n                self.device_str_to_mesh_worker_map[\n                    device_str] = self.src_mesh.workers[i]\n                self.device_str_to_host_id_map[device_str] = i\n                self.device_str_to_device_id_map[device_str] = j\n        for i, _ in enumerate(dst_mesh.host_ips):\n            self.mesh_workers[\n                i + len(self.src_mesh.host_ips)] = self.dst_mesh.workers[i]\n            for j in range(dst_mesh.num_devices_per_host):\n                device_str = self.dst_mesh.device_strs[\n                    i * dst_mesh.num_devices_per_host + j]\n                self.device_str_to_rank_map[device_str] = (\n                    i + len(src_mesh.host_ips), j)\n                self.device_str_to_mesh_worker_map[\n                    device_str] = self.dst_mesh.workers[i]\n                self.device_str_to_host_id_map[device_str] = i\n                self.device_str_to_device_id_map[device_str] = j\n\n        self.worker_to_rank_map = {\n            worker: r for r, worker in enumerate(self.mesh_workers)\n        }\n\n    def instantiate(self):\n        \"\"\"Instantiate the collective group in Ray lazily.\"\"\"\n        if self.instantiated:\n            return\n        options = {\n            \"group_name\": self.group_name,\n            \"world_size\": len(self.mesh_workers),\n            \"ranks\": [i for i, _ in enumerate(self.mesh_workers)],\n            \"backend\": \"nccl\"\n        }\n        col.create_collective_group(self.mesh_workers, **options)\n        self.instantiated = True\n\n    def instantiate_now(self):\n        \"\"\"Instantiate the collective group eagerly (but not communicators).\"\"\"\n        if self.instantiated:\n            return\n        world_size = len(self.mesh_workers)\n        task_dones = []\n        logger.debug(\n            \"Trying to create ray.collective groups among participants.\")\n        for rank, worker in enumerate(self.mesh_workers):\n            task_dones.append(\n                worker.init_collective_group.remote(world_size, rank, \"nccl\",\n                                                    self.group_name))\n        ray.get(task_dones)\n        logger.debug(f\"The group {self.group_name} has been created.\")\n        self.instantiated = True\n\n    def destroy(self):\n        \"\"\"Destroy the NCCL collective group at exit.\"\"\"\n        logger.debug(f\"Recycling the collective group: {self.group_name}.\")\n        for worker in self.mesh_workers:\n            # This remote call will remove ray named actors (hence it is\n            # necessary)\n            ray.get(worker.destroy_collective_group.remote(self.group_name))\n        # Destroy the declared named actor in ray\n        self._destroy_info_actor()\n        self.instantiated = False\n\n    def _destroy_info_actor(self):\n        name = \"info_\" + self.group_name\n        try:\n            store = ray.get_actor(name)\n            ray.kill(store)\n        except ValueError:\n            pass\n\n\nclass ReshardingTaskSpec:\n    \"\"\"\n    A helper class specifies how to perform cross-mesh resharding for two\n    arrays.\n\n    Args:\n        src_array (VirtualDistributedArray): the source VirtualDistributedArray.\n        dst_array (VirtualDistributedArray): the destination\n            VirtualDistributedArray.\n    \"\"\"\n\n    def __init__(self, src_array, dst_array, final_dst_spec):\n        self.src = src_array\n        self.dst = dst_array\n        self._dst_tile_to_src_tiles_map = None\n        self._strategy = None\n        self.final_dst_spec = final_dst_spec\n\n    @property\n    def src_sharding_spec(self):\n        \"\"\"Return the sharding spec of the source array.\"\"\"\n        return self.src.sharding_spec\n\n    @property\n    def dst_sharding_spec(self):\n        \"\"\"Return the sharding spec of the destination array.\"\"\"\n        return self.dst.sharding_spec\n\n    @property\n    def aval(self):\n        \"\"\"Return the abstract value of the array.\"\"\"\n        return self.src.aval\n\n    @property\n    def src_indices(self):\n        \"\"\"Return the sharding (flattened) indices of the source array.\"\"\"\n        return self.src.indices\n\n    @property\n    def dst_indices(self):\n        \"\"\"Return the sharding (flattened) indices of the destination array.\"\"\"\n        return self.dst.indices\n\n    @property\n    def dst_tile_to_src_tiles_map(self):\n        \"\"\"\n        Map from dst_tile to all corresponding src TileSlices.\n\n        It is a list of length len(dst.tiles), each element is a 3-element tuple\n        (dst_tile, src_tile_slices, indices_in_dst_tile):\n        - dst_tile: a tile from dst.tiles\n        - src_tile_slices: a list of TileSlice objects from src, corresponding\n            to this dst_tile\n        - indices_in_dst_tile: a list of slicers. Each slicer is a list of slice\n            objects, corresponding to\n            a TileSlice in src_tile_slices, representing the indices of this\n            TileSlice in dst_tile.\n        \"\"\"\n        if not self._dst_tile_to_src_tiles_map:\n            self._dst_tile_to_src_tiles_map = self.generate_src_dst_map()\n        return self._dst_tile_to_src_tiles_map\n\n    def generate_src_dst_map(self):\n        \"\"\"\n        Analyzes the src and dst array and generate the\n        dst_tile_to_src_tiles_map.\n\n        It aims to tell the needed collective group and communication pattern.\n\n        Returns:\n            dst_tile_to_src_tiles_map (tuple[tile, tileslices, indices]):\n                see the docstring of `dst_tile_to_src_tiles_map`.\n        \"\"\"\n        dst_tile_to_src_tiles_map = []\n        for tile in self.dst.tiles.flatten():\n            # loop over each tile\n            src_tile_slices, indices_in_dst_tile = (\n                self._look_up_dst_tile_from_src(tile))\n            dst_tile_to_src_tiles_map.append(\n                (tile, src_tile_slices, indices_in_dst_tile))\n        return dst_tile_to_src_tiles_map\n\n    def _look_up_dst_tile_from_src(self, tile):\n        \"\"\"\n        Look up all related tiles from the source array for a given destination\n        tile.\n\n        See the docstring in dst_tile_to_src_tiles_map() for more details.\n        \"\"\"\n        # For each dim in the dst tile, find all the related tiles, and ragged\n        # values on that dim in src_tiles.\n        # To record that, for each dim, we make a tuple containing the first and\n        # last index of tiles in src array that intersects with the dst tile:\n        # Shards between [start, end) are involved; Left included, right not\n        # included.\n        related_tile_start_end = [tuple()] * self.src.tensor_rank\n\n        # Meanwhile, for each dim, for the first and end tile, we make a tuple\n        # recording the slicing offset:\n        # - start_shard_offset: [start_shard_offset: ] on that dim is activated.\n        # - end_shard_offset: [:end_sharding_offset] on that dim is activated.\n        related_tile_offset = [tuple()] * self.src.tensor_rank\n\n        for i, dim in enumerate(self.src.tensor_shape):\n            tile_length, ragged = divmod(dim, self.src.tile_shape[i])\n            assert not ragged\n            start_tile, start_tile_offset = divmod(tile.indices[i].start,\n                                                   tile_length)\n            end_tile, end_tile_offset = divmod(tile.indices[i].stop,\n                                               tile_length)\n            # if falling on the middle a src tile, increase the index of the\n            # final tile by 1.\n            if end_tile_offset:\n                end_tile = end_tile + 1\n            # if falling on the end of a src tile, the offset should be\n            # [0: tile_length]\n            if end_tile_offset == 0:\n                end_tile_offset = tile_length\n            related_tile_start_end[i] = (start_tile, end_tile)\n            related_tile_offset[i] = (start_tile_offset, end_tile_offset)\n\n        # count the number of tile slices\n        num_src_tileslices = 1\n        for start, end in related_tile_start_end:\n            num_src_tileslices = num_src_tileslices * (end - start)\n\n        src_tileslices = []\n        indices_in_dst_tile = []\n        for tileslice_index in range(num_src_tileslices):\n            tile_index_relative = unflatten_tile_index(\n                tileslice_index,\n                [end - start for start, end in related_tile_start_end])\n            tile_index_absolute = [\n                start + tile_index_relative[dim_index]\n                for dim_index, (start, end) in enumerate(related_tile_start_end)\n            ]\n            # depending on its index, calculate a slice for it\n            offsets = []\n            indices = []\n            # loop over each dimension\n            for i, r in enumerate(tile_index_absolute):\n                start, end = related_tile_start_end[i]\n                tile_length_on_this_dim = self.src.tiles[tuple(\n                    tile_index_absolute)].tile_shape[i]\n                if r == start and r == end - 1:\n                    # the dst tile is smaller or equal to the src tile\n                    left_offset = related_tile_offset[i][0]\n                    right_offset = related_tile_offset[i][1]\n                    offsets.append(slice(left_offset, right_offset))\n                    indices.append(slice(0, tile.tile_shape[i]))  # all included\n                elif r == start:\n                    # meaning it is the first involved tile, and not the last\n                    offset = related_tile_offset[i][0]\n                    offsets.append(slice(offset, tile_length_on_this_dim))\n                    indices.append(slice(0, tile_length_on_this_dim - offset))\n                elif r == end - 1:\n                    # meaning it is the last involved tile, and not the first\n                    offset = related_tile_offset[i][1]\n                    offsets.append(slice(0, offset))\n                    indices.append(\n                        slice(tile.tile_shape[i] - offset, tile.tile_shape[i]))\n                else:\n                    # meaning it is a fully involved tile\n                    offset = related_tile_offset[i][0]\n                    offsets.append(slice(0, tile_length_on_this_dim))\n                    left_in_dst_tile = (\n                        tile_length_on_this_dim - offset +\n                        (tile_index_relative[i] - 1) * tile_length_on_this_dim)\n                    right_in_dst_tile = (left_in_dst_tile +\n                                         tile_length_on_this_dim)\n                    indices.append(slice(left_in_dst_tile, right_in_dst_tile))\n            # construct a new tile slice\n            this_tileslice = TileSlice(\n                self.src.tiles[tuple(tile_index_absolute)], offset=offsets)\n            src_tileslices.append(this_tileslice)\n            indices_in_dst_tile.append(indices)\n        return src_tileslices, indices_in_dst_tile\n\n    def set_resharding_strategy(self, strategy):\n        \"\"\"Now the strategy is np.array(dtype=str) to specify connections\n        between src tiles and dst tile.\"\"\"\n        self._strategy = strategy\n\n    @property\n    def strategy(self):\n        \"\"\"Return the communication strategy for this resharding task spec.\"\"\"\n        if not self._strategy:\n            raise RuntimeError(\n                \"Generate and set strategy in the cross-mesh communicator \"\n                \"first.\")\n        return self._strategy\n\n    def generate_naive_order(self, mode):\n        \"\"\"Return the naive order to submit resharding tasks.\"\"\"\n\n        order = []\n        if mode == \"sendrecv\":\n            for i, (dst_tile, src_tiles,\n                    _) in enumerate(self.dst_tile_to_src_tiles_map):\n                for k, _ in enumerate(dst_tile.replica_device_strs):\n                    for j, _ in enumerate(src_tiles):\n                        order.append((i, k, j))\n        elif mode == \"broadcast\":\n            for i, (_, src_tiles,\n                    _) in enumerate(self.dst_tile_to_src_tiles_map):\n                for j, _ in enumerate(src_tiles):\n                    order.append((i, j))\n        else:\n            raise NotImplementedError\n\n        return order\n\n    def get_participant_device_strs(self):\n        \"\"\"Identify all participant device strs (for NCCL setup) in this task\n        spec.\"\"\"\n        if not self._strategy:\n            raise RuntimeError(\"Generate and set strategy first.\")\n        device_strs = OrderedSet()\n        # senders\n        for tile_strategy in self.strategy.per_spec_plans:\n            device_strs = device_strs | OrderedSet(\n                tile_strategy.flatten().tolist())\n        # receivers\n        for tile in self.dst.tiles.flatten():\n            device_strs = device_strs | OrderedSet(tile.replica_device_strs)\n        return device_strs\n\n    def __str__(self):\n        ret_str = \"\"\n        ret_str += f\"{self.src_sharding_spec} -> {self.dst_sharding_spec}\"\n        if self.final_dst_spec != self.dst_sharding_spec:\n            ret_str += f\" -(allgather)-> {self.final_dst_spec}\"\n        ret_str += \";\"\n        return ret_str\n\n\nclass ReshardingStrategy:\n    \"\"\"A data class for storing resharding communication information.\n\n    Args:\n        mode (str): Two choices:[\"sendrecv\", \"broadcast\"].\n        per_spec_plans (List[np.ndarray]): `per_spec_plan` is a list a np array,\n            with length as len(spec.dst_tile_to_src_tiles_map), each array is\n            with shape [len(dst_tile.devices), len(src_tiles)]; it specifies for\n            each replica of a dst tile, how it should get the data from\n            src_tiles (src tile replicas).\n        order (List[Tuple(int, ...)]): in which order we should submit\n            these nccl communication operation into cuda stream. When mode\n            is \"sendrecv\", order is of type List[Tuple(int, int)];\n            Otherwise, order is of type List[Tuple(int, int, int)].\n        is_local_allgather (bool): if this strategy involves post allgather\n            operations.\n    \"\"\"\n\n    def __init__(self, mode, per_spec_plans, order, is_local_allgather):\n        self.mode = mode\n        self.per_spec_plans = per_spec_plans\n        self.order = order\n        self.is_local_allgather = is_local_allgather\n\n\nclass CrossMeshCommunicator:\n    \"\"\"\n    Communicator for cross-mesh resharding.\n\n    Given the pipeline schedule and stages, the class analyzes them and\n    generates:\n    - resharding specs (see docstring of `ReshardingTaskSpec`),\n    - resharding strategies (see docstring of `ReshardingStrategy`).\n    This communicator only takes care of compilation-time work, and does not\n    get involved with physical meshes, buffer creations, or other runtime work.\n\n    Args:\n        sharded_stages (Sequence[XlaShardedPipelineComputation]): list of stages\n            to form the pipeline.\n        schedule (Any): the pipelining schedule for these stages.\n    \"\"\"\n\n    def __init__(self, sharded_stages, schedule):\n        if not isinstance(sharded_stages, list):\n            raise RuntimeError(\"Require a list of stages.\")\n        for s in sharded_stages:\n            if not isinstance(s, XlaShardedPipelineComputation):\n                raise RuntimeError(\"Require a list of sharded stages.\")\n        # Do not mutate\n        self._sharded_stages = sharded_stages\n        self._schedule = schedule\n        self.resharding_specs = None\n\n        # Loads for load balancing.\n        self._sender_loads = {\n            device_str: 0 for mesh in self._schedule.meshes\n            for device_str in mesh.device_strs\n        }\n        self._receiver_loads = {\n            device_str: 0 for mesh in self._schedule.meshes\n            for device_str in mesh.device_strs\n        }\n\n        # Initialize all resharding specs\n        self._create_resharding_specs()\n        # Generate a send/recv strategies for all resharding tasks by looking\n        # at their load.\n        for src_mesh_idx, dst_mesh_idx, var_spec_map in self.task_spec_iter():\n            for _, spec in var_spec_map.items():\n                if global_config.resharding_mode == \"send_recv\":\n                    strategy = (self._generate_send_recv_resharding_strategy(\n                        spec, self._schedule.meshes[src_mesh_idx],\n                        self._schedule.meshes[dst_mesh_idx]))\n                else:\n                    strategy = (self._generate_broadcast_resharding_strategy(\n                        spec, self._schedule.meshes[src_mesh_idx],\n                        self._schedule.meshes[dst_mesh_idx]))\n                spec.set_resharding_strategy(strategy)\n\n    @property\n    def num_mesh(self):\n        \"\"\"Number of meshes in the schedule.\"\"\"\n        return self._schedule.num_mesh\n\n    @staticmethod\n    def _rewrite_allgather_spec(sharding_spec, dst_num_hosts, var_shape):\n        \"\"\"\n        Given a sharding spec, if use_local_allgather is on and the tensor\n        corresponding to the spec is not fully sharded, the function rewrite the\n        spec to a fully-sharded one, and return info of added chunks.\n\n        The rewrite is by steps below:\n        1. Iterate all logical mesh dimensions(m_dim) along which the tensor is\n        replicated;\n        2. Iterate all tensor dimensions(t_dim). If the length of the tensor on\n        t_dim and the number of replicas on m_dim have a common divisor greater\n        than 1, an extra chunk is appended on t_dim;\n        3. When there is no replicas on m_dim, the iteration terminates.\n        \"\"\"\n\n        if not global_config.use_local_allgather:\n            return sharding_spec\n        # check whether the tensor is fully sharded.\n        replicated_mesh_dim = []\n        mesh_dim_to_chunk_axis = {}\n        for mesh_dim, dim_mapping in enumerate(sharding_spec.mesh_mapping):\n            if isinstance(dim_mapping, pxla.Replicated):\n                replicated_mesh_dim.append((mesh_dim, dim_mapping.replicas))\n            else:\n                dim_mapping: pxla.ShardedAxis\n                mesh_dim_to_chunk_axis[mesh_dim] = dim_mapping.axis\n        if len(replicated_mesh_dim) == 0:\n            return sharding_spec\n        assert len(replicated_mesh_dim) == 1, \"Only support 1D and 2D mesh\"\n\n        # create chunk axis to tensor dim mapping\n        chunk_axis_to_tensor_dim = []\n        for tensor_dim, dim_spec in enumerate(sharding_spec.sharding):\n            if isinstance(dim_spec, pxla.Chunked):\n                for chunk_idx in range(len(dim_spec.chunks)):\n                    chunk_axis_to_tensor_dim.append((tensor_dim, chunk_idx))\n\n        # TODO(yonghao): add a global config for wheter cross-node allgather is\n        # allowed\n        node_mesh_mapping = sharding_spec.mesh_mapping[0]\n        node_chunk = 1\n        if isinstance(node_mesh_mapping, pxla.ShardedAxis):\n            tensor_dim, _ = chunk_axis_to_tensor_dim[node_mesh_mapping.axis]\n            node_chunk = _get_chunk_value(sharding_spec.sharding[tensor_dim])\n        if node_chunk < dst_num_hosts:\n            return sharding_spec\n\n        sharding = list(sharding_spec.sharding)\n        squeezed_mesh_mapping = [\n            None if isinstance(dim_mapping, pxla.Replicated) else\n            [chunk_axis_to_tensor_dim[dim_mapping.axis]]\n            for dim_mapping in sharding_spec.mesh_mapping\n        ]\n        for (mesh_dim, replica) in replicated_mesh_dim:\n            dim_local_mapping = []\n            for tensor_dim, dim_sharding in enumerate(sharding):\n                chunked_value = _get_chunk_value(dim_sharding)\n                chunked_len = var_shape[tensor_dim] // chunked_value\n                new_chunk = math.gcd(replica, chunked_len)\n                if new_chunk == 1:\n                    continue\n                sharding[tensor_dim] = _add_chunk(dim_sharding, new_chunk)\n                chunk_idx = len(sharding[tensor_dim].chunks) - 1\n                dim_local_mapping.append((tensor_dim, chunk_idx))\n\n                replica //= new_chunk\n                if replica == 1:\n                    break\n            if replica != 1:\n                logger.warning(\n                    \"ReshardingTask is not fully sharded, this causes \"\n                    \"redundant communication.\")\n            if len(dim_local_mapping) != 0:\n                squeezed_mesh_mapping[mesh_dim] = dim_local_mapping\n\n        mesh_mapping = _get_mesh_mapping(sharding, sharding_spec.mesh_mapping,\n                                         squeezed_mesh_mapping)\n        new_sharding_spec = pxla.ShardingSpec(sharding, mesh_mapping)\n        # sorted by (tensor dim, chunk idx, mesh dim)\n        return new_sharding_spec\n\n    def _create_resharding_specs(self):\n        stages = self._sharded_stages\n        meshes = self._schedule.meshes\n        num_stage = len(self._sharded_stages)\n        stage_placements = [\n            list(self._schedule.stage_placement(i))[0] for i in range(num_stage)\n        ]\n        deps = self._schedule.dependency\n        assert deps.shape[0] == num_stage\n        assert deps.shape[1] == num_stage\n\n        # Note(Hao): resharding_specs is num_mesh x num_mesh matrix\n        # Each element is a dict: the name of variables are keys, ReshardingSpec\n        # are values.\n        self.resharding_specs = [\n            [{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh)\n        ]\n\n        # find stages that will communicate\n        pairs = np.argwhere(deps > 0)\n        for i in range(pairs.shape[0]):\n            # for each pair of stages that are dependent,\n            src_stage_index = pairs[i][1]\n            src_stage = stages[src_stage_index]\n            dst_stage_index = pairs[i][0]\n            dst_stage = stages[dst_stage_index]\n            src_mesh_index = stage_placements[src_stage_index]\n            dst_mesh_index = stage_placements[dst_stage_index]\n            src_mesh = meshes[src_mesh_index]\n            dst_mesh = meshes[dst_mesh_index]\n\n            # we only take care of cross-mesh sharding.\n            if src_mesh_index == dst_mesh_index:\n                continue\n\n            # find out variables that need resharding, and get their\n            # (1) out_sharding_spec in the src stage\n            # (2) in_sharding_spec in the destination stage.\n            resharding_vars, out_var_indices, in_var_indices = (\n                self._args_between(src_stage, dst_stage))\n            out_sharding_specs = src_stage.output_sharding_specs\n            in_sharding_specs = dst_stage.input_sharding_specs\n\n            # Make a ReshardSpec for each VirtualDistributedArray\n            for var, out_var_index, in_var_index in zip(resharding_vars,\n                                                        out_var_indices,\n                                                        in_var_indices):\n                src_sharding_spec = out_sharding_specs[out_var_index]\n                dst_sharding_spec = in_sharding_specs[in_var_index]\n\n                final_dst_spec = dst_sharding_spec\n                if global_config.resharding_mode == \"send_recv\":\n                    dst_sharding_spec = self._rewrite_allgather_spec(\n                        dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape)\n\n                src_array = VirtualDistributedArray(\n                    device_mesh=src_mesh,\n                    aval=var.aval,\n                    sharding_spec=src_sharding_spec)\n                dst_array = VirtualDistributedArray(\n                    device_mesh=dst_mesh,\n                    aval=var.aval,\n                    sharding_spec=dst_sharding_spec)\n                task_spec = ReshardingTaskSpec(src_array, dst_array,\n                                               final_dst_spec)\n                self.resharding_specs[src_mesh_index][dst_mesh_index][\n                    var] = task_spec\n\n    def task_spec_iter(self):\n        \"\"\"A convenient iterator over all activated task specs.\"\"\"\n        for i in range(self.num_mesh):\n            for j in range(self.num_mesh):\n                if not self.resharding_specs[i][j]:\n                    continue\n                yield i, j, self.resharding_specs[i][j]\n\n    @staticmethod\n    def get_resources_info_in_mesh(mesh):\n        device_strs = []\n        device_host_map = {}\n        nic_constraints = []\n\n        for i in range(mesh.num_hosts):\n            ip = mesh.host_info[i][\"NodeManagerAddress\"]\n            one_nic_constraint = []\n            for device in mesh.devices[i]:\n                device_str = device_id_to_str(ip, device)\n                device_strs.append(device_str)\n                one_nic_constraint.append(device_str)\n                #TODO: Here we assume there is only one NIC in one host.\n                device_host_map[device_str] = ip\n            nic_constraints.append(one_nic_constraint)\n        return device_strs, device_host_map, nic_constraints\n\n    @staticmethod\n    def _get_hardware_info_for_loadbalance(src_mesh, dst_mesh):\n        src_mesh_devices, src_device_host_map, src_nic_constraints = (\n            CrossMeshCommunicator.get_resources_info_in_mesh(src_mesh))\n        dst_mesh_devices, dst_device_host_map, dst_nic_constraints = (\n            CrossMeshCommunicator.get_resources_info_in_mesh(dst_mesh))\n        device_host_map = {**src_device_host_map, **dst_device_host_map}\n        nic_constraints = src_nic_constraints + dst_nic_constraints\n        return (src_mesh_devices, dst_mesh_devices, device_host_map,\n                nic_constraints)\n\n    @staticmethod\n    def _generate_send_recv_resharding_strategy_by_loads(\n            spec: ReshardingTaskSpec, src_loads, dst_loads):\n        \"\"\"Generate the resharding strategy by balancing loads.\"\"\"\n        is_local_allgather = spec.final_dst_spec != spec.dst_sharding_spec\n        per_spec_plans = []\n        for dst_tile, src_tileslices, _ in spec.dst_tile_to_src_tiles_map:\n            # plan is a 2D array\n            per_spec_plan = np.empty(\n                (len(dst_tile.replica_device_strs), len(src_tileslices)),\n                dtype=object)\n            for receiver_idx, receiver in enumerate(\n                    dst_tile.replica_device_strs):\n                for src_tileslice_idx, src_tileslice in enumerate(\n                        src_tileslices):\n                    loads = {\n                        sender: src_loads[sender]\n                        for sender in src_tileslice.replica_device_strs\n                    }\n                    sender = min(loads, key=loads.get)\n                    per_spec_plan[receiver_idx][src_tileslice_idx] = sender\n                    # upload load on-the-fly\n                    src_loads[sender] += src_tileslice.slice_size\n                    dst_loads[receiver] += src_tileslice.slice_size\n            per_spec_plans.append(per_spec_plan)\n\n        strategy = ReshardingStrategy(\"sendrecv\", per_spec_plans,\n                                      spec.generate_naive_order(\"sendrecv\"),\n                                      is_local_allgather)\n        return strategy\n\n    def _generate_send_recv_resharding_strategy(self, spec: ReshardingTaskSpec,\n                                                src_mesh, dst_mesh):\n        if global_config.resharding_loadbalance_mode == \"normal\":\n            strategy = (self._generate_send_recv_resharding_strategy_by_loads(\n                spec, self._sender_loads, self._receiver_loads))\n        elif global_config.resharding_loadbalance_mode == \"no_loadbalance\":\n            strategy = (\n                self._generate_send_recv_resharding_strategy_by_no_load(spec))\n        elif global_config.resharding_loadbalance_mode in ([\n                \"loadbalance_size\", \"loadbalance_order\"\n        ]):\n            strategy = self.\\\n            _generate_send_recv_resharding_strategy_by_loadbalance(\n                spec, src_mesh, dst_mesh)\n        else:\n            raise NotImplementedError()\n        return strategy\n\n    def _generate_broadcast_resharding_strategy(self, spec: ReshardingTaskSpec,\n                                                src_mesh, dst_mesh):\n        if global_config.resharding_loadbalance_mode == \"normal\":\n            strategy = (self._generate_broadcast_resharding_strategy_by_loads(\n                spec, self._sender_loads, self._receiver_loads))\n        elif global_config.resharding_loadbalance_mode == \"no_loadbalance\":\n            strategy = (\n                self._generate_broadcast_resharding_strategy_by_no_load(spec))\n        elif global_config.resharding_loadbalance_mode in [\n                \"loadbalance_size\", \"loadbalance_order\"\n        ]:\n            strategy = (\n                self._generate_broadcast_resharding_strategy_by_loadbalance(\n                    spec, src_mesh, dst_mesh))\n        else:\n            raise NotImplementedError()\n        return strategy\n\n    @staticmethod\n    def _generate_send_recv_resharding_strategy_by_no_load(\n            spec: ReshardingTaskSpec):\n        \"\"\"Generate the resharding strategy by balancing loads.\"\"\"\n        is_local_allgather = spec.final_dst_spec != spec.dst_sharding_spec\n        per_spec_plans = []\n        for dst_tile, src_tileslices, _ in spec.dst_tile_to_src_tiles_map:\n            # plan is a 2D array\n            per_spec_plan = np.empty(\n                (len(dst_tile.replica_device_strs), len(src_tileslices)),\n                dtype=object)\n            for receiver_idx, _ in enumerate(dst_tile.replica_device_strs):\n                for src_tileslice_idx, src_tileslice in enumerate(\n                        src_tileslices):\n                    sender = src_tileslice.replica_device_strs[0]\n                    # Choose an arbitrary sender without considering loads\n                    per_spec_plan[receiver_idx][src_tileslice_idx] = sender\n            per_spec_plans.append(per_spec_plan)\n\n        strategy = ReshardingStrategy(\"sendrecv\", per_spec_plans,\n                                      spec.generate_naive_order(\"sendrecv\"),\n                                      is_local_allgather)\n        return strategy\n\n    @staticmethod\n    def _generate_send_recv_resharding_strategy_by_loadbalance(\n            spec, src_mesh, dst_mesh):\n        \"\"\"\n            Generate the send/recv-based resharding strategy by balancing\n            loads and along time.\n        \"\"\"\n\n        # pre-process\n        src_mesh_devices, dst_mesh_devices, device_host_map, nic_constraints = (\n            CrossMeshCommunicator._get_hardware_info_for_loadbalance(\n                src_mesh, dst_mesh))\n\n        works = []\n        for i, (dst_tile, src_tileslices,\n                _) in enumerate(spec.dst_tile_to_src_tiles_map):\n            for receiver in dst_tile.replica_device_strs:\n                for j, src_tileslice in enumerate(src_tileslices):\n                    senders = src_tileslice.replica_device_strs\n                    data_size = src_tileslice.tile_size\n                    works.append(\n                        SingleReshardingLoadBalancingWork(\n                            senders, [receiver], data_size))\n\n        # solve and get solution\n        task = ReshardingLoadBalancingTaskSolver(src_mesh_devices,\n                                                 dst_mesh_devices,\n                                                 device_host_map, works,\n                                                 nic_constraints)\n\n        sol_assigned_sender, sol_order = task.solve()\n\n        # post-process\n        per_spec_plans = []\n        rank_to_idx = []\n        cnt = 0\n        for i, (dst_tile, src_tileslices,\n                _) in enumerate(spec.dst_tile_to_src_tiles_map):\n            per_spec_plan = np.empty(\n                (len(dst_tile.replica_device_strs), len(src_tileslices)),\n                dtype=object)\n            for k, receiver in enumerate(dst_tile.replica_device_strs):\n                for j, src_tileslice in enumerate(src_tileslices):\n                    sender = sol_assigned_sender[cnt]\n                    per_spec_plan[k][j] = sender\n                    rank_to_idx.append((i, k, j))\n                    cnt += 1\n            per_spec_plans.append(per_spec_plan)\n\n        order = [rank_to_idx[i] for i in sol_order]\n        is_local_allgather = spec.final_dst_spec != spec.dst_sharding_spec\n        strategy = ReshardingStrategy(\"sendrecv\", per_spec_plans, order,\n                                      is_local_allgather)\n        return strategy\n\n    @staticmethod\n    def _generate_broadcast_resharding_strategy_by_no_load(\n            spec: ReshardingTaskSpec):\n        \"\"\"\n            Generate the broadcast-based resharding strategy by balancing\n            loads. For each tile, I not only allow one source to provide\n            the tile.\n        \"\"\"\n        # pylint: disable=unused-argument\n        per_spec_plans = []\n        for _, src_tileslices, _ in spec.dst_tile_to_src_tiles_map:\n            per_spec_plan = np.empty((len(src_tileslices),), dtype=object)\n\n            for src_tileslice_idx, src_tileslice in enumerate(src_tileslices):\n                per_spec_plan[\n                    src_tileslice_idx] = src_tileslice.replica_device_strs[0]\n            per_spec_plans.append(per_spec_plan)\n        strategy = ReshardingStrategy(\"broadcast\", per_spec_plans,\n                                      spec.generate_naive_order(\"broadcast\"),\n                                      None)\n        return strategy\n\n    @staticmethod\n    def _generate_broadcast_resharding_strategy_by_loadbalance(\n            spec, src_mesh, dst_mesh):\n        \"\"\"\n            Generate the broadcast-based resharding strategy by balancing\n            loads and along time.\n        \"\"\"\n\n        # pre-process\n        src_mesh_devices, dst_mesh_devices, device_host_map, nic_constraints = (\n            CrossMeshCommunicator._get_hardware_info_for_loadbalance(\n                src_mesh, dst_mesh))\n\n        works = []\n        for i, (dst_tile, src_tileslices,\n                _) in enumerate(spec.dst_tile_to_src_tiles_map):\n            for j, src_tileslice in enumerate(src_tileslices):\n                senders = src_tileslice.replica_device_strs\n                receivers = dst_tile.replica_device_strs\n                data_size = src_tileslice.tile_size\n                works.append(\n                    SingleReshardingLoadBalancingWork(senders, receivers,\n                                                      data_size))\n\n        # solve and get solution\n        task = ReshardingLoadBalancingTaskSolver(src_mesh_devices,\n                                                 dst_mesh_devices,\n                                                 device_host_map, works,\n                                                 nic_constraints)\n\n        sol_assigned_sender, sol_order = task.solve()\n\n        # post-process\n        per_spec_plans = []\n        rank_to_idx = []\n        cnt = 0\n        for i, (dst_tile, src_tileslices,\n                _) in enumerate(spec.dst_tile_to_src_tiles_map):\n            per_spec_plan = np.empty((len(src_tileslices),), dtype=object)\n            for j, src_tileslice in enumerate(src_tileslices):\n                sender = sol_assigned_sender[cnt]\n                per_spec_plan[j] = sender\n                rank_to_idx.append((i, j))\n                cnt += 1\n            per_spec_plans.append(per_spec_plan)\n\n        order = [rank_to_idx[i] for i in sol_order]\n        strategy = ReshardingStrategy(\"broadcast\", per_spec_plans, order, None)\n        return strategy\n\n    @staticmethod\n    def _generate_broadcast_resharding_strategy_by_loads(\n            spec, src_loads, dst_loads):\n        \"\"\"\n            Generate the broadcast-based resharding strategy by balancing loads.\n            For each tile, I not only allow one source to provide the tile.\n        \"\"\"\n        # pylint: disable=unused-argument\n        per_spec_plans = []\n        dst_loads = None\n        for _, src_tileslices, _ in spec.dst_tile_to_src_tiles_map:\n            per_spec_plan = np.empty((len(src_tileslices),), dtype=object)\n\n            for src_tileslice_idx, src_tileslice in enumerate(src_tileslices):\n                loads = {\n                    sender: src_loads[sender]\n                    for sender in src_tileslice.replica_device_strs\n                }\n                sender = min(loads, key=loads.get)\n\n                per_spec_plan[src_tileslice_idx] = sender\n                src_loads[sender] += src_tileslice.slice_size\n            per_spec_plans.append(per_spec_plan)\n        strategy = ReshardingStrategy(\"broadcast\", per_spec_plans,\n                                      spec.generate_naive_order(\"broadcast\"),\n                                      None)\n        return strategy\n\n    @staticmethod\n    def _args_between(src_stage, dst_stage):\n        \"\"\"Find the variable exchanged between stages.\"\"\"\n        resharding_vars = []\n        src_indices = []\n        dst_indices = []\n        for i, var in enumerate(src_stage.outvars):\n            if var in dst_stage.invars:\n                resharding_vars.append(var)\n                src_indices.append(i)\n                dst_indices.append(dst_stage.invars.index(var))\n        return resharding_vars, src_indices, dst_indices\n\n\nSingleReshardingLoadBalancingWork = namedtuple(\n    \"SingleReshardingLoadBalancingWork\", [\"senders\", \"receivers\", \"data_size\"])\nSingleAbstractedLoadBalancingWork = namedtuple(\n    \"SingleAbstractedLoadBalancingWork\",\n    [\"sender_ids\", \"receiver_ids\", \"duration\"])\n\n\nclass ReshardingLoadBalancingTaskSolver:\n    \"\"\"This is class of solver for load balancing problem\"\"\"\n\n    def __init__(self,\n                 src_mesh_devices,\n                 dst_mesh_devices,\n                 device_host_map,\n                 works,\n                 nic_contraints,\n                 host_bridge_contraints=None):\n        \"\"\"We define the load balancing problem in resharding problem.\n        Here both send_recv and broadcast based implementation could\n        be formulated in this way.\n\n        Args:\n            src_mesh_devices: All gpus in src mesh.\n            dst_mesh_devices: All gpus in dst mesh.\n            device_host_map: a map from device to its corresponding host.\n            works (List[SingleReshardingLoadBalancingWork]): all works to\n                be scheduled in this task.\n            nic_contraints (List[List[device]]): each list[device] contains\n                a set of devices that competes for the same NIC.\n                Now I assmue sender and receiver do not share NIC.\n                The assumption is met in nic_contraints.\n                I assume these constraints are disjoint sets.\n        \"\"\"\n        self.src_mesh_devices = src_mesh_devices\n        self.dst_mesh_devices = dst_mesh_devices\n        self.all_devices = list(\n            set(src_mesh_devices).union(set(dst_mesh_devices)))\n        self.device_host_map = device_host_map\n        self.works = works\n        self.nic_contraints = nic_contraints\n        self.host_bridge_contraints = host_bridge_contraints\n\n        # self.print_task()\n\n    def solve(self):\n        \"\"\"\n            Return two data\n            1. The first List[device] represents which sender to choose\n               for each work.\n            2. The second List[int] represents the order to execute\n               these works.\n        \"\"\"\n\n        # Deal with the case when a src device share the same NIC with a tar\n        # device. Now I assmue they do not share NIC. The assumption is met\n        # in nic_contraints so we do not need to deal with it in this method.\n\n        tmp_device_to_worker_id_map = {\n            device: idx for idx, device in enumerate(self.all_devices)\n        }\n        for nic_contraint in self.nic_contraints:\n            min_id = min(\n                tmp_device_to_worker_id_map[device] for device in nic_contraint)\n            for device in nic_contraint:\n                tmp_device_to_worker_id_map[device] = min_id\n\n        device_to_worker_id_map = {}\n        worker_id_to_devices = {}\n        n_workers = 0\n        for idx, device in enumerate(self.all_devices):\n            if tmp_device_to_worker_id_map[device] == idx:\n                device_to_worker_id_map[device] = n_workers\n                worker_id_to_devices[n_workers] = [device]\n                n_workers += 1\n            else:\n                group_head_device = self.all_devices[\n                    tmp_device_to_worker_id_map[device]]\n                worker_id = device_to_worker_id_map[group_head_device]\n                device_to_worker_id_map[device] = worker_id\n                worker_id_to_devices[worker_id].append(device)\n\n        abstract_works = []\n        for work in self.works:\n            sender_ids = set()\n            for sender in work.senders:\n                sender_ids.add(device_to_worker_id_map[sender])\n            sender_ids = list(sender_ids)\n            sender_ids.sort()\n            receiver_ids = set()\n            for receiver in work.receivers:\n                receiver_ids.add(device_to_worker_id_map[receiver])\n            receiver_ids = list(receiver_ids)\n            receiver_ids.sort()\n            time_spent = work.data_size\n\n            abstract_works.append(\n                SingleAbstractedLoadBalancingWork(sender_ids, receiver_ids,\n                                                  time_spent))\n\n        if global_config.resharding_loadbalance_mode == \"loadbalance_size\":\n            task = LoadBalancingOverSizeTaskSolver(n_workers, abstract_works)\n        else:\n            if global_config.loadbalance_order_algo == \"search\":\n                task = LoadBalancingTaskSolverSearchAlgo(\n                    n_workers, abstract_works)\n            else:\n                task = LoadBalancingTaskSolverGreedyAlgo(\n                    n_workers, abstract_works)\n\n        sol_assigned_sender_id, sol_order = task.solve()\n\n        sol_assigned_sender = []\n        for work, worker_id in zip(self.works, sol_assigned_sender_id):\n            selected_sender = None\n            for sender in work.senders:\n                if device_to_worker_id_map[sender] == worker_id:\n                    selected_sender = sender\n                    break\n            assert selected_sender is not None\n            sol_assigned_sender.append(selected_sender)\n        return sol_assigned_sender, sol_order\n\n    def print_task(self):\n        print(\"\\nTask[START]\")\n        print(f\"src_mesh_devices: {self.src_mesh_devices}\")\n        print(f\"dst_mesh_devices: {self.dst_mesh_devices}\")\n        print(f\"device_host_map: {self.device_host_map}\")\n        print(\"works:\")\n        for work in self.works:\n            print(work)\n        print(\"nic_contraints:\")\n        for contraint in self.nic_contraints:\n            print(contraint)\n        print(\"Task[END]\\n\")\n\n\nclass AbstractedLoadBalancingTaskSolver(ABC):\n    \"\"\"This is class of solver for abstracted load balancing problem\"\"\"\n\n    def __init__(self, n_workers, works):\n        \"\"\"We abstract the load balancing problem into this mathematically\n        clear form.\n\n        Args:\n            n_workers (int): The total number of single threaded\n                workers in this loadbalancing task.\n            works (List[SingleAbstractedLoadBalancingWork]): all works to\n                be scheduled in this task.\n        \"\"\"\n        self.n_workers = n_workers\n        self.n_works = len(works)\n        self.works = works\n        self.loads = [0 for _ in range(n_workers)]\n\n        # self.print_task()\n\n    @abstractmethod\n    def solve(self):\n        \"\"\"\n            Return two list[int] of length n_works\n            1. The first represents which sender to choose for each work.\n            2. The second represents the order to execute these works.\n        \"\"\"\n        raise NotImplementedError\n\n    def print_task(self):\n        print(\"AbstractedTask[START]\")\n        print(f\"n_workers: {self.n_workers}\")\n        print(\"works:\")\n        for work in self.works:\n            print(work)\n        print(\"AbstractedTask[END]\")\n\n\nclass LoadBalancingTaskSolverGreedyAlgo(AbstractedLoadBalancingTaskSolver):\n    \"\"\"Implementation of load balance: use randomized greedy algorithm\"\"\"\n\n    def find_one_random_concurrent_set_of_works(self, works_ids):\n        \"\"\"This method finds one set of works that could be run\n           concurrently.\n\n        Args:\n            works_ids (List[int]): The ids of works that could be\n                selected.\n\n        Returns:\n            one_concurrent_works_ids (list[int]): The ids of works\n                selected in this method.\n            one_concurrent_selected_senders (list[int]): The assigned\n                senders for the selected works.\n        \"\"\"\n\n        def probability_of_being_selected(loads):\n            # these weights could be more carefully tuned.\n            max_weight = max(loads)\n            weights = [max_weight - weight + 1 for weight in loads]\n            return weights\n\n        used = [False for _ in range(self.n_workers)]\n        perm = np.random.permutation(np.array(works_ids))\n        one_concurrent_works_ids = []\n        one_concurrent_selected_senders = []\n        for i in perm:\n            work = self.works[i]\n            receivers_availability = True\n            for receiver in work.receiver_ids:\n                if used[receiver]:\n                    receivers_availability = False\n                    break\n            if not receivers_availability:\n                continue\n\n            available_senders = []\n            for sender in work.sender_ids:\n                if not used[sender]:\n                    available_senders.append(sender)\n            if not available_senders:\n                continue\n\n            weights = probability_of_being_selected(\n                [self.loads[sender] for sender in available_senders])\n            selected_sender = random.choices(available_senders,\n                                             weights=weights)[0]\n\n            used[selected_sender] = True\n            for receiver in work.receiver_ids:\n                used[receiver] = True\n\n            one_concurrent_works_ids.append(i)\n            one_concurrent_selected_senders.append(selected_sender)\n        return one_concurrent_works_ids, one_concurrent_selected_senders\n\n    def find_best_concurrent_set_of_works(self, works_ids, n_rounds=100):\n        \"\"\"\n            One simple strategy is that everytime we choose the maximum number\n            of works and minimize std and put them into the sequence.\n            The simple logic behind is to maximize concurrency.\n\n        Args:\n            works_ids (List[int]): All available works waiting for running.\n            n_rounds (int, optional): The number of rounds to run for finding\n                the best set of works. Defaults to 100.\n        \"\"\"\n\n        def calc_std(data):\n            ave = sum(data) / len(data)\n            std = (sum((x - ave)**2 for x in data) / len(data))**0.5\n            return std\n\n        # def calc_max(A):\n        #     return max(A)\n\n        max_num = 0\n        min_std = None\n        best_concurrent_works_ids = []\n        best_concurrent_selected_senders = []\n        for _ in range(n_rounds):\n            one_concurrent_works_ids, one_concurrent_selected_senders = \\\n                self.find_one_random_concurrent_set_of_works(works_ids)\n            num = len(one_concurrent_works_ids)\n            if num < max_num:\n                continue\n\n            loads = list(self.loads)\n            for work_id, selected_sender in zip(\n                    one_concurrent_works_ids, one_concurrent_selected_senders):\n                loads[selected_sender] += self.works[work_id].duration\n\n            # here we could use different criterions\n            std = calc_std(loads)  # calc_max(loads)\n            # std = calc_std(\n            # [self.works[i].duration for i in range(one_concurrent_works_ids)]\n            # )\n\n            if num > max_num or (num == max_num and std < min_std):\n                max_num = num\n                min_std = std\n                best_concurrent_works_ids = one_concurrent_works_ids\n                best_concurrent_selected_senders = (\n                    one_concurrent_selected_senders)\n        return best_concurrent_works_ids, best_concurrent_selected_senders\n\n    def solve(self):\n        sol_assigned_sender_id = [None for _ in range(len(self.works))]\n        sol_order = []\n        while True:\n            available_works_ids = [\n                i for i in range(len(self.works)) if i not in sol_order\n            ]\n            best_concurrent_works_ids, best_concurrent_selected_senders = \\\n                self.find_best_concurrent_set_of_works(available_works_ids)\n\n            for work_id, sender_id in zip(best_concurrent_works_ids,\n                                          best_concurrent_selected_senders):\n                sol_order.append(work_id)\n                sol_assigned_sender_id[work_id] = sender_id\n                self.loads[sender_id] += self.works[work_id].duration\n\n            if len(sol_order) == len(self.works):\n                break\n\n        assert None not in sol_assigned_sender_id\n\n        return sol_assigned_sender_id, sol_order\n\n\nclass LoadBalancingTaskSolverSearchAlgo(AbstractedLoadBalancingTaskSolver):\n    \"\"\"Implementation of load balance: use search algorithm with pruning\"\"\"\n\n    def __init__(self, n_workers, works):\n        super().__init__(n_workers, works)\n\n        self.sol_assigned_sender_id = [None for _ in range(len(self.works))]\n        self.sol_order = []\n        self.minimal_finish_time = None\n\n        self.cur_assigned_sender_id = [None for _ in range(len(self.works))]\n        self.cur_order = []\n\n        self.start_time = time.time()\n        self.search_time_threshold = 1\n\n    def evaluate_one_solution(self, assigned_sender_id, order):\n        \"\"\"Given current task-sender assigment and order to submit\n           these tasks, this method return the finishing time of each\n           receiver for the current schedule as solution.\n           To get the finishing time, this method just simulates the\n           whole process.\n\n        Args:\n            assigned_sender_id: This variable contains idx of sender\n                for each task.\n            order: The order to submit different tasks.\n\n        Returns:\n            current_time (list[int]): the time for each receiver\n                after finishing all the tasks assigned to it.\n        \"\"\"\n        current_time = [0 for _ in range(self.n_workers)]\n\n        for i in order:\n            work = self.works[i]\n            sender_id = assigned_sender_id[i]\n            mx_time = max([current_time[sender_id]] + [\n                current_time[receiver_id] for receiver_id in work.receiver_ids\n            ])\n            current_time[sender_id] = mx_time + work.duration\n            for receiver_id in work.receiver_ids:\n                current_time[receiver_id] = mx_time + work.duration\n        return current_time\n\n    def heuristic(self, current_time, remained_work_ids):\n        \"\"\" Given the current time for each receiver to finish\n            its assigned works, and the remained work to be\n            assigned, this method estimate the minimal amount\n            of time to finish all works. If the minimal amount\n            of time to finish all works is still longer than\n            current best solution, then we could prune the current\n            search branch.\n\n        Args:\n            current_time (list[int]): the time for each receiver\n                after finishing all the tasks assigned to it.\n            remained_work_ids (list[int]): The ids of works remained\n                to be assigned to workers.\n\n        Returns:\n            int: the minimal amount of time to finish all works\n                with current assignment and order schedule.\n        \"\"\"\n        remained_time_lowerbound = [0 for _ in range(self.n_workers)]\n        for i in remained_work_ids:\n            work = self.works[i]\n            sender_id_with_mintime = -1\n            for sender_id in work.sender_ids:\n                if sender_id_with_mintime == -1:\n                    sender_id_with_mintime = sender_id\n                elif (remained_time_lowerbound[sender_id] +\n                      current_time[sender_id] <\n                      remained_time_lowerbound[sender_id_with_mintime] +\n                      current_time[sender_id_with_mintime]):\n                    sender_id_with_mintime = sender_id\n            # heuristic function could be continuely improved.\n            remained_time_lowerbound[sender_id_with_mintime] += work.duration\n            for receiver_id in work.receiver_ids:\n                remained_time_lowerbound[receiver_id] += work.duration\n\n        max_time = max(\n            x + y for x, y in zip(remained_time_lowerbound, current_time))\n        return max_time\n\n    def dfs(self, depth):\n        \"\"\"This is the Depth First Search function\n           to search the order of submitting works\n           and sender for each work.\n\n        Args:\n            depth (int): The depth of the DFS; In other\n            words, we are deciding the depth_th task in\n            order array.\n        \"\"\"\n        if time.time() - self.start_time > self.search_time_threshold:\n            return\n\n        current_time = self.evaluate_one_solution(self.cur_assigned_sender_id,\n                                                  self.cur_order)\n\n        if depth == len(self.works):\n            finish_time = max(current_time)\n            if (self.minimal_finish_time is None or\n                    finish_time < self.minimal_finish_time):\n                self.minimal_finish_time = finish_time\n                self.sol_assigned_sender_id = list(self.cur_assigned_sender_id)\n                self.sol_order = list(self.cur_order)\n            return\n\n        remained_work_ids = [\n            i for i in range(len(self.works)) if i not in self.cur_order\n        ]\n\n        heuristic = self.heuristic(current_time, remained_work_ids)\n        if (self.minimal_finish_time is not None and\n                heuristic > self.minimal_finish_time):\n            return\n\n        for i in remained_work_ids:\n            self.cur_order.append(i)\n            work = self.works[i]\n            for sender_id in work.sender_ids:\n                self.cur_assigned_sender_id[i] = sender_id\n                self.dfs(depth + 1)\n            self.cur_assigned_sender_id[i] = None\n            self.cur_order.pop()\n\n    def solve(self):\n\n        self.dfs(depth=0)\n\n        assert None not in self.sol_assigned_sender_id\n\n        return self.sol_assigned_sender_id, self.sol_order\n\n\nclass LoadBalancingOverSizeTaskSolver(AbstractedLoadBalancingTaskSolver):\n    \"\"\"Implementation of load balance: only consider workers' workloads\"\"\"\n\n    def __init__(self, n_workers, works):\n        super().__init__(n_workers, works)\n\n        self.sol_assigned_sender_id = [None for _ in range(len(self.works))]\n        self.sol_order = []\n\n    def solve(self):\n        for i, work in enumerate(self.works):\n            loads = {sender: self.loads[sender] for sender in work.sender_ids}\n            sender = min(loads, key=loads.get)\n            self.sol_assigned_sender_id[i] = sender\n            self.loads[sender] += work.duration\n            self.sol_order.append(i)\n\n        assert None not in self.sol_assigned_sender_id\n\n        return self.sol_assigned_sender_id, self.sol_order\n"
  },
  {
    "path": "alpa/pipeline_parallel/layer_construction.py",
    "content": "\"\"\"Group small ops into layers and rematerialize at layer boundary.\"\"\"\nfrom abc import ABC, abstractmethod\nfrom functools import partial, wraps\nimport logging\nfrom typing import Callable, Iterable, Optional, Sequence, Union\n\nimport numpy as np\nfrom jax import lax\nfrom jax.tree_util import tree_flatten, tree_unflatten\nfrom jax._src.api import _check_callable, make_jaxpr\nfrom jax._src.ad_checkpoint import remat_p\nfrom jax.core import (Var, Jaxpr, ClosedJaxpr, DropVar, Literal, jaxpr_as_fun,\n                      gensym)\n\nfrom alpa.global_env import global_config\nfrom alpa.parallel_plan import PlacementSpec\nfrom alpa.pipeline_parallel.layer_stats import (global_invar_size,\n                                                is_nontrivial, eqn_flops,\n                                                heavy_count,\n                                                log_layer_slicing_stats)\nfrom alpa.pipeline_parallel.primitive_def import (pipeline_p,\n                                                  mark_pipeline_jaxpreqn)\nfrom alpa.util import (clone_jaxpr, clone_jaxpr_eqn, slices_to_jaxpr,\n                       OrderedSet, get_var_mapping, maybe_numba_jit,\n                       new_jaxpr_eqn)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.DEBUG)\n\nLAYER_HEAVY_OP_LOWER_BOUND = 3\nDEFAULT_EPS = 0.5\nDEFAULT_COST_CRITERIA = \"flops\"\n\n\nclass LayerOption(ABC):\n    \"\"\"Options of grouping operators into layers.\"\"\"\n\n    def __init__(self):\n        pass\n\n    @abstractmethod\n    def transform(self, func):\n        raise NotImplementedError()\n\n\nclass ManualLayerOption(LayerOption):\n    \"\"\"\n    Manually specifying the boundaries of layers by using\n    alpa.mark_pipeline_boundary()\n\n    Args:\n      remat_layer: Whether to use gradient rematerialization for each layer.\n      static_argnums: The indices of static arguments of the\n        forward function.\n    \"\"\"\n\n    def __init__(self,\n                 remat_layer: bool = False,\n                 static_argnums: Sequence[int] = ()):\n        self.remat_layer = remat_layer\n        self.static_argnums = static_argnums\n        super().__init__()\n\n    def transform(self, func):\n        return manual_layer_construction(func,\n                                         static_argnums=self.static_argnums,\n                                         remat_layer=self.remat_layer)\n\n\nclass AutoLayerOption(LayerOption):\n    \"\"\"\n    Use an algorithm to automatically group operators into\n    layers. The parameter `layer_num` specifies the number of\n    resulting layers. You can try a few values for this parameters.\n    The best choice of this value depends on the number of nodes in your\n    cluster and the number of repetitive blocks in your model.\n\n    Args:\n      layer_num: The number of layers to construct.\n      remat_mode: Whether to use automatic tensor rematerialization.\n        Possible choices:\n        {\"none\", \"fine_grained_remat\", \"coarse_grained_remat\"}.\n      fine_grained_remat_layer_num:\n        Only used for remat_mode == \"fine_grained_remat\".\n        The number of layers for auto_remat.\n      static_argnums: The indices of static arguments of the\n        forward function.\n      eps: The tolerance of inbalance of the costs of different layers.\n    \"\"\"\n\n    def __init__(self,\n                 layer_num: int,\n                 remat_mode: str = \"none\",\n                 fine_grained_remat_layer_num: Optional[int] = None,\n                 static_argnums: Sequence[int] = (),\n                 eps: float = DEFAULT_EPS):\n        super().__init__()\n        self.layer_num = layer_num\n        self.remat_mode = remat_mode\n        self.fine_grained_remat_layer_num = fine_grained_remat_layer_num\n        self.static_argnums = static_argnums\n        self.eps = eps\n\n    def transform(self, func):\n        if self.remat_mode == \"fine_grained_remat\":\n            func = automatic_remat(func,\n                                   layer_num=self.fine_grained_remat_layer_num)\n            use_remat = False\n        elif self.remat_mode == \"coarse_grained_remat\":\n            use_remat = True\n        else:\n            use_remat = False\n\n        return automatic_layer_construction(func,\n                                            static_argnums=self.static_argnums,\n                                            layer_num=self.layer_num,\n                                            remat_layer=use_remat,\n                                            eps=self.eps)\n\n\nclass FollowLayerOption(LayerOption):\n    \"\"\"Follow given input placement specs to construct the layer.\n\n    Args:\n      input_placement_specs: The flatten placement specs of inputs.\n      static_argnums: The indices of static arguments of the\n        forward function.\n    \"\"\"\n\n    def __init__(self,\n                 input_placement_specs: Sequence[PlacementSpec],\n                 num_meshes: int,\n                 static_argnums: Sequence[int] = ()):\n        super().__init__()\n        self.placement_specs = input_placement_specs\n        self.num_meshes = num_meshes\n        self.static_argnums = static_argnums\n\n    def transform(self, func):\n        return follow_layer_construction(func, self.static_argnums,\n                                         self.placement_specs, self.num_meshes)\n\n\ndef slice_eqns_by_layer_boundary(closed_jaxpr: ClosedJaxpr):\n    \"\"\"Slices eqns by layer boundary markers.\"\"\"\n    sliced_eqns = []\n    current_computation_eqns = []\n\n    for eqn in closed_jaxpr.jaxpr.eqns:\n        if (eqn.primitive is pipeline_p and\n                eqn.params[\"mark_type\"] == \"boundary\"):\n            sliced_eqns.append(current_computation_eqns)\n            current_computation_eqns = []\n        else:\n            current_computation_eqns.append(eqn)\n    sliced_eqns.append(current_computation_eqns)\n    return sliced_eqns\n\n\ndef add_pipeline_marks_for_sliced_eqns(closed_jaxpr: ClosedJaxpr, sliced_eqns):\n    \"\"\"Adds pipeline marks for sliced equations.\"\"\"\n    layer_num = len(sliced_eqns)\n    layer_pipeline_invars = [OrderedSet() for _ in range(layer_num)]\n    layer_pipeline_outvars = [OrderedSet() for _ in range(layer_num)]\n    var_layer_dict = {}\n    var_mapping = {}\n\n    # build mapping dicts for global invars\n    for var in closed_jaxpr.jaxpr.invars:\n        var_layer_dict[var] = -1\n\n    # build mapping dicts for all eqns\n    for i, eqns in enumerate(sliced_eqns):\n        for eqn in eqns:\n            for var in eqn.invars:\n                if (not isinstance(var, Literal) and\n                        var not in closed_jaxpr.jaxpr.constvars and\n                        var_layer_dict[var] != i):\n                    layer_pipeline_invars[i].add(var)\n                    if var_layer_dict[var] == -1:\n                        continue\n                    layer_pipeline_outvars[var_layer_dict[var]].add(var)\n            for var in eqn.outvars:\n                if not isinstance(var, DropVar):\n                    var_layer_dict[var] = i\n\n    # build mapping dict for global outvars\n    gensym_func = gensym([closed_jaxpr.jaxpr])\n    literal_outvar_eqns = []\n    literal_outvar_marker_invars = []\n    literal_outvar_marker_outvars = []\n    for idx, var in enumerate(closed_jaxpr.jaxpr.outvars):\n        if isinstance(var, Literal):\n            # add a dummy equation to transform a Literal into a normal Var\n            if isinstance(var.val, np.ndarray):\n                val = np.zeros_like(var.val)\n            elif isinstance(var.val, Iterable):\n                raise NotImplementedError()\n            else:\n                val = type(var.val)(0)\n            zero_literal = Literal(val, var.aval)\n            new_var = gensym_func(var.aval)\n            new_eqn = new_jaxpr_eqn([var, zero_literal], [new_var], lax.add_p,\n                                    {})\n            literal_outvar_eqns.append(new_eqn)\n            literal_outvar_marker_invars.append(new_var)\n            literal_outvar_marker_outvars.append(gensym_func(var.aval))\n            var_mapping[idx] = literal_outvar_marker_outvars[-1]\n        elif var in closed_jaxpr.jaxpr.constvars or var_layer_dict[var] == -1:\n            raise NotImplementedError(\n                \"Does not support this use case of output var.\")\n        else:\n            layer_pipeline_outvars[var_layer_dict[var]].add(var)\n\n    # build new equations\n    new_eqns = []\n    for i, eqns in enumerate(sliced_eqns):\n        # pipeline start eqn\n        computation_var_mapping = {}\n\n        pipeline_start_invars = []\n        pipeline_start_outvars = []\n        for var in layer_pipeline_invars[i]:\n            new_var = gensym_func(var.aval)\n            pipeline_start_invars.append(get_var_mapping(var_mapping, var))\n            pipeline_start_outvars.append(new_var)\n            computation_var_mapping[var] = new_var\n        new_eqns.append(\n            mark_pipeline_jaxpreqn(pipeline_start_invars,\n                                   pipeline_start_outvars, f\"layer_{i}\",\n                                   \"start\"))\n        # all other eqns\n        for eqn in (eqns + literal_outvar_eqns if i == 0 else eqns):\n            new_invars = [\n                get_var_mapping(computation_var_mapping, var)\n                for var in eqn.invars\n            ]\n            new_eqns.append(clone_jaxpr_eqn(eqn, new_invars))\n\n        # pipeline end eqn\n        pipeline_end_invars = list(\n            literal_outvar_marker_invars) if i == 0 else []\n        pipeline_end_outvars = list(\n            literal_outvar_marker_outvars) if i == 0 else []\n        for var in layer_pipeline_outvars[i]:\n            new_var = gensym_func(var.aval)\n            pipeline_end_invars.append(\n                get_var_mapping(computation_var_mapping, var))\n            pipeline_end_outvars.append(new_var)\n            var_mapping[var] = new_var\n        new_eqns.append(\n            mark_pipeline_jaxpreqn(pipeline_end_invars, pipeline_end_outvars,\n                                   f\"layer_{i}\", \"end\"))\n\n    new_outvars = []\n    for idx, var in enumerate(closed_jaxpr.jaxpr.outvars):\n        if isinstance(var, Literal):\n            new_outvars.append(var_mapping[idx])\n        else:\n            new_outvars.append(get_var_mapping(var_mapping, var))\n\n    new_closed_jaxpr = clone_jaxpr(closed_jaxpr,\n                                   outvars=new_outvars,\n                                   eqns=new_eqns)\n    return new_closed_jaxpr\n\n\ndef remat_sliced_eqns(origin_jaxpr, sliced_eqns):\n    \"\"\"Add tensor rematerialization for sliced equations.\"\"\"\n    ret_eqns = []\n\n    sliced_jaxprs = slices_to_jaxpr(origin_jaxpr, sliced_eqns)\n    for jaxpr in sliced_jaxprs:\n        new_invars = jaxpr.jaxpr.invars + jaxpr.jaxpr.constvars\n        new_jaxpr = Jaxpr([], new_invars, jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns)\n        ret_eqns.append([\n            new_jaxpr_eqn(\n                new_invars, new_jaxpr.outvars, remat_p,\n                dict(jaxpr=new_jaxpr,\n                     prevent_cse=True,\n                     differentiated=False,\n                     policy=None))\n        ])\n    return ret_eqns\n\n\ndef jaxpr_eqns_input_sizes(jaxpr) -> np.ndarray:\n    \"\"\"Return a list of input sizes for each equation in the jaxpr.\n\n    Args:\n        jaxpr: Jaxpr to get input sizes for.\n\n    Returns:\n        A #eqns * #eqns numpy array of input sizes. cost[l, r] represents the\n        input size of the l-th to (r - 1)-th equation in the jaxpr.\n    \"\"\"\n    length = len(jaxpr.eqns)\n    input_sizes = np.full((length + 1, length + 1), 0, dtype=np.float32)\n\n    outvars = OrderedSet()\n    for k in range(0, length + 1):\n        if k > 0:\n            outvars = outvars.union(jaxpr.eqns[k - 1].outvars)\n        invars = OrderedSet()\n        total_size = 0\n        for r in range(k + 1, length + 1):\n            for invar in jaxpr.eqns[r - 1].invars:\n                if (isinstance(invar, Var) and invar in outvars and\n                        invar not in invars):\n                    invars.add(invar)\n                    total_size += invar.aval.size * invar.aval.dtype.itemsize\n            input_sizes[k, r] = total_size\n    return input_sizes\n\n\ndef get_layer_construction_costs(jaxpr, cost_criteria=\"flops\"):\n    \"\"\"Gets the layer construction cost.\"\"\"\n    nontrivial = np.array([is_nontrivial(eqn) for eqn in jaxpr.eqns],\n                          dtype=np.int32)\n    input_sizes = jaxpr_eqns_input_sizes(jaxpr)\n    if cost_criteria == \"flops\":\n        compute_costs = np.array([\n            eqn_flops(eqn) if nt else 0\n            for nt, eqn in zip(nontrivial, jaxpr.eqns)\n        ],\n                                 dtype=np.float64)\n    elif cost_criteria == \"count\":\n        compute_costs = np.array([\n            heavy_count(eqn) if nt else 0\n            for nt, eqn in zip(nontrivial, jaxpr.eqns)\n        ],\n                                 dtype=np.float64)\n    elif cost_criteria == \"input_memory\":\n        cost_fn = partial(global_invar_size, set(jaxpr.jaxpr.invars))\n        compute_costs = np.array([cost_fn(eqn) for eqn in jaxpr.eqns],\n                                 dtype=np.float64)\n    else:\n        raise ValueError(f\"Unrecoginzed cost criteria {cost_criteria}\")\n    return nontrivial, input_sizes, compute_costs\n\n\ndef cluster_jaxpr_by_cost(jaxpr: Jaxpr, layer_num: int, eps: float, costs,\n                          cost_criteria):\n    \"\"\"Clusters the jaxpr by cost.\"\"\"\n    layer_num = int(layer_num)\n    length = len(jaxpr.eqns)\n    non_trivial, input_sizes, compute_costs = costs\n    compute_costs_avg = compute_costs.sum() / layer_num\n    if cost_criteria in (\"flops\", \"input_memory\"):\n        compute_costs_bound = compute_costs_avg * (1 + eps)\n    elif cost_criteria == \"count\":\n        compute_costs_bound = max(compute_costs_avg * (1 + eps),\n                                  compute_costs_avg + 5)\n    else:\n        raise ValueError(f\"Unrecoginzed cost criteria {cost_criteria}\")\n    layer_heavy_op_lower_bound = LAYER_HEAVY_OP_LOWER_BOUND\n    if sum(non_trivial) / layer_num < layer_heavy_op_lower_bound:\n        layer_heavy_op_lower_bound = int(sum(non_trivial) / layer_num)  # noqa\n        logger.warning(\n            \"Too few non-trivial ops (dot, conv), which may influence\"\n            \" auto-sharding performance\")\n\n    @maybe_numba_jit\n    def init():\n        blocked = np.full((length + 1, length + 1), np.inf, dtype=np.float32)\n        for left in range(1, length + 1):\n            cnt = 0\n            total_compute_cost = 0\n            for r in range(left, length + 1):\n                if non_trivial[r - 1]:\n                    cnt += 1\n                    total_compute_cost += compute_costs[r - 1]\n                if cnt < layer_heavy_op_lower_bound:\n                    if total_compute_cost >= compute_costs_bound:\n                        blocked[left, r] = 0\n                    continue\n                if (total_compute_cost >= compute_costs_bound and\n                        non_trivial[r - 1] and\n                        cnt > layer_heavy_op_lower_bound):\n                    break\n                blocked[left, r] = 0\n        return blocked\n\n    @maybe_numba_jit\n    def dp(input_sizes, blocked):\n        max_cost = np.full((length + 1, layer_num + 1),\n                           np.inf,\n                           dtype=np.float32)\n        sum_cost_under_max = np.full((length + 1, layer_num + 1),\n                                     np.inf,\n                                     dtype=np.float32)\n        max_cost_argmin = np.full((length + 1, layer_num + 1),\n                                  -1,\n                                  dtype=np.int32)\n        solution_imbalance = np.full((length + 1, layer_num + 1),\n                                     np.inf,\n                                     dtype=np.float32)\n        max_cost[0, 0] = 0\n        sum_cost_under_max[0, 0] = 0\n        # Currently use variance to measure imbalance\n        for r in range(0, length + 1):\n            solution_imbalance[r, 0] = 0\n\n        for q in range(1, layer_num + 1):\n            for r in range(1, length + 1):\n                for k in range(0, r):\n                    new_value = max(max_cost[k, q - 1],\n                                    blocked[k + 1, r] + input_sizes[k, r])\n                    new_sum = (sum_cost_under_max[k, q - 1] +\n                               blocked[k + 1, r] + input_sizes[k, r])\n                    new_imbalance = (solution_imbalance[k, q - 1] + k**2 / q -\n                                     r**2 / (q + 1) + (r - k)**2)\n                    if (new_value < max_cost[r, q] or\n                        (new_value <= max_cost[r, q] * (1 + 1e-4) and\n                         (new_sum < sum_cost_under_max[r, q] or\n                          (new_sum <= sum_cost_under_max[r, q] * (1 + 1e-4) and\n                           new_imbalance < solution_imbalance[r, q])))):\n                        max_cost[r, q] = new_value\n                        sum_cost_under_max[r, q] = new_sum\n                        max_cost_argmin[r, q] = k\n                        solution_imbalance[r, q] = new_imbalance\n        return max_cost_argmin, max_cost[length, layer_num]\n\n    blocked = init()\n    a_argmin, value = dp(input_sizes, blocked)\n\n    reversed_sliced_eqns = []\n\n    r = length\n    for q in range(layer_num, 0, -1):\n        k = a_argmin[r, q]\n        reversed_sliced_eqns.append(jaxpr.eqns[k:r])\n        r = k\n    assert r == 0, \"No solution for layer construction.\"\n    solution = list(reversed(reversed_sliced_eqns))\n\n    # print(\"dp solution\")\n    # for i, eqns in enumerate(solution):\n    #    invars = OrderedSet()\n    #    for eqn in eqns:\n    #        invars.update([var for var in eqn.invars if isinstance(var, Var)])\n    #    invars.intersection_update(jaxpr.jaxpr.invars)\n    #    print(f\"mesh: {i},  set_shapes: \"\n    #          f\"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}\")\n    #\n    #    invars = []\n    #    for eqn in eqns:\n    #        tmp_set = set([var for var in eqn.invars if isinstance(var, Var)])\n    #        tmp_set.intersection_update(jaxpr.jaxpr.invars)\n    #        invars.extend(list(tmp_set))\n    #    print(f\"mesh: {i}, list_shapes: \"\n    #          f\"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}\")\n\n    solution_info = {\n        \"total_cost\": value,\n    }\n    return solution, solution_info\n\n\ndef search_layer_num(jaxpr,\n                     eps,\n                     layer_eps=0,\n                     cost_criteria=DEFAULT_COST_CRITERIA):\n    \"\"\"TODO(zhuohan): docstring.\"\"\"\n    non_trivial, input_sizes, compute_costs = get_layer_construction_costs(\n        jaxpr)\n    layer_num = 2\n    r = int(non_trivial.sum() / 3) + 1\n    _, solution_info = cluster_jaxpr_by_cost(\n        jaxpr,\n        layer_num,\n        eps, (non_trivial, input_sizes, compute_costs),\n        cost_criteria=cost_criteria)\n    l_val = solution_info[\"total_cost\"]\n    while r - layer_num > 1:\n        mid = int((layer_num + r) / 2)\n        _, solution_info = cluster_jaxpr_by_cost(\n            jaxpr,\n            mid,\n            eps, (non_trivial, input_sizes, compute_costs),\n            cost_criteria=cost_criteria)\n        mid_val = solution_info[\"total_cost\"]\n        if mid_val > l_val * (1 + layer_eps):\n            r = mid\n        else:\n            layer_num = mid\n    return layer_num\n\n\ndef layer_level_jaxpr_transformation(fn: Callable,\n                                     static_argnums: Sequence[int] = (),\n                                     remat: bool = False,\n                                     layer_construction: bool = False,\n                                     auto_layer_boundary: bool = False,\n                                     layer_num: Union[int, str] = None,\n                                     eps: float = DEFAULT_EPS,\n                                     cost_criteria: str = DEFAULT_COST_CRITERIA,\n                                     layer_eps: float = 0.0):\n    \"\"\"TODO(zhuohan): docstring.\"\"\"\n    if not remat and not layer_construction:\n        return fn\n\n    @wraps(fn)\n    def wrapped(*args):\n        jaxpr, out_shape_tree = make_jaxpr(fn,\n                                           static_argnums=static_argnums,\n                                           return_shape=True)(*args)\n        if auto_layer_boundary:\n            nonlocal layer_num\n            if layer_num == \"auto\":\n                layer_num = search_layer_num(jaxpr, eps, layer_eps)\n            costs = get_layer_construction_costs(jaxpr,\n                                                 cost_criteria=cost_criteria)\n            sliced_eqns, _ = cluster_jaxpr_by_cost(jaxpr,\n                                                   layer_num,\n                                                   eps,\n                                                   costs,\n                                                   cost_criteria=cost_criteria)\n        else:\n            sliced_eqns = slice_eqns_by_layer_boundary(jaxpr)\n\n        if global_config.print_auto_layer_stats:\n            log_layer_slicing_stats(jaxpr, sliced_eqns)\n\n        if remat:\n            sliced_eqns = remat_sliced_eqns(jaxpr, sliced_eqns)\n\n        if layer_construction:\n            jaxpr = add_pipeline_marks_for_sliced_eqns(jaxpr, sliced_eqns)\n        else:\n            jaxpr = clone_jaxpr(jaxpr,\n                                eqns=[x for eqns in sliced_eqns for x in eqns])\n\n        flatten_args, _ = tree_flatten(args)\n        ans = jaxpr_as_fun(jaxpr)(*flatten_args)  # pylint: disable=not-callable\n        _, out_tree = tree_flatten(out_shape_tree)\n        return tree_unflatten(out_tree, ans)\n\n    return wrapped\n\n\ndef manual_remat(fun: Callable = None, *, static_argnums: Sequence[int] = ()):\n    \"\"\"Rematerialize an input function with manually selected layer boundaries.\n\n    Rematerialize each layer of an input function with manually selected layer\n    boundaries indicated by pipeline markers.\n\n    Args:\n        fun: the input function to rematerialize.\n        static_argnums: An optional int or collection of ints that specify\n          which positional arguments to treat as static (compile-time constant).\n          Same as in jax.jit\n    Returns:\n        A new function rematerializes each layer of the input function.\n    \"\"\"\n\n    def decorate_fun(fun):\n        return layer_level_jaxpr_transformation(fun,\n                                                static_argnums,\n                                                remat=True,\n                                                layer_construction=False,\n                                                auto_layer_boundary=False)\n\n    if fun is None:\n        return decorate_fun\n    else:\n        _check_callable(fun)\n        return decorate_fun(fun)\n\n\ndef automatic_remat(fun: Callable = None,\n                    *,\n                    static_argnums: Sequence[int] = (),\n                    layer_num: Union[int, str] = None,\n                    eps: float = DEFAULT_EPS,\n                    cost_criteria: str = DEFAULT_COST_CRITERIA,\n                    layer_eps: float = 0.0):\n    \"\"\"Rematerialize an input function with automatic boundaries.\n\n    Rematerialize each layer of an input function with automatically decided\n    layer boundaries.\n\n    Args:\n        fun: The input function to rematerialize.\n        static_argnums: An optional int or collection of ints that specify\n          which positional arguments to treat as static (compile-time constant).\n          Same as in jax.jit\n        layer_num: The number of layers to rematerialize. If set to \"auto\", the\n          number of layers will be automatically determined by a binary search.\n          The binary search might not work for complex input functions.\n        eps: The tolerance of inbalance of the costs of different layers.\n        cost_criteria: The cost criteria to use for deciding the layers.\n        layer_eps: A parameter for layer_num binary search.\n\n    Returns:\n        A new function rematerializes each layer of the input function.\n    \"\"\"\n\n    def decorate_fun(fun):\n        return layer_level_jaxpr_transformation(fun,\n                                                static_argnums,\n                                                remat=True,\n                                                layer_construction=False,\n                                                auto_layer_boundary=True,\n                                                layer_num=layer_num,\n                                                eps=eps,\n                                                cost_criteria=cost_criteria,\n                                                layer_eps=layer_eps)\n\n    if fun is None:\n        return decorate_fun\n    else:\n        _check_callable(fun)\n        return decorate_fun(fun)\n\n\ndef manual_layer_construction(fun: Callable = None,\n                              *,\n                              static_argnums: Sequence[int] = (),\n                              remat_layer: bool = False):\n    \"\"\"Setup manually selected layer boundaries.\n\n    Add input variables of each layer to its start pipeline marker and output\n    variables of each layer to its end pipeline marker.\n\n    Args:\n        fun: the input function.\n        static_argnums: An optional int or collection of ints that specify\n          which positional arguments to treat as static (compile-time constant).\n          Same as in jax.jit\n        remat_layer: Whether to rematerialize each layer at layer boundaries.\n    Returns:\n        A new function with correctly setup pipeline markers.\n    \"\"\"\n\n    def decorate_fun(fun):\n        return layer_level_jaxpr_transformation(fun,\n                                                static_argnums,\n                                                remat=remat_layer,\n                                                layer_construction=True,\n                                                auto_layer_boundary=False)\n\n    if fun is None:\n        return decorate_fun\n    else:\n        _check_callable(fun)\n        return decorate_fun(fun)\n\n\ndef automatic_layer_construction(fun: Callable = None,\n                                 *,\n                                 static_argnums: Sequence[int] = (),\n                                 layer_num: int = None,\n                                 remat_layer: bool = False,\n                                 eps: float = DEFAULT_EPS,\n                                 cost_criteria: str = DEFAULT_COST_CRITERIA,\n                                 layer_eps: float = 0.0):\n    \"\"\"Automatically cluster the equations in a jaxpr into layers.\n    Automatically cluster the equations in a jaxpr into layers and add pipeline\n    markers at layer boundaries.\n    Args:\n        fun: the input function.\n        static_argnums: An optional int or collection of ints that specify\n          which positional arguments to treat as static (compile-time constant).\n          Same as in jax.jit\n        layer_num: the number of layers to rematerialize. If set to \"auto\", the\n          number of layers will be automatically determined by a binary search.\n          The binary search might not work for complex input functions.\n        remat_layer: Whether to rematerialize each layer at layer boundaries.\n        eps: the tolerance of inbalance of the costs of different layers.\n        cost_criteria: the cost criteria to use for deciding the layers.\n        layer_eps: a parameter for layer_num binary search.\n    Returns:\n        A new function rematerializes each layer of the input function.\n    \"\"\"\n\n    def decorate_fun(fun):\n        return layer_level_jaxpr_transformation(fun,\n                                                static_argnums,\n                                                remat=remat_layer,\n                                                layer_construction=True,\n                                                auto_layer_boundary=True,\n                                                layer_num=layer_num,\n                                                eps=eps,\n                                                cost_criteria=cost_criteria,\n                                                layer_eps=layer_eps)\n\n    if fun is None:\n        return decorate_fun\n    else:\n        _check_callable(fun)\n        return decorate_fun(fun)\n\n\ndef follow_layer_construction(fun, static_argnums, input_placement_specs,\n                              num_meshes):\n    \"\"\"Follow given input placement specs to construct layers.\"\"\"\n    _check_callable(fun)\n\n    @wraps(fun)\n    def wrapped(*args):\n        jaxpr, out_shape_tree = make_jaxpr(fun,\n                                           static_argnums=static_argnums,\n                                           return_shape=True)(*args)\n\n        var2mesh = {}  # Dict[var -> mesh_idx]\n\n        for var, spec in zip(jaxpr.jaxpr.invars, input_placement_specs):\n            if spec is None:\n                # Assign input vars to mesh 0 by default\n                if isinstance(var, Var):\n                    var2mesh[var] = 0\n            else:\n                if isinstance(var, Var):\n                    var2mesh[var] = spec.mesh_ids[0]\n\n        sliced_eqns = slice_jaxpr_with_var_assignment(jaxpr, var2mesh,\n                                                      num_meshes)\n        jaxpr = add_pipeline_marks_for_sliced_eqns(jaxpr, sliced_eqns)\n\n        flatten_args, _ = tree_flatten(args)\n        ans = jaxpr_as_fun(jaxpr)(*flatten_args)  # pylint: disable=not-callable\n        _, out_tree = tree_flatten(out_shape_tree)\n        return tree_unflatten(out_tree, ans)\n\n    return wrapped\n\n\ndef slice_jaxpr_with_var_assignment(jaxpr, var2mesh, num_meshes):\n    mesh_begin = [None] * num_meshes\n    mesh_end = [None] * num_meshes\n\n    # Run a linear scan to find the begin and end equations of each mesh.\n    cur_mesh = 0\n    for idx, eqn in enumerate(jaxpr.eqns):\n        if eqn.primitive is pipeline_p:\n            continue\n        for var in eqn.invars:\n            if isinstance(var, Var) and var in var2mesh:\n                mesh_idx = var2mesh[var]\n\n                if mesh_idx > cur_mesh:\n                    cur_mesh = mesh_idx\n\n                if mesh_begin[cur_mesh] is None:\n                    mesh_begin[cur_mesh] = idx\n                mesh_end[cur_mesh] = idx\n\n    # Some boundary equations are not within the ranges detected above.\n    # Use DP algorithm to refine the boundary, so we can minimize the\n    # communication costs.\n    cost_criteria = \"flops\"\n    costs = get_layer_construction_costs(jaxpr, cost_criteria=cost_criteria)\n    _, _, compute_costs = costs\n\n    # To make the solution of DP algorithm respect our begin/end constraint.\n    # We assign begin, end equations a very large cost and run DP\n    # with a small eps.\n    max_cost = np.sum(compute_costs) * 10\n    for i in range(num_meshes):\n        assert mesh_begin[i] is not None and mesh_end[i] is not None\n        compute_costs[mesh_begin[i]] += max_cost\n        compute_costs[mesh_end[i]] += max_cost\n\n    sliced_eqns, _ = cluster_jaxpr_by_cost(jaxpr,\n                                           layer_num=num_meshes,\n                                           eps=0.1,\n                                           costs=costs,\n                                           cost_criteria=cost_criteria)\n    return sliced_eqns\n"
  },
  {
    "path": "alpa/pipeline_parallel/layer_stats.py",
    "content": "\"\"\"Functions related with computing the stats during layer construction.\"\"\"\nfrom typing import List, Set\n\nfrom jax import lax\nfrom jax.lib import xla_client as xc, xla_bridge as xb\nfrom jax.core import JaxprEqn, Var, DropVar, Jaxpr, ClosedJaxpr\nfrom alpa.util import OrderedSet, jaxpr_to_hlo\n\nnon_trivial_primitive = [lax.dot_general_p, lax.conv_general_dilated_p]\n\n\ndef eqn_flops(eqn: JaxprEqn) -> float:\n    \"\"\"Get the FLOP of a jaxpr equation.\"\"\"\n    if \"jaxpr\" in eqn.params:\n        return sum(eqn_flops(x) for x in eqn.params[\"jaxpr\"].eqns)\n\n    if eqn.primitive not in non_trivial_primitive:\n        return 0\n\n    new_inv = [inv for inv in eqn.invars if isinstance(inv, Var)]\n    jaxpr = Jaxpr([], new_inv, eqn.outvars, [eqn])\n    closed_jaxpr = ClosedJaxpr(jaxpr, [])\n    hlo_module = jaxpr_to_hlo(\"tmp\", closed_jaxpr, [\n        False,\n    ] * len(jaxpr.invars)).get_module()\n\n    backend = xb.get_backend(\"cpu\")\n    properties = xc._xla.hlo_module_cost_analysis(  # pylint: disable=protected-access\n        backend, hlo_module)\n    return properties[\"flops\"] if \"flops\" in properties else 0.0\n\n\ndef cluster_edges_cost(start: List[\"JaxprEqn\"], end: List[\"JaxprEqn\"]):\n    \"\"\"Calculates the cost of cluster edges.\"\"\"\n    out_tensors = OrderedSet()\n    for eqn in start:\n        out_tensors = out_tensors.union(OrderedSet(eqn.outvars))\n    in_tensors = OrderedSet()\n    for eqn in end:\n        for invar in eqn.invars:\n            if isinstance(invar, Var) and invar in out_tensors:\n                in_tensors.add(invar)\n    acc = 0\n    for in_tensor in in_tensors:\n        acc += in_tensor.aval.size * in_tensor.aval.dtype.itemsize\n    return acc\n\n\ndef heavy_count(eqn):\n    \"\"\"Check the number of heavy ops in the eqn.\"\"\"\n    if \"jaxpr\" in eqn.params:\n        return sum(heavy_count(x) for x in eqn.params[\"jaxpr\"].eqns)\n\n    if eqn.primitive not in non_trivial_primitive:\n        return 0\n    return 1\n\n\ndef is_nontrivial(eqn):\n    \"\"\"Check if the eqn is nontrivial.\"\"\"\n    return heavy_count(eqn) > 0\n\n\ndef get_cross_slice_vars(jaxpr, slices):\n    \"\"\"TODO(zhuohan):doscstring.\"\"\"\n    defined = {}\n    stage_invars = [OrderedSet() for _ in slices]\n    for invar in jaxpr.invars:\n        defined[invar] = -1\n    for invar in jaxpr.constvars:\n        defined[invar] = -1\n    for i, sliced in enumerate(slices):\n        for eqn in sliced:\n            for outvar in eqn.outvars:\n                if isinstance(outvar, DropVar):\n                    continue\n                defined[outvar] = i\n    for i, sliced in enumerate(slices):\n        for eqn in sliced:\n            for invar in eqn.invars:\n                if not isinstance(invar, Var):\n                    continue\n                if defined[invar] >= 0 and defined[invar] != i:\n                    stage_invars[i].add(invar)\n    for i, invars in enumerate(stage_invars):\n        print(f\"Layer {i} has inputs:\")\n        for invar in invars:\n            print(invar, invar.aval.shape, \"from layer\", defined[invar])\n\n\ndef log_layer_slicing_stats(origin_jaxpr, slices):\n    \"\"\"Print the layer slicing stats.\"\"\"\n    stage_flops = []\n    stage_heavy_ops = []\n    for eqns in slices:\n        stage_flops.append(sum(eqn_flops(eqn) for eqn in eqns))\n        stage_heavy_ops.append(sum(heavy_count(eqn) for eqn in eqns))\n\n    print(\"-\" * 20, \"Layer slicing stats\", \"-\" * 20)\n    print(f\"layer_num: {len(slices)}\")\n    print(\" - Number of Jaxpr eqns in each stage:\")\n    for i, s in enumerate(slices):\n        print(f\"Layer {i}: #eqns={len(s)},\"\n              f\" flop={stage_flops[i] / (1000 ** 4):.3f} TFlop,\"\n              f\" #heavy_ops={stage_heavy_ops[i]}\")\n    print(\" - Invars of each stage:\")\n    get_cross_slice_vars(origin_jaxpr.jaxpr, slices)\n    print(\"-\" * 61)\n\n\ndef global_invar_size(invars: Set[Var], eqn: JaxprEqn):\n    input_vars = {v for v in eqn.invars if isinstance(v, Var)}\n    size = sum((var.aval.size * var.aval.dtype.itemsize)\n               for var in invars.intersection(input_vars))\n    return size\n"
  },
  {
    "path": "alpa/pipeline_parallel/local_pipeline.py",
    "content": "\"\"\"Pipeline parallel on a single device. This is only used for debugging.\"\"\"\nfrom typing import Sequence, Any, Dict\n\nimport jax\nfrom jax import linear_util as lu\nfrom jax.core import Var, ClosedJaxpr, Literal, gensym\nfrom jax.interpreters import partial_eval as pe\nfrom jax.interpreters.xla import DeviceArray\n\nfrom alpa.pipeline_parallel.computation import (\n    PipelineComputation, XlaPipelineComputation,\n    slice_closed_jaxpr_by_full_pipeline_marks,\n    mark_missing_vars_in_backward_computation_pipeline_marks)\n\n\nclass LocalPipelineRunner:\n    \"\"\"Single-device local pipeline runner.\"\"\"\n\n    def __init__(self, name: str, global_invals: Sequence[DeviceArray]):\n        self.name = name\n        self.env = {}\n        self.global_invals = global_invals\n\n    def run_stage(self, stage: PipelineComputation, invals: Dict[Var, Any]):\n        \"\"\"\n        Run a pipeline stage.\n\n        Args:\n            stage (PipelineComputation): The pipeline stage to run.\n            invals (Dict[Var, Any], optional): Input value dict.\n        \"\"\"\n        runnable = stage.get_runnable()\n        invals_list = []\n        for var in stage.invars:\n            invals_list.append(invals[var])\n        outvals_list = runnable(*invals_list)\n        outvals = dict(zip(stage.outvars, outvals_list))\n        self.env.update(outvals)\n\n    def get_val(self, var):\n        \"\"\"Get the value of a variable from the env.\"\"\"\n        return self.env[var]\n\n    def del_var(self, var):\n        \"\"\"Delete a variable from the env.\"\"\"\n        del self.env[var]\n\n\nclass LocalPipelineExecutable:\n    \"\"\"A pipeline parallel executable running on a single local device.\n\n    Args:\n        stages (Sequence[PipelineComputation]): the pipeline stages to be\n            executed.\n        global_invars (Sequence[Var]): Global input variables.\n        global_outvars (Sequence[Var]): Global output variables.\n    \"\"\"\n\n    def __init__(self, *, stages: Sequence[PipelineComputation],\n                 global_invars: Sequence[Var], global_outvars: Sequence[Var]):\n        self.stages = stages\n        self.global_invars = global_invars\n        self.global_outvars = global_outvars\n\n    def launch_on_driver(self, *args):\n        \"\"\"Run function.\"\"\"\n        global_invals = dict(zip(self.global_invars, args))\n        runners = {}\n\n        var_stage_mapping = {}\n        var_reference_count = {}\n\n        # Create variable dependency mapping.\n        for stage in self.stages:\n            for var in stage.invars:\n                if var not in global_invals:\n                    assert var in var_stage_mapping, (\n                        f\"referred to an unknown var {var}\")\n                    var_reference_count[var] = var_reference_count.get(var,\n                                                                       0) + 1\n            for var in stage.outvars:\n                var_stage_mapping[var] = stage.name\n\n        for var in self.global_outvars:\n            if not isinstance(var, Literal):\n                assert var in var_stage_mapping, (\n                    f\"referred to an unknown var {var}\")\n                var_reference_count[var] = var_reference_count.get(var, 0) + 1\n\n        for stage in self.stages:\n            stage_invals = {}\n            for var in stage.invars:\n                if var in global_invals:\n                    stage_invals[var] = global_invals[var]\n                else:\n                    assert var in var_stage_mapping, (\n                        f\"referred to an unknown var {var}\")\n                    sender_runner = runners[var_stage_mapping[var]]\n                    stage_invals[var] = sender_runner.get_val(var)\n                    var_reference_count[var] -= 1\n                    if var_reference_count[var] == 0:\n                        sender_runner.del_var(var)\n\n            if stage.name not in runners:\n                runners[stage.name] = LocalPipelineRunner(\n                    stage.name, global_invals)\n            runners[stage.name].run_stage(stage, stage_invals)\n\n        global_outvals_list = []\n        for var in self.global_outvars:\n            if isinstance(var, Literal):\n                global_outvals_list.append(var.val)\n            else:\n                assert var in var_stage_mapping, (\n                    f\"referred to an unknown var {var}\")\n                sender_runner = runners[var_stage_mapping[var]]\n                global_outvals_list.append(sender_runner.get_val(var))\n                var_reference_count[var] -= 1\n                if var_reference_count[var] == 0:\n                    sender_runner.del_var(var)\n        return global_outvals_list\n\n\ndef compile_local_pipeline_executable(fun: lu.WrappedFun, *avals):\n    \"\"\"Compile a local pipeline executable that only runs on a singel device.\"\"\"\n    with jax.disable_jit():\n        jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, avals)\n    closed_jaxpr = ClosedJaxpr(jaxpr, consts)\n    global_invars = closed_jaxpr.jaxpr.invars\n    global_outvars = closed_jaxpr.jaxpr.outvars\n    gensym_func = gensym([closed_jaxpr.jaxpr])\n    jax_pipeline_stages = slice_closed_jaxpr_by_full_pipeline_marks(\n        closed_jaxpr)\n    jax_pipeline_stages = (\n        mark_missing_vars_in_backward_computation_pipeline_marks(\n            jax_pipeline_stages, global_invars, global_outvars, gensym_func))\n    xla_pipeline_stages = [\n        XlaPipelineComputation.from_jax_pipeline_computation(stage)\n        for stage in jax_pipeline_stages\n    ]\n\n    return LocalPipelineExecutable(stages=xla_pipeline_stages,\n                                   global_invars=global_invars,\n                                   global_outvars=global_outvars)\n"
  },
  {
    "path": "alpa/pipeline_parallel/pipeshard_executable.py",
    "content": "\"\"\"The driver part and worker part of a pipeshard executable.\"\"\"\nimport logging\nfrom functools import partial\nimport json\nimport os\nimport time\nfrom typing import Optional, Sequence\n\nfrom jax._src import traceback_util\nfrom jax._src.lib import xla_extension as xe\nfrom jax.tree_util import tree_flatten, tree_unflatten, tree_leaves, PyTreeDef\nimport numpy as np\nimport ray.exceptions\n\nfrom alpa.device_mesh import (\n    MeshHostWorker, RemoteArrayRef,\n    create_and_record_cross_mesh_collective_communicators, next_array_uuids)\nfrom alpa.global_env import global_config\nfrom alpa.device_mesh import PhysicalDeviceMeshGroup\nfrom alpa.mesh_executable import (AllocZeroBufferWorkerExecutable,\n                                  UtilMeshWorkerExecutable,\n                                  PartialGradAccMeshWorkerExecutable,\n                                  next_mesh_executable_uuid,\n                                  get_execution_timer_name)\nfrom alpa.parallel_plan import ClusterInfo, PipelinePlan, ParallelPlan\nfrom alpa.pipeline_parallel.layer_construction import LayerOption\nfrom alpa.pipeline_parallel.runtime_emitter import (\n    AllocateZeroWorkerExecutableConfig, ConcatWorkerExecutableConfig,\n    ExecutableConfig, PartialGradWorkerExecutableConfig, PipelineInstType,\n    PipelineInstruction, PipeshardConfig)\nfrom alpa.shard_parallel.auto_sharding import HloStatus\nfrom alpa.timer import timers, tracer\nfrom alpa.util import OrderedSet, mesh_ids_hash\n\ntraceback_util.register_exclusion(__file__)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\nclass PipeshardDriverExecutable:\n    \"\"\"The driver part of the executable for pipeshard parallel.\"\"\"\n\n    def __init__(self,\n                 mesh_group: PhysicalDeviceMeshGroup,\n                 pipeshard_config: PipeshardConfig,\n                 num_batch: int,\n                 layer_option: LayerOption,\n                 in_tree: PyTreeDef,\n                 out_tree: Optional[PyTreeDef] = None,\n                 static_argnums: Optional[Sequence[int]] = None):\n        ##### Input arguments #####\n        self.mesh_group = mesh_group\n        self.num_mesh = len(mesh_group)\n        self.num_batch = num_batch\n        self.in_tree = in_tree\n        self.out_tree = out_tree\n        self.static_argnums = static_argnums\n\n        ##### For debugging and serialization #####\n        self.stages = pipeshard_config.xla_stages\n        self.schedule = pipeshard_config.schedule\n        self.flop_count = pipeshard_config.flop_count\n        self.stage_input_shard_specs = pipeshard_config.stage_input_shard_specs\n        self.input_placement_specs = pipeshard_config.input_placement_specs\n        self.output_placement_specs = pipeshard_config.output_placement_specs\n        # List[stage_idx -> str]\n        self.fully_optimized_hlo_texts = []\n        # List[stage_idx -> int]\n        self.stage_allocation_sizes = []\n        self.sharding_annotated_hlo_texts = (\n            pipeshard_config.sharding_annotated_hlo_texts)\n        # List[stage_idx -> executable_uuid]\n        self.executable_uuids = pipeshard_config.executable_uuids\n        self.default_auto_sharding_option = (\n            pipeshard_config.default_auto_sharding_option)\n        self.pipeline_plan = PipelinePlan(\n            self.schedule.name,\n            layer_option,\n            pipeshard_config.manual_stage_option,\n        )\n\n        ##### For handling inputs of the executable #####\n        # go to the definition of PipeshardInputConfig for more details.\n        input_config = pipeshard_config.input_config\n        self.donate_invars = input_config.donate_invars\n        self.mesh_arg_indices = input_config.mesh_arg_indices\n        self.input_shard_indices = input_config.input_shard_indices\n        self.delete_after_shard = input_config.delete_after_shard\n        self.batch_invars = input_config.batch_invars\n\n        ##### For handling outputs of the executable #####\n        self.output_local_uuid_list = pipeshard_config.output_local_uuid_list\n        self.outs_handler = pipeshard_config.outs_handler\n\n        ##### For cross-mesh resharding #####\n        self._instantiate_nccl_groups(pipeshard_config.device_str_groups)\n        self.resharding_tasks = pipeshard_config.resharding_tasks\n        for mesh_ids in pipeshard_config.allreduce_groups:\n            meshes = [self.mesh_group.meshes[idx] for idx in mesh_ids]\n            key = mesh_ids_hash(mesh_ids)\n            create_and_record_cross_mesh_collective_communicators(meshes, key)\n        if global_config.eagerly_create_communicators:\n            for task in self.resharding_tasks:\n                task.create_resharding_communicators()\n\n        self.exec_uuid = next_mesh_executable_uuid()\n        # Create a PipeshardMeshWorkerExecutable for each MeshHostWorker\n        for mesh_idx, physical_mesh in enumerate(self.mesh_group):\n            mesh_grad_uuids = pipeshard_config.grad_uuids[mesh_idx]\n            for worker in physical_mesh.workers:\n                acc_grad_local_uuids = []\n                if len(mesh_grad_uuids) > 0:\n                    acc_grad_local_uuids = mesh_grad_uuids\n                args = (pipeshard_config.instruction_lists[worker],\n                        input_config.input_local_uuid_lists[mesh_idx],\n                        self.output_local_uuid_list[mesh_idx],\n                        pipeshard_config.executable_configs[worker],\n                        acc_grad_local_uuids,\n                        pipeshard_config.reduced_var_uuid_lists[mesh_idx],\n                        self.donate_invars[mesh_idx])\n                worker.put_executable.remote(self.exec_uuid,\n                                             PipeshardMeshWorkerExecutable,\n                                             *args)\n\n    ##### Compilation Related Functions #####\n    def _instantiate_nccl_groups(self, device_str_groups):\n        \"\"\"\n        Instantiate NCCL groups between two physical meshes.\n\n        Args:\n            device_str_groups (List[List[set]]): a num_mesh x num_mesh matrix.\n                Only entries at device_str_groups[i][j] (i < j) are filled,\n                entries with i > j are None, because (spec[i][j], spec[j][i])\n                will share collective groups.\n        \"\"\"\n        start_time = time.time()\n        for i in range(self.num_mesh):\n            for j in range(i, self.num_mesh):\n                if device_str_groups[i][j]:\n                    self.mesh_group.instantiate_nccl_group(i, j)\n        end_time = time.time()\n        logger.debug(\n            f\"Initialize collective group takes {end_time - start_time:.2f}\")\n\n    ##### Execution Related Functions #####\n    def launch_on_driver(self, *args):\n        \"\"\"Launch the executable on the driver.\n\n        Args:\n            args: The original arguments of the parallelized function.\n        \"\"\"\n        input_bufs = [None for _ in range(self.num_mesh)]\n        output_bufs = [None for _ in range(self.num_mesh)]\n        output_uuids = [None for _ in range(self.num_mesh)]\n\n        num_outs = [\n            len(self.output_local_uuid_list[mesh_idx])\n            for mesh_idx in range(self.num_mesh)\n        ]\n\n        for mesh_idx, physical_mesh in enumerate(self.mesh_group):\n            # Shard inputs\n            mesh_args = [args[idx] for idx in self.mesh_arg_indices[mesh_idx]]\n            tmp_bufs = physical_mesh.shard_args_to_bufs(\n                self.input_shard_indices[mesh_idx],\n                self.delete_after_shard[mesh_idx], self.batch_invars[mesh_idx],\n                self.num_batch, mesh_args)\n\n            # Flatten the batch args in tmp_bufs\n            flatten_bufs = []\n            for i, is_batch_invar in enumerate(self.batch_invars[mesh_idx]):\n                if is_batch_invar:\n                    flatten_bufs.extend(tmp_bufs[i])\n                else:\n                    flatten_bufs.append(tmp_bufs[i])\n            input_bufs[mesh_idx] = flatten_bufs\n\n            # Convert bufs to uuids\n            input_uuids = np.array([ref.uuid for ref in input_bufs[mesh_idx]])\n            output_uuids[mesh_idx] = next_array_uuids(num_outs[mesh_idx])\n\n            # Execute\n            for worker in physical_mesh.workers:\n                worker.run_executable.remote(\n                    self.exec_uuid,\n                    input_uuids,\n                    output_uuids[mesh_idx],\n                    sync_for_timer=global_config.pipeline_sync_for_timer,\n                    collect_trace=global_config.collect_trace)\n\n        # Handle donation\n        for mesh_idx in range(len(self.mesh_group)):\n            inputs = input_bufs[mesh_idx]\n            for ref, donate in zip(inputs, self.donate_invars[mesh_idx]):\n                if donate:\n                    ref.set_deleted_on_workers()\n\n        # Construct output_bufs\n        for mesh_idx, physical_mesh in enumerate(self.mesh_group):\n            output_uuid = output_uuids[mesh_idx]\n            output_bufs[mesh_idx] = np.empty((num_outs[mesh_idx],),\n                                             dtype=object)\n            for i in range(num_outs[mesh_idx]):\n                output_bufs[mesh_idx][i] = RemoteArrayRef(\n                    physical_mesh, output_uuid[i])\n\n        # Check if there is OOM\n        if global_config.pipeline_check_alive:\n            self._check_alive()\n\n        return self.outs_handler(self.mesh_group, output_bufs)\n\n    def get_input_placement_specs(self):\n        \"\"\"\n        Return the preferred placement specs for input arguments.\n        The return value is a pytree of PlacementSpec\n        with the same structure as the input pytree.\n        \"\"\"\n        return tree_unflatten(self.in_tree, self.input_placement_specs)\n\n    def get_output_placement_specs(self):\n        \"\"\"\n        Return the preferred placement specs for outputs.\n        The return value is a pytree of PlacementSpec\n        with the same structure as the output pytree.\n        \"\"\"\n        return tree_unflatten(self.out_tree, self.output_placement_specs)\n\n    def get_parallel_plan(self):\n        \"\"\"Get the overall parallel plan.\"\"\"\n        virtual_mesh = self.mesh_group.parent\n        cluster_info = ClusterInfo(virtual_mesh.num_hosts,\n                                   virtual_mesh.num_devices_per_host)\n        return ParallelPlan(cluster_info, self.num_batch,\n                            self.default_auto_sharding_option,\n                            self.pipeline_plan,\n                            tree_leaves(self.get_input_placement_specs()))\n\n    def __call__(self, *args):\n        \"\"\"Fast call without signature matching.\"\"\"\n        if self.static_argnums:\n            dyn_args = [\n                args[i]\n                for i in range(len(args))\n                if i not in self.static_argnums\n            ]\n        else:\n            dyn_args = args\n        args_flat, _ = tree_flatten(dyn_args)\n        out = self.launch_on_driver(*args_flat)\n        return tree_unflatten(self.out_tree, out)\n\n    ##### Profiling and Debugging Related Functions #####\n    def get_stage_execution_info(self):\n        \"\"\"Get the per-stage execution information of all invocations.\n           Return a list, where each element corresponds to a single stage.\n           Each element is a list of (start, stop, node_ids, devices) tuple,\n           where each tuple corresponds to one invocation.\n        \"\"\"\n        exec_timer_name = get_execution_timer_name(self.exec_uuid)\n        run_begin_event = exec_timer_name + \"-ins-run-begin\"\n        run_end_event = exec_timer_name + \"-ins-run-end\"\n\n        num_stages = len(self.stages)\n        stage_start = [[] for _ in range(num_stages)]\n        stage_end = [[] for _ in range(num_stages)]\n\n        # Extract events\n        for mesh in self.mesh_group:\n            mesh_tracer = mesh.get_remote_tracer()\n\n            for x in mesh_tracer.events:\n                if x.name == run_begin_event and \"stage\" in x.info:\n                    stage_id = int(x.info[6:])\n                    stage_start[stage_id].append(x.tstamp)\n                if x.name == run_end_event and \"stage\" in x.info:\n                    stage_id = int(x.info[6:])\n                    stage_end[stage_id].append(x.tstamp)\n\n        # Organize return values\n        all_stages_info_list = []\n        for i in range(num_stages):\n            mesh_idx = self.schedule.stage_placement(i)\n            assert len(mesh_idx) == 1\n            mesh_idx = list(mesh_idx)[0]\n            mesh = self.mesh_group[mesh_idx]\n            host_ids, devices = mesh.host_ids, mesh.devices\n            per_stage_info_list = []\n            for s, e in zip(stage_start[i], stage_end[i]):\n                per_stage_info_list.append((s, e, host_ids, devices))\n            all_stages_info_list.append(per_stage_info_list)\n        return all_stages_info_list\n\n    def get_execution_time_costs(self, timer_name=None, return_all_costs=False):\n        \"\"\"Get the execution time costs with internal timers.\"\"\"\n        assert timer_name is None  # TODO(lmzheng): support other timers later\n        timer_name = get_execution_timer_name(self.exec_uuid)\n        mesh_costs = []\n        for mesh in self.mesh_group:\n            mesh_costs.append(mesh.get_remote_timer(timer_name).costs)\n        if return_all_costs:\n            return mesh_costs\n\n        min_costs = [1.0e9] * len(mesh_costs[0])\n        max_costs = [0] * len(mesh_costs[0])\n        for mesh_cost in mesh_costs:\n            for i, cost in enumerate(mesh_cost):\n                if cost > max_costs[i]:\n                    max_costs[i] = cost\n                if cost < min_costs[i]:\n                    min_costs[i] = cost\n        return max_costs\n\n    def get_shard_args_time_costs(self):\n        # TODO(lmzheng): implement this\n        raise NotImplementedError()\n\n    def get_hlo_text(self, status: HloStatus = HloStatus.FULLY_OPTIMIZED):\n        \"\"\"Return the HLO text for all stages.\"\"\"\n        if status == HloStatus.FULLY_OPTIMIZED:\n            if self.fully_optimized_hlo_texts:\n                return self.fully_optimized_hlo_texts\n\n            hlo_texts = []\n            for stage_idx in range(len(self.stages)):\n                mesh_idx = self.schedule.stage_placement(stage_idx)\n                assert len(mesh_idx) == 1\n                mesh_idx = list(mesh_idx)[0]\n                physical_mesh = self.mesh_group[mesh_idx]\n                hlo_text = physical_mesh.workers[0].get_exec_hlo_text.remote(\n                    self.executable_uuids[stage_idx])\n                hlo_texts.append(hlo_text)\n            self.fully_optimized_hlo_texts = ray.get(hlo_texts)\n            return self.fully_optimized_hlo_texts\n        else:\n            return self.sharding_annotated_hlo_texts\n\n    def get_stage_allocation_size(self):\n        \"\"\"Get the total memory allocation size in bytes of all stages.\"\"\"\n        if self.stage_allocation_sizes:\n            return self.stage_allocation_sizes\n\n        sizes = []\n        for stage_idx in range(len(self.stages)):\n            mesh_idx = self.schedule.stage_placement(stage_idx)\n            assert len(mesh_idx) == 1\n            mesh_idx = list(mesh_idx)[0]\n            physical_mesh = self.mesh_group[mesh_idx]\n            size = physical_mesh.workers[\n                0].get_exec_total_allocation_size.remote(\n                    self.executable_uuids[stage_idx])\n            sizes.append(size)\n        self.stage_allocation_sizes = ray.get(sizes)\n        return self.stage_allocation_sizes\n\n    def dump_debug_info(self, folder: str):\n        \"\"\"\n        Dump intermediate representations and other informations for debugging.\n        \"\"\"\n        os.makedirs(folder, exist_ok=True)\n        name = self.stages[0].hlo.name\n        if \"pipeshard_parallel\" in name:\n            name = name[:name.index(\"pipeshard_parallel\") - 1]\n        elif \"create_state_parallel\" in name:\n            name = name[:name.index(\"create_state_parallel\") - 1]\n        prefix = os.path.join(folder, name)\n\n        fully_optimized_hlo_texts = self.get_hlo_text(HloStatus.FULLY_OPTIMIZED)\n        allocation_sizes = self.get_stage_allocation_size()\n        for stage_idx in range(len(self.stages)):\n            with open(f\"{prefix}_stage_{stage_idx}.hlo\", \"w\") as f:\n                f.write(fully_optimized_hlo_texts[stage_idx])\n\n            with open(f\"{prefix}_stage_{stage_idx}.mem_usage.txt\", \"w\") as f:\n                f.write(f\"total_allocation_size: \"\n                        f\"{allocation_sizes[stage_idx]/(1024**3):.3f} GB\\n\")\n\n        with open(f\"{prefix}_resharding_tasks.txt\", \"w\") as f:\n            for task in self.resharding_tasks:\n                f.write(str(task) + \"\\n\\n\")\n\n        with open(f\"{prefix}_input_placement_specs.txt\", \"w\") as f:\n            f.write(str(self.get_input_placement_specs()))\n        with open(f\"{prefix}_output_placement_specs.txt\", \"w\") as f:\n            f.write(str(self.get_output_placement_specs()))\n\n    def dump_stage_execution_trace(self, filename: str):\n        exec_info = self.get_stage_execution_info()\n        dump_stage_execution_trace_internal(exec_info, filename)\n\n    def profile_all_executable_with_dummy_inputs(self):\n        \"\"\"Profile all stage executables with dummy inputs.\"\"\"\n        all_profiled_handles = []\n        for _, physical_mesh in enumerate(self.mesh_group):\n            all_worker_profiled = []\n            for _, worker in enumerate(physical_mesh.workers):\n                worker: MeshHostWorker\n                all_worker_profiled.append(\n                    worker.profile_executable_with_dummy_inputs.remote(\n                        self.exec_uuid))\n            if len(all_worker_profiled) == 1:\n                all_worker_profiled = all_worker_profiled[0]\n            all_profiled_handles.append(all_worker_profiled)\n        all_profiled = [ray.get(handles) for handles in all_profiled_handles]\n        return all_profiled\n\n    ##### Other Functions #####\n    def sync(self):\n        \"\"\"Sync device activities on all workers.\"\"\"\n        self.mesh_group.sync_workers()\n\n    def sync_move_workers(self):\n        \"\"\"Sync moveworkers on all meshes.\"\"\"\n        self.mesh_group.sync_move_workers()\n\n    def _check_alive(self):\n        \"\"\"\n        Check whether all workers are alive.\n        Shutdown the runtime if any worker dies.\n        \"\"\"\n        try:\n            rets = [\n                worker.check_alive.remote()\n                for mesh in self.mesh_group\n                for worker in mesh.workers\n            ]\n            ray.get(rets)\n        except ray.exceptions.RayActorError:\n            self.mesh_group.exception_shutdown()\n\n    def __del__(self):\n        for mesh in self.mesh_group:\n            mesh.delete_remote_executable(self.exec_uuid)\n\n\nclass PipeshardMeshWorkerExecutable:\n    \"\"\"\n    An executable that executes static pipeline runtime instructions on a\n    worker.\n    \"\"\"\n\n    def __init__(self, worker: MeshHostWorker, uuid: int,\n                 instructions: Sequence[PipelineInstruction],\n                 input_local_uuids: Sequence[int],\n                 output_local_uuids: Sequence[int],\n                 executable_configs: Sequence[ExecutableConfig],\n                 acc_local_uuids: np.ndarray, acc_out_uuids: np.ndarray,\n                 donate_invars: Sequence[bool]):\n        # Instruction Lists\n        self.exec_uuid = uuid\n        self.exec_timer_name = get_execution_timer_name(uuid)\n        self.instructions = instructions\n        self.input_local_uuids = input_local_uuids\n        self.output_local_uuids = output_local_uuids\n\n        # Buffer management\n        self.worker = worker\n        self.global_buffers = worker.buffers\n        self.acc_in_uuids = acc_local_uuids\n        self.acc_out_uuids = acc_out_uuids\n        self.donate_invars = donate_invars\n\n        # Executable management\n        self._related_exec_uuids = []\n        self.partial_grad_exec_uuids = OrderedSet()\n\n        # Compile executables\n        for task_config in executable_configs:\n            self._related_exec_uuids.append(task_config.exec_uuid)\n            if isinstance(task_config, PartialGradWorkerExecutableConfig):\n                self.worker.put_executable(task_config.exec_uuid,\n                                           PartialGradAccMeshWorkerExecutable,\n                                           *task_config[1:])\n                self.partial_grad_exec_uuids.add(task_config.exec_uuid)\n            elif isinstance(task_config, AllocateZeroWorkerExecutableConfig):\n                self.worker.put_executable(task_config.exec_uuid,\n                                           AllocZeroBufferWorkerExecutable,\n                                           task_config.grad_shard_shapes,\n                                           task_config.grad_shard_dtypes)\n            elif isinstance(task_config, ConcatWorkerExecutableConfig):\n                self.worker.put_executable(task_config.exec_uuid,\n                                           UtilMeshWorkerExecutable,\n                                           *task_config[1:])\n            else:\n                raise ValueError(f\"Invalid task config {task_config}\")\n        self.partial_grad_exec_uuids = list(self.partial_grad_exec_uuids)\n\n    def execute_on_worker(self, input_global_uuids, output_global_uuids,\n                          sync_for_timer, collect_trace):\n        \"\"\"Execute on the mesh worker given input and output uuids.\"\"\"\n        # create a local buffer environment\n        assert len(self.input_local_uuids) == len(input_global_uuids)\n        buffers = {}\n        for local_id, global_id in zip(self.input_local_uuids,\n                                       input_global_uuids):\n            buffers[local_id] = self.global_buffers[global_id]\n        if global_config.enable_overlapping:\n            xe.reset_event_context(self.worker.backend)\n        # donate invars\n        for global_id, donate in zip(input_global_uuids, self.donate_invars):\n            if donate:\n                self.worker.delete_buffers(global_id)\n        # load the local env\n        self.worker.buffers = buffers\n        sync_func = self.worker.sync if sync_for_timer else None\n\n        # Setup tracer\n        if collect_trace:\n            log_run_begin = partial(tracer.log,\n                                    self.exec_timer_name + \"-ins-run-begin\")\n            log_run_end = partial(tracer.log,\n                                  self.exec_timer_name + \"-ins-run-end\")\n        else:\n\n            def log_run_begin(*_, **__):\n                pass\n\n            log_run_end = log_run_begin\n\n        # Execute\n        timers(self.exec_timer_name).start(sync_func=sync_func)\n\n        for instruction in self.instructions:\n            #self.worker.sync()\n            #print(f\"memory_allocated: \"\n            #      f\"{self.worker.get_memory_allocated()/1024**3:.3f} GB  \"\n            #      f\"max_memory_allocated: \"\n            #      f\"{self.worker.get_max_memory_allocated()/1024**3:.3f} GB \"\n            #      f\"next instruction: {instruction}\", flush=True)\n\n            if instruction.opcode == PipelineInstType.RUN:\n                log_run_begin(instruction.info, sync_func=sync_func)\n                self.worker.run_executable(instruction.task_uuid,\n                                           instruction.input_uuids,\n                                           instruction.output_uuids,\n                                           **instruction.opaques[\"kwargs\"])\n                log_run_end(instruction.info, sync_func=sync_func)\n            elif instruction.opcode == PipelineInstType.SEND:\n                self.worker.run_resharding_send_task(instruction.task_uuid,\n                                                     instruction.input_uuids[0])\n            elif instruction.opcode == PipelineInstType.RECV:\n                self.worker.run_resharding_recv_task(\n                    instruction.task_uuid, instruction.output_uuids[0],\n                    instruction.opaques[\"set_empty_buffer\"])\n                # TODO(lmzheng): move this to run_resharding_recv_task\n                if instruction.opaques[\"allgather_uuid\"] is not None:\n                    task_uuid = instruction.opaques[\"allgather_uuid\"]\n                    ary_uuid = instruction.output_uuids[0]\n                    self.worker.run_executable(task_uuid, [ary_uuid],\n                                               [ary_uuid], False, False)\n            elif instruction.opcode == PipelineInstType.BROADCAST:\n                self.worker.run_resharding_broadcast_task(\n                    instruction.task_uuid,\n                    (instruction.input_uuids if instruction.input_uuids\n                     is not None else instruction.output_uuids)[0])\n            elif instruction.opcode == PipelineInstType.FREE:\n                self.worker.delete_buffers(instruction.input_uuids)\n\n        timers(self.exec_timer_name).stop(sync_func=sync_func)\n\n        # copy to global env\n        assert len(self.output_local_uuids) == len(output_global_uuids)\n        for local_id, global_id in zip(self.output_local_uuids,\n                                       output_global_uuids):\n            self.global_buffers[global_id] = buffers[local_id]\n        # restore global environment\n        self.worker.buffers = self.global_buffers\n        buffers.clear()\n        if global_config.enable_overlapping:\n            xe.reset_event_context(self.worker.backend)\n\n    def profile_with_dummy_inputs(self):\n        \"\"\"Profile the executable with dummy inputs.\"\"\"\n        self.worker.reset_memory_stats()\n        ret = {\n            exec_id:\n            (np.mean(\n                self.worker.profile_executable_with_dummy_inputs(\n                    exec_id, skip_grad_sync=False)),\n             self.worker.get_exec_total_allocation_size(exec_id) / 1024**3)\n            for exec_id in self.partial_grad_exec_uuids\n        }\n        self.worker.reset_memory_stats()\n        return ret\n\n    def __del__(self):\n        for exec_id in self._related_exec_uuids:\n            self.worker.delete_executable(exec_id)\n\n\ndef dump_stage_execution_trace_internal(stage_execution_info, filename: str):\n    \"\"\"Dump stage execution info as a chrome tracing file.\"\"\"\n\n    def get_color(i):\n        color_list = [\n            \"thread_state_uninterruptible\",\n            \"thread_state_iowait\",\n            \"thread_state_running\",\n            \"thread_state_runnable\",\n            \"thread_state_unknown\",\n            \"background_memory_dump\",\n            \"light_memory_dump\",\n            \"detailed_memory_dump\",\n            \"vsync_highlight_color\",\n            \"generic_work\",\n            \"good\",\n            \"bad\",\n            \"terrible\",\n            \"yellow\",\n            \"olive\",\n            \"rail_response\",\n            \"rail_animation\",\n            \"rail_idle\",\n            \"rail_load\",\n            \"startup\",\n            \"heap_dump_stack_frame\",\n            \"heap_dump_object_type\",\n            \"heap_dump_child_node_arrow\",\n            \"cq_build_running\",\n            \"cq_build_passed\",\n            \"cq_build_failed\",\n            \"cq_build_attempt_runnig\",\n            \"cq_build_attempt_passed\",\n            \"cq_build_attempt_failed\",\n        ]\n        return color_list[i % len(color_list)]\n\n    slot_list = []\n    for request_id, request_timeline in enumerate(zip(*stage_execution_info)):\n        sorted_timeline = sorted(request_timeline, key=lambda x: x[0])\n\n        for stage_num, (s, e, node_ids, devices) in enumerate(sorted_timeline):\n            for node_id, devices_per_node in zip(node_ids, devices):\n                for device_id in devices_per_node:\n                    slot = {\n                        \"name\": f\"r{request_id}s{stage_num}\",\n                        \"cat\": f\"request {request_id}, stage {stage_num}\",\n                        \"ph\": \"X\",\n                        \"pid\": int(node_id),\n                        \"tid\": int(device_id),\n                        \"ts\": float(s) * 1e6,\n                        \"dur\": float(e - s) * 1e6,\n                        \"cname\": get_color(request_id)\n                    }\n                    slot_list.append(slot)\n\n    os.makedirs(os.path.dirname(filename), exist_ok=True)\n    with open(filename, \"w\") as fout:\n        fout.write(\n            json.dumps({\n                \"traceEvents\": slot_list,\n                \"displayTimeUnit\": \"ms\",\n            }))\n"
  },
  {
    "path": "alpa/pipeline_parallel/primitive_def.py",
    "content": "\"\"\"Define a new Jax primitive pipeline_marker to mark the boundary of pipeline\ncomputations.\"\"\"\nimport numpy as np\n\nfrom jax.core import Primitive\nfrom jax.interpreters import xla, ad\nfrom jax.lib import xla_client as xc\nfrom jax.tree_util import tree_flatten, tree_unflatten\n\nfrom alpa.util import new_jaxpr_eqn\n\n########## Public APIs ##########\n\n# Define a Jax primitive to mark start/end of a pipeline computation.\npipeline_p = Primitive(\"pipeline_marker\")\n\n\ndef mark_pipeline_boundary():\n    \"\"\"Mark the boundary of pipeline layers. We reuse pipeline_marker for this\n    functionality.\"\"\"\n    return pipeline_p.bind(name=\"boundary\", mark_type=\"boundary\")\n\n\ndef mark_gradient(grad):\n    \"\"\"Mark variables as gradients. We reuse pipeline_marker for this\n    functionality.\"\"\"\n    grad_flat, tree = tree_flatten(grad)\n    grad_flat = pipeline_p.bind(*grad_flat, name=\"grad\", mark_type=\"grad\")\n    grad = tree_unflatten(tree, grad_flat)\n    return grad\n\n\ndef mark_pipeline_jaxpreqn(invars, outvars, name: str, mark_type: str):\n    \"\"\"Make a new jaxpr equation.\"\"\"\n    if mark_type not in (\"start\", \"end\", \"jvp_start\", \"jvp_end\"):\n        raise ValueError(f\"Unknown mark type: {mark_type}\")\n    return new_jaxpr_eqn(invars, outvars, pipeline_p, {\n        \"name\": name,\n        \"mark_type\": mark_type\n    })\n\n\ndef mark_hook_jaxpreqn(invars, outvars):\n    \"\"\"Mark some variables in a hook. We then extract the information\n    of the variables in the hook.\"\"\"\n    return new_jaxpr_eqn(invars, outvars, pipeline_p, {\n        \"name\": \"hook\",\n        \"mark_type\": \"hook\"\n    })\n\n\n########## Internal Registration ##########\ndef flatten_shape_byte_sizes(shape):\n\n    def _flatten_shape_byte_sizes(shape):\n        if shape.is_tuple():\n            res = []\n            for sub_shape in shape.tuple_shapes():\n                res += _flatten_shape_byte_sizes(sub_shape)\n            return res\n        else:\n            return [shape.numpy_dtype().itemsize * np.prod(shape.dimensions())]\n\n    res = _flatten_shape_byte_sizes(shape)\n    return np.array(res, dtype=np.int64)\n\n\ndef xla_custom_call(c, call_name, op_name, *args):\n    input_params = xc.ops.Tuple(c, args)\n    input_shape = c.get_shape(input_params)\n    flattened_byte_sizes = flatten_shape_byte_sizes(input_shape)\n    op_metadata = xc.OpMetadata(op_name=op_name)\n    c.set_op_metadata(op_metadata)\n\n    if len(args) == 0:\n        # If the custom call is an empty marker, it cannot be annotated\n        # by sharding propagation, so we set a sharding for it.\n        sharding = xc.OpSharding()\n        sharding.type = sharding.type.REPLICATED\n        c.set_sharding(sharding)\n\n    if call_name == \"pipeline_marker\":\n        output_tuple = xc.ops.CustomCall(\n            c,\n            b\"pipeline_marker\",\n            operands=(input_params,),\n            shape=input_shape,\n            # Prevent the deletion of an empty marker\n            has_side_effect=True,\n            opaque=flattened_byte_sizes.tobytes())\n    elif call_name == \"optimization_barrier\":\n        output_tuple = xc.ops.OptimizationBarrier(input_params)\n    else:\n        raise ValueError(\"Invalid call_name: {call_name}\")\n\n    c.clear_op_metadata()\n    c.clear_sharding()\n    return output_tuple\n\n\ndef _pipeline_impl(*args, **kwargs):\n    # pylint: disable=unused-argument\n    # The pipeline marker acts as an identity function.\n    return args\n\n\ndef _pipeline_abstract_eval(*args, **kwargs):\n    # pylint: disable=unused-argument\n    # The pipeline marker acts as an identity function.\n    return args\n\n\ndef _pipeline_xla_translation(c, *args, **kwargs):\n    name = kwargs[\"name\"] + \"$\" + kwargs[\"mark_type\"]\n    if kwargs[\"name\"] == \"hook\":\n        call_name = \"optimization_barrier\"\n    else:\n        call_name = \"pipeline_marker\"\n\n    return xla_custom_call(c, call_name, name, *args)\n\n\ndef _pipeline_value_and_jvp(arg_values, arg_tangents, name, mark_type):\n    primal_outs = pipeline_p.bind(*arg_values, name=name, mark_type=mark_type)\n    # TODO(zhuohan): Check the semantics here works for higher order gradients.\n    if mark_type in (\"start\", \"jvp_start\"):\n        tangent_mark_type = \"jvp_start\"\n    elif mark_type in (\"end\", \"jvp_end\"):\n        tangent_mark_type = \"jvp_end\"\n    else:\n        raise ValueError(\"Invalid mark_type\")\n\n    marker_inputs = []\n    tan_marker_id = []\n    for val, tan in zip(arg_values, arg_tangents):\n        if isinstance(tan, ad.Zero):\n            tan_marker_id.append(-1)\n        else:\n            tan_marker_id.append(len(marker_inputs))\n            marker_inputs.append(tan)\n    res = pipeline_p.bind(*marker_inputs,\n                          name=name,\n                          mark_type=tangent_mark_type)\n    tangent_outs = []\n    for i, (val, tan) in enumerate(zip(arg_values, arg_tangents)):\n        if tan_marker_id[i] == -1:\n            tangent_outs.append(ad.Zero(val.aval))\n        else:\n            tangent_outs.append(res[tan_marker_id[i]])\n\n    return primal_outs, tangent_outs\n\n\ndef _pipeline_transpose(ct, *args, name, mark_type):\n    # TODO(zhuohan): Check the semantics here works for higher order gradients.\n    if mark_type in (\"start\", \"jvp_start\"):\n        transposed_mark_type = \"end\"\n    elif mark_type in (\"end\", \"jvp_end\"):\n        transposed_mark_type = \"start\"\n    else:\n        raise ValueError(\"Invalid mark_type\")\n    marker_inputs = []\n    ctan_marker_id = []\n    for val, ctan in zip(args, ct):\n        if isinstance(ctan, ad.Zero):\n            ctan_marker_id.append(-1)\n        else:\n            ctan_marker_id.append(len(marker_inputs))\n            marker_inputs.append(ctan)\n    res = pipeline_p.bind(*marker_inputs,\n                          name=name + \"_backward\",\n                          mark_type=transposed_mark_type)\n    new_ct = []\n    for i, (val, ctan) in enumerate(zip(args, ct)):\n        if ctan_marker_id[i] == -1:\n            new_ct.append(ad.Zero(val.aval))\n        else:\n            new_ct.append(res[ctan_marker_id[i]])\n    return new_ct\n\n\npipeline_p.def_impl(_pipeline_impl)\npipeline_p.def_abstract_eval(_pipeline_abstract_eval)\npipeline_p.multiple_results = True\nxla.translations[pipeline_p] = _pipeline_xla_translation\nad.primitive_jvps[pipeline_p] = _pipeline_value_and_jvp\nad.primitive_transposes[pipeline_p] = _pipeline_transpose\n"
  },
  {
    "path": "alpa/pipeline_parallel/resharding_tensor.py",
    "content": "\"\"\"Tensor classes and utilities used for cross-mesh resharding.\"\"\"\nfrom collections.abc import Iterable\nfrom dataclasses import dataclass\nfrom typing import List, Any\n\nimport numpy as np\nfrom jax.interpreters import pxla\nfrom jax.interpreters.pxla import Replicated, ShardingSpec\n\nfrom alpa.device_mesh import VirtualPhysicalMesh\n\n\ndef unflatten_tile_index(index, shape):\n    \"\"\"Unroll a flattened index based on the given shape.\"\"\"\n    unflattened_index = []\n    reminder = index\n    for i in range(len(shape) - 1):\n        cur_index = int(reminder / np.prod(shape[i + 1:]))\n        unflattened_index.append(cur_index)\n        reminder = reminder - cur_index * np.prod(shape[i + 1:])\n    unflattened_index.append(reminder)\n    return unflattened_index\n\n\nclass VirtualDistributedArray:\n    \"\"\"\n    Distributed Array without allocating remote buffers.\n\n    VirtualDistributedArray wrapper differs from DistributedArray in that:\n    (1) it does not allocate a remote buffer at construction;\n    (2) its device_mesh attribute is a virtual mesh (not physical).\n\n    Args:\n        device_mesh (VirtualPhysicalMesh): the virtual mesh this\n            VirtualDistributedArray locates on.\n        aval (aval): shape information about the array.\n        sharding_spec (ShardingSpec): sharding spec of this array.\n    \"\"\"\n\n    def __init__(self, *, device_mesh: VirtualPhysicalMesh, aval,\n                 sharding_spec: ShardingSpec):\n        self.device_mesh = device_mesh\n        self.aval = aval\n        self.sharding_spec = sharding_spec\n\n        self._indices = None\n        self._one_replica_buffer_indices = None\n        self._tile_assignments = None\n        self._tiles = None\n\n        self._sharding_spec_proto = self.sharding_spec.sharding_proto()\n\n    @property\n    def tensor_shape(self):\n        \"\"\"Return the shape of the original tensor.\"\"\"\n        return self.aval.shape\n\n    @property\n    def tensor_rank(self):\n        \"\"\"Return the rank of the original tensor.\"\"\"\n        return len(self.tensor_shape)\n\n    @property\n    def indices(self):\n        \"\"\"Return the indices of the sharded tensor.\"\"\"\n        if not self._indices:\n            self._indices = pxla.spec_to_indices(self.tensor_shape,\n                                                 self.sharding_spec)\n        return self._indices\n\n    @property\n    def tile_assignments(self):\n        \"\"\"Return the device assignment of each tile.\"\"\"\n        if self._tile_assignments is None:\n            if self.replicated:\n                mesh_flat = np.arange(self.device_mesh.num_devices)\n                self._tile_assignments = np.reshape(\n                    mesh_flat, self.tile_shape + [self.device_mesh.num_devices])\n            else:\n                # Generate tile assignments using proto\n                proto = self._sharding_spec_proto\n                shape = proto.tile_assignment_dimensions\n                devices_flat = proto.tile_assignment_devices\n                self._tile_assignments = np.reshape(devices_flat, shape)\n        return self._tile_assignments\n\n    @property\n    def replicated_maxes(self):\n        \"\"\"Return the list of mesh axes for replication.\"\"\"\n        replicated_maxes = []\n        for maxis, assignment in enumerate(self.sharding_spec.mesh_mapping):\n            if isinstance(assignment, Replicated):\n                replicated_maxes.append(maxis)\n        return replicated_maxes\n\n    @property\n    def num_replicas(self):\n        \"\"\"Number of replicas if replicated or partially tiled.\"\"\"\n        if self.tiled:\n            return 1\n        else:\n            num_replicas = 1\n            for _, assignment in enumerate(self.sharding_spec.mesh_mapping):\n                if isinstance(assignment, Replicated):\n                    num_replicas = num_replicas * assignment.replicas\n            return num_replicas\n\n    @property\n    def tiled(self):\n        \"\"\"Whether this distributed array is fully tiled.\"\"\"\n        if not self.replicated_maxes:\n            return True\n        return False\n\n    @property\n    def replicated(self):\n        \"\"\"Whether this distributed array is fully replicated.\"\"\"\n        if len(self.replicated_maxes) == len(self.sharding_spec.mesh_mapping):\n            return True\n        return False\n\n    @property\n    def partial_tiled(self):\n        \"\"\"Whether this distributed array is mixed sharded and replicated.\"\"\"\n        if (self.replicated_maxes and len(self.replicated_maxes) < len(\n                self.sharding_spec.mesh_mapping)):\n            return True\n        return False\n\n    @property\n    def tile_shape(self):\n        \"\"\"\n        Return the shape of the tiles.\n\n        Each dim of the tile_shape is an integer representing how many tiles are\n        along this dim.\n        \"\"\"\n        if self.tiled:\n            return self.tile_assignments.shape\n        elif self.partial_tiled:\n            return self.tile_assignments.shape[:-1]\n        else:\n            # when fully replicated, the tile shape should be\n            # [1, ..., 1, num_devices], with rank = rank(array) + 1\n            return [1] * len(self.sharding_spec.sharding)\n\n    @property\n    def num_tiles(self):\n        \"\"\"Return the number of tiles of the VirtualDistributedArray.\"\"\"\n        return np.prod(self.tile_shape)\n\n    @property\n    def tiles(self):\n        \"\"\"Return all the shards of the VirtualDistributedArray following their\n        orders.\"\"\"\n        if self._tiles is None:\n            # Below are for tiled or partial_tiled.\n            num_tiles = np.prod(self.tile_shape)\n            # unique tiles (not counting those replicated)\n            self._tiles = np.empty(self.tile_shape, dtype=object)\n            for tile_index_flat in range(num_tiles):\n                # get its index\n                tile_index = unflatten_tile_index(tile_index_flat,\n                                                  self.tile_shape)\n                indices: List[Any] = [None] * len(self.tensor_shape)\n                for i, dim in enumerate(self.tensor_shape):\n                    tile_size, ragged = divmod(dim, self.tile_shape[i])\n                    assert not ragged\n                    indices[i] = slice(tile_size * tile_index[i],\n                                       tile_size * (tile_index[i] + 1))\n                device_ids = self.tile_assignments[tuple(tile_index)]\n                if not isinstance(device_ids, Iterable):\n                    device_ids = [device_ids]\n                else:\n                    device_ids = list(device_ids)\n                device_strs = [\n                    self.device_mesh.device_strs[d] for d in device_ids\n                ]\n                dst_tile = Tile(index=tile_index,\n                                index_flat=tile_index_flat,\n                                replica_device_ids=device_ids,\n                                replica_device_strs=device_strs,\n                                indices=indices)\n                self._tiles[tuple(tile_index)] = dst_tile\n        return self._tiles\n\n    @property\n    def device_str_to_flat_index(self):\n        \"\"\"Maps a device_str to its index in the flattened .indices object.\"\"\"\n        device_str_to_flat_index_map = {}\n        for i, device_str in enumerate(self.device_mesh.device_strs):\n            device_str_to_flat_index_map[device_str] = i\n        return device_str_to_flat_index_map\n\n\n@dataclass\nclass Tile:\n    \"\"\"\n    Representing a full tile (shard) on the original distributed array.\n\n    Args:\n        index (List[int]): the index of this shard in the tile_assignments\n            matrix of the VirtualDistributedArray.\n        index_flat (int): flattend index, row-majored.\n        replica_device_ids (List[int]): the device ids this shard is replicated\n            on.\n        replica_device_strs (List[str]): the device strs this shard is\n            replicated on.\n        indices (List[slice]): a list of slices that expresses its indices in\n            the original array.\n    \"\"\"\n\n    index: List[int]\n    index_flat: int\n    replica_device_ids: List[int]\n    replica_device_strs: List[str]\n    indices: List[slice]\n\n    @property\n    def tile_size(self):\n        \"\"\"Return the size (number of elements) of the tile.\"\"\"\n        size = 1\n        for s in self.indices:\n            size = size * (s.stop - s.start)\n        return size\n\n    @property\n    def tile_shape(self):\n        \"\"\"Return the shape of the tile.\"\"\"\n        return [s.stop - s.start for s in self.indices]\n\n\n@dataclass\nclass TileSlice(Tile):\n    \"\"\"\n    Representing a slice of a tile of the array using an offset.\n\n    TileSlice subsets Tile, and Tile subsets VirtualDistributedArray.\n\n    Args:\n        offset (List[slice]): a list of slice objects to represent the offset\n            made on the shard.\n    \"\"\"\n\n    offset: List[slice]\n\n    def __init__(self, tile, offset):\n        super().__init__(tile.index, tile.index_flat, tile.replica_device_ids,\n                         tile.replica_device_strs, tile.indices)\n        self.offset = offset\n\n    @property\n    def slice_size(self):\n        \"\"\"Return the size (number of elements) of this tile slice.\"\"\"\n        size = 1\n        for o in self.offset:\n            size = size * (o.stop - o.start)\n        return size\n"
  },
  {
    "path": "alpa/pipeline_parallel/runtime_emitter.py",
    "content": "\"\"\"Compile pipeline stages to runtime pipeline instructions.\"\"\"\nfrom collections import namedtuple, defaultdict\nfrom dataclasses import dataclass\nimport enum\nimport logging\nfrom typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union, Set\n\nfrom jax.core import Var\nfrom jax.interpreters import pxla\nimport numpy as np\n\nfrom alpa.global_env import global_config\nfrom alpa.device_mesh import (DistributedArray, PhysicalDeviceMeshGroup,\n                              ReplicatedDistributedArray)\nfrom alpa.mesh_executable import next_mesh_executable_uuid\nfrom alpa.parallel_plan import PlacementSpec\nfrom alpa.pipeline_parallel.computation import XlaShardedPipelineComputation\nfrom alpa.pipeline_parallel.cross_mesh_resharding import (\n    CrossMeshCommunicator, SymbolicBroadcastReshardingTask,\n    SymbolicReshardingTask, ReshardingTask)\nfrom alpa.pipeline_parallel.schedules import PipelineSchedule\nfrom alpa.pipeline_parallel.stage_construction import ManualStageOption\nfrom alpa.shard_parallel.auto_sharding import AutoShardingOption\nfrom alpa.util import (DisjointDict, OrderedSet, get_shard_shape,\n                       get_microbatch_sharding_spec, compile_concatenate)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\nclass PipelineInstType(enum.IntEnum):\n    \"\"\"Enum class for pipeline instruction types.\"\"\"\n\n    # Run an XLA executable\n    RUN = 0\n    # Run a sending task\n    SEND = 1\n    # Run a receiving task\n    RECV = 2\n    # Free tensors\n    FREE = 3\n    # Run a broadcast task\n    BROADCAST = 4\n\n\n@dataclass\nclass PipelineInstruction:\n    \"\"\"Base class for pipeline instructions.\"\"\"\n\n    opcode: PipelineInstType\n    task_uuid: Optional[int]\n    input_uuids: Optional[np.ndarray]\n    output_uuids: Optional[np.ndarray]\n    opaques: Optional[Dict[str, Any]]\n    info: str\n    print_uuids: bool = False\n\n    @classmethod\n    def run(cls, task_uuid, input_uuids, output_uuids, kwargs, info=\"\"):  # noqa\n        return cls(opcode=PipelineInstType.RUN,\n                   task_uuid=task_uuid,\n                   input_uuids=input_uuids,\n                   output_uuids=output_uuids,\n                   opaques={\"kwargs\": kwargs},\n                   info=info)\n\n    @classmethod\n    def send(cls, task_uuid, input_uuids, info=\"\"):  # noqa\n        return cls(opcode=PipelineInstType.SEND,\n                   task_uuid=task_uuid,\n                   input_uuids=input_uuids,\n                   output_uuids=None,\n                   opaques=None,\n                   info=info)\n\n    @classmethod\n    def recv(\n            cls,  # noqa\n            task_uuid,\n            output_uuids,\n            set_empty_buffer,\n            allgather_uuid=None,\n            info=\"\"):  # noqa\n        return cls(opcode=PipelineInstType.RECV,\n                   task_uuid=task_uuid,\n                   input_uuids=None,\n                   output_uuids=output_uuids,\n                   opaques={\n                       \"set_empty_buffer\": set_empty_buffer,\n                       \"allgather_uuid\": allgather_uuid\n                   },\n                   info=info)\n\n    @classmethod\n    def broadcast(\n            cls,  # noqa\n            task_uuid,\n            input_uuids,\n            output_uuids,\n            info=\"broadcast\"):  # noqa\n        return cls(opcode=PipelineInstType.BROADCAST,\n                   task_uuid=task_uuid,\n                   input_uuids=input_uuids,\n                   output_uuids=output_uuids,\n                   opaques=None,\n                   info=info)\n\n    @classmethod\n    def free(cls, input_uuids, info=\"\"):  # noqa\n        return cls(opcode=PipelineInstType.FREE,\n                   task_uuid=None,\n                   input_uuids=input_uuids,\n                   output_uuids=None,\n                   opaques=None,\n                   info=info,\n                   print_uuids=False)\n\n    def __str__(self):\n        ret = \"\"\n        ret += \"Opcode: \" + str(self.opcode)[17:] + \", Task uuid: \" + str(\n            self.task_uuid)\n        if self.print_uuids:\n            ret += \", input uuids:\" + str(self.input_uuids)\n            ret += \", output uuids:\" + str(self.output_uuids)\n        ret += \", Info: \" + self.info\n        return ret\n\n\nAllocateZeroWorkerExecutableConfig = namedtuple(\n    \"AllocateZeroWorkerExecutableConfig\",\n    [\"exec_uuid\", \"grad_shard_shapes\", \"grad_shard_dtypes\"])\nConcatWorkerExecutableConfig = namedtuple(\"ConcatWorkerExecutableConfig\",\n                                          [\"exec_uuid\", \"hlo\"])\nPartialGradWorkerExecutableConfig = namedtuple(\n    \"PartialGradWorkerExecutableConfig\",\n    [\"exec_uuid\", \"hlo\", \"stage_plan\", \"donated_invars\"])\n\nExecutableConfig = Union[AllocateZeroWorkerExecutableConfig,\n                         PartialGradWorkerExecutableConfig,\n                         ConcatWorkerExecutableConfig]\n\n\ndef flatten_uuid_set(container):\n    \"\"\"Convert a nested array to an OrderedSet of elements in the array.\"\"\"\n    output = OrderedSet()\n    for e in container:\n        if isinstance(e, (np.ndarray, list)):\n            output.update(flatten_uuid_set(e))\n        else:\n            output.add(e)\n    return output\n\n\nclass PipelineInstEmitterHelper:\n    \"\"\"Environment for PipelineInstEmitter.\"\"\"\n\n    def __init__(self, global_invar_set: Set[Var],\n                 global_batch_invar_set: Set[Var],\n                 grad_dummy_invars: Dict[Var, Var], schedule: PipelineSchedule):\n        self.global_invar_set = global_invar_set\n        self.global_batch_invar_set = global_batch_invar_set\n        self.grad_dummy_invars = grad_dummy_invars\n        self.schedule = schedule\n        # Dict[var_key -> Dict[mesh_idx -> array_uuid]]\n        # The shape of the numpy array is [num_hosts, num_devices_per_host]\n        self.env = {}\n\n    def _get_var_key(self, var, batch_idx):\n        if (var in self.global_invar_set and\n                var not in self.global_batch_invar_set):\n            key = (var, 0)\n        elif (var in self.grad_dummy_invars and\n              batch_idx != self.schedule.first_backward_batch_index):\n            key = (self.grad_dummy_invars[var],\n                   self.schedule.previous_backward_batch_index(batch_idx))\n        else:\n            key = (var, batch_idx)\n        return key\n\n    def get_var_with_accumulate(self, var, batch_idx):\n        if (var in self.grad_dummy_invars and\n                batch_idx != self.schedule.first_backward_batch_index):\n            return self.grad_dummy_invars[var]\n        else:\n            return var\n\n    def get_var_mesh_uuid(self, var, batch_idx, mesh_idx) -> int:\n        key = self._get_var_key(var, batch_idx)\n        return self.env[key][mesh_idx]\n\n    def get_var_meshes(self, var, batch_idx) -> Dict[int, int]:\n        key = self._get_var_key(var, batch_idx)\n        return self.env.setdefault(key, {})\n\n    def set_var_mesh_uuid(self, var, batch_idx, mesh_idx, uuid):\n        key = self._get_var_key(var, batch_idx)\n        self.env.setdefault(key, {})[mesh_idx] = uuid\n\n    def var_at(self, var, batch_idx, mesh_idx) -> bool:\n        key = self._get_var_key(var, batch_idx)\n        return mesh_idx in self.env.setdefault(key, {})\n\n\n@dataclass\nclass PipeshardInputConfig:\n    \"\"\"Configurations of the inputs for a Pipeshard executable.\"\"\"\n    # The local input uuids\n    # List[mesh_idx -> List[arg_uuid]]\n    input_local_uuid_lists: Sequence[Sequence[int]]\n    # Whether the var should be donated\n    # List[mesh_idx -> List[bool]]\n    donate_invars: Sequence[Sequence[bool]]\n    # List[mesh_idx -> List[arg_idx]]\n    mesh_arg_indices: Sequence[Sequence[int]]\n    # Cached sharding indices for input arguments\n    # List[mesh_idx -> List[sharding_indices]].\n    input_shard_indices: Sequence[Sequence[Any]]\n    # Whether the argument should be deleted after shard\n    # List[mesh_idx -> List[bool]]\n    delete_after_shard: Sequence[Sequence[bool]]\n    # Whether the argument is a batch argument\n    # List[mesh_idx -> List[bool]]\n    batch_invars: Sequence[Sequence[bool]]\n\n\n# TODO(yonghao): use worker_idx as the dict's key\n@dataclass\nclass PipeshardConfig:\n    \"\"\"Configurations of a Pipeshard executable.\"\"\"\n    # Executable configs\n    instruction_lists: Dict[Any, Sequence[PipelineInstruction]]\n    xla_stages: Sequence[XlaShardedPipelineComputation]\n    # FIXME(yonghao): share this setting within a mesh\n    executable_configs: Dict[Any, Sequence[ExecutableConfig]]\n    executable_uuids: Sequence[int]\n    schedule: PipelineSchedule\n    # Resharding task configs\n    device_str_groups: Sequence[Sequence[OrderedSet]]\n    allreduce_groups: Tuple[Sequence[int], Var]\n    resharding_tasks: Sequence[ReshardingTask]\n    # Input configs\n    input_config: PipeshardInputConfig\n    grad_uuids: Sequence[np.ndarray]\n    reduced_var_uuid_lists: Sequence[np.ndarray]\n    # Output configs\n    output_local_uuid_list: Sequence[Sequence[int]]\n    outs_handler: Callable\n    # Others (debug info)\n    stage_input_shard_specs: Sequence[Sequence[pxla.ShardingSpec]]\n    input_placement_specs: Sequence[PlacementSpec]\n    output_placement_specs: Sequence[PlacementSpec]\n    default_auto_sharding_option: AutoShardingOption\n    manual_stage_option: ManualStageOption\n    sharding_annotated_hlo_texts: Sequence[str]\n    flop_count: int\n\n\nclass PipelineInstEmitter:\n    \"\"\"Pipeline Instruction Emitter.\"\"\"\n\n    def __init__(self, *, stages: Sequence[XlaShardedPipelineComputation],\n                 global_invars: Sequence[Var], grad_dummy_invars: Dict[Var,\n                                                                       Var],\n                 global_outvars: Sequence[Var], concat_vars_mapping: Dict[Var,\n                                                                          Var],\n                 mesh_group: PhysicalDeviceMeshGroup,\n                 schedule: PipelineSchedule, is_batch: Sequence[bool],\n                 num_batch: int,\n                 default_auto_sharding_option: AutoShardingOption,\n                 manual_stage_option: ManualStageOption, flop_count: int,\n                 allreduce_groups: Tuple[Sequence[int], Var]):\n        ##### Input arguments #####\n        self.stages = stages\n        self.global_invars = global_invars\n        self.grad_dummy_invars = grad_dummy_invars\n        self.concat_vars_mapping = concat_vars_mapping\n        self.global_outvars = global_outvars\n        self.mesh_group = mesh_group\n        self.num_mesh = len(mesh_group)\n        self.schedule = schedule\n        self.is_batch = is_batch\n        self.num_batch = num_batch\n        self.default_auto_sharding_option = default_auto_sharding_option\n        self.manual_stage_option = manual_stage_option\n        self.flop_count = flop_count\n        self.sharding_annotated_hlo_texts = [x.get_hlo_text() for x in stages]\n        self.allreduce_groups = allreduce_groups\n\n        ##### Internal states #####\n        self.uuid_counter = 0  # counter for local buffer uuid\n        global_invar_set = OrderedSet(global_invars)\n        global_batch_invar_set = OrderedSet(\n            v for v, b in zip(global_invars, is_batch) if b)\n        self.env = PipelineInstEmitterHelper(global_invar_set,\n                                             global_batch_invar_set,\n                                             grad_dummy_invars, schedule)\n        self._communicator = None\n        self._resharding_tasks = [\n            [{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh)\n        ]\n\n    def _get_next_uuids(self, num) -> np.ndarray:\n        \"\"\"Get the next uuids as a numpy array of uuids.\"\"\"\n        ret = np.arange(start=self.uuid_counter,\n                        stop=self.uuid_counter + num,\n                        dtype=np.int64)\n        self.uuid_counter += num\n        return ret\n\n    def _compile_sharding_specs(self):\n        \"\"\"Run spmd partitioner pass for each stage to get sharding specs.\"\"\"\n        for stage_idx, stage in enumerate(self.stages):\n            mesh_indices = list(self.schedule.stage_placement(stage_idx))\n            assert len(mesh_indices) == 1\n            stage.get_spmd_partitioned()\n\n    def _compile_resharding_tasks(self):\n        \"\"\"Create and compile all resharding (send/recv/allgather) tasks.\"\"\"\n        for (src_mesh_idx, dst_mesh_idx,\n             var_spec_map) in self._communicator.task_spec_iter():\n            for var, spec in var_spec_map.items():\n                cg = self.mesh_group.collective_groups[src_mesh_idx][\n                    dst_mesh_idx]\n                src_mesh = self.mesh_group[src_mesh_idx]\n                dst_mesh = self.mesh_group[dst_mesh_idx]\n                # TODO(yonghao): delay put_resharding_XXXX_task until pipeshard\n                #  executable\n                if global_config.resharding_mode == \"send_recv\":\n                    self._resharding_tasks[src_mesh_idx][dst_mesh_idx][\n                        var] = SymbolicReshardingTask(spec, cg, src_mesh,\n                                                      dst_mesh)\n                else:\n                    self._resharding_tasks[src_mesh_idx][dst_mesh_idx][\n                        var] = SymbolicBroadcastReshardingTask(\n                            spec, cg, src_mesh, dst_mesh)\n\n    def _gather_resharding_tasks(self):\n        \"\"\"Gather all resharding tasks into a list.\"\"\"\n        tasks = []\n        for src_idx in range(self.num_mesh):\n            for dst_idx in range(self.num_mesh):\n                tasks.extend(self._resharding_tasks[src_idx][dst_idx].values())\n        return tasks\n\n    def _establish_nccl_groups(self):\n        \"\"\"\n        Identify NCCL groups based on resharding specs but do not instantiate\n        them.\n\n        We establish one collective group between two physical meshes, covering\n        all the devices in these two meshes that require NCCL communication.\n\n        Returns:\n            device_str_groups (List[List[set]]): a num_mesh x num_mesh matrix.\n                Only entries at device_str_groups[i][j] (i < j) are filled,\n                entries with i > j are None, because (spec[i][j], spec[j][i])\n                will share collective groups.\n        \"\"\"\n        self._communicator = CrossMeshCommunicator(self.stages, self.schedule)\n        device_str_groups = [[OrderedSet()\n                              for _ in range(self.num_mesh)]\n                             for _ in range(self.num_mesh)]\n        # Merge (i, j) and (j, i)\n        for i, j, var_spec_map in self._communicator.task_spec_iter():\n            participants = OrderedSet()\n            for _, spec in var_spec_map.items():  # for each var\n                participants = participants | spec.get_participant_device_strs()\n            if i <= j:\n                device_str_groups[i][j] = device_str_groups[i][j] | participants\n            else:\n                device_str_groups[j][i] = device_str_groups[j][i] | participants\n\n        # construct groups\n        for i in range(self.num_mesh):\n            for j in range(self.num_mesh):\n                if i >= j:\n                    assert not device_str_groups[i][j]\n                    continue\n                if not device_str_groups[i][j]:\n                    continue\n                self.mesh_group.establish_nccl_group(i, j, instantiate=False)\n        return device_str_groups\n\n    def compile(self):\n        \"\"\"Compile pipeline instructions and executables for workers.\"\"\"\n        num_mesh = len(self.mesh_group)\n\n        # Compile resharding tasks\n        self._compile_sharding_specs()\n        device_str_groups = self._establish_nccl_groups()\n        self._compile_resharding_tasks()\n\n        # Compile forward, backward and apply_grad computations\n        (executable_uuids,\n         executable_config_lists) = self._compile_computation_executables()\n\n        # Compile gradient buffer allocations\n        grad_uuids, instruction_lists = self._compile_grad_buffer_allocations(\n            executable_config_lists)\n\n        # Split input into micro batches\n        (input_config,\n         input_shard_specs) = self._compile_split_input_to_microbatches()\n\n        # Simulate the pipeline schedule and generate instructions\n        donation_mapping = [DisjointDict() for _ in range(num_mesh)]\n        worker_to_idx = {}\n        for mesh_idx, mesh in enumerate(self.mesh_group):\n            for worker_idx, worker in enumerate(mesh.workers):\n                worker_to_idx[worker] = (mesh_idx, worker_idx)\n\n        for _, sched in enumerate(self.schedule.schedules):\n            self._compile_exec_one_tick(sched, donation_mapping,\n                                        instruction_lists, executable_uuids,\n                                        executable_config_lists)\n\n        # Compile concate\n        self._compile_concate(instruction_lists, executable_config_lists)\n\n        # Compile information for outputs\n        output_local_uuid_list, mesh_output_indices, output_spec_list = (\n            self._compile_collect_outputs())\n        outs_handler, output_placement_specs = self._get_outs_handler(\n            mesh_output_indices, output_spec_list)\n\n        # Add gradient accumulation buffer\n        reduced_var_uuid_lists = []\n        for mesh_idx in range(num_mesh):\n            reduced_var_uuids = grad_uuids[mesh_idx]\n            reduced_var_uuids = np.array([\n                donation_mapping[mesh_idx].recursive_lookup(uuid)\n                for uuid in reduced_var_uuids\n            ])\n            reduced_var_uuid_lists.append(reduced_var_uuids)\n        # Insert buffer free instructions\n        for worker in instruction_lists:\n            mesh_idx, worker_idx = worker_to_idx[worker]\n            used_outside = flatten_uuid_set(output_local_uuid_list[mesh_idx])\n\n            donated = set(donation_mapping[mesh_idx].keys())\n            used_outside.update(flatten_uuid_set(reduced_var_uuids))\n            instruction_lists[worker] = self._compile_free(\n                worker, used_outside, donated, instruction_lists)\n\n        # Compile load info\n        input_placement_specs = self._compile_input_placement_spec(\n            input_config.mesh_arg_indices, input_shard_specs)\n\n        # Keep the input sharding specs based on pipeline stages\n        input_shard_specs = [\n            self.stages[idx].input_sharding_specs\n            for idx in self.schedule.mesh_stage_mapping\n        ]\n\n        return PipeshardConfig(\n            # Executable configs\n            instruction_lists,\n            self.stages,\n            executable_config_lists,\n            executable_uuids,\n            self.schedule,\n            # Resharding task configs\n            device_str_groups,\n            self.allreduce_groups,\n            self._gather_resharding_tasks(),\n            # Input configs\n            input_config,\n            grad_uuids,\n            reduced_var_uuid_lists,\n            # Output configs\n            output_local_uuid_list,\n            outs_handler,\n            # Others\n            input_shard_specs,\n            input_placement_specs,\n            output_placement_specs,\n            self.default_auto_sharding_option,\n            self.manual_stage_option,\n            self.sharding_annotated_hlo_texts,\n            self.flop_count)\n\n    def _compile_get_vars_from_mesh(self, invars, dst_specs, mesh_idx,\n                                    batch_idx, comm_lists, alloc_lists,\n                                    executable_config_lists):\n        if len(invars) == 0:\n            return\n        # TODO(yonghao): only compile alloc once, use multiple times\n        recv_uuid_list = self._compile_alloc(invars, dst_specs, mesh_idx,\n                                             batch_idx, alloc_lists,\n                                             executable_config_lists, \"recv\")\n\n        for invar, recv_uuid in zip(invars, recv_uuid_list):\n            var_key = self.env.get_var_with_accumulate(invar, batch_idx)\n            src_idx, src_uuid = list(\n                self.env.get_var_meshes(invar, batch_idx).items())[0]\n            resharding_task = self._resharding_tasks[src_idx][mesh_idx][var_key]\n            if global_config.resharding_mode == \"send_recv\":\n                self._compile_resharding_task(src_uuid, resharding_task,\n                                              recv_uuid, comm_lists)\n            else:\n                self._compile_broadcast_resharding_task(\n                    self.mesh_group[src_idx], src_uuid, resharding_task,\n                    recv_uuid, comm_lists)\n\n    def _compile_exec_one_mesh(self, mesh_idx, task, executable_uuids,\n                               donation_mapping, worker_tmp_instructions):\n        batch_idx, stage_idx = task\n        physical_mesh = self.mesh_group[mesh_idx]\n        stage = self.stages[stage_idx]\n        for outvar in stage.outvars:\n            # get uuids of this outvar\n            output_uuid = self._get_next_uuids(1)[0]\n            self.env.set_var_mesh_uuid(outvar, batch_idx, mesh_idx, output_uuid)\n\n        exec_uuid = executable_uuids[stage_idx]\n        donated_invars = self.stages[stage_idx].donated_invars\n\n        input_uuids = np.zeros((len(stage.invars),), dtype=np.int64)\n        output_uuids = np.zeros((len(stage.outvars),), dtype=np.int64)\n        for idx, invar in enumerate(stage.invars):\n            input_uuids[idx] = self.env.get_var_mesh_uuid(\n                invar, batch_idx, mesh_idx)\n        for idx, outvar in enumerate(stage.outvars):\n            output_uuids[idx] = self.env.get_var_mesh_uuid(\n                outvar, batch_idx, mesh_idx)\n        for idx in range(len(stage.invars)):\n            if donated_invars[idx]:\n                donation_mapping[mesh_idx].update(input_uuids[idx],\n                                                  output_uuids[idx])\n\n        for worker in physical_mesh.workers:\n            kwargs = {\n                \"skip_grad_sync\": self.schedule.should_skip_grad_sync(task),\n                \"sync_before\": False,\n                \"sync_after\": False,\n            }\n\n            worker_tmp_instructions[worker].append(\n                PipelineInstruction.run(exec_uuid,\n                                        input_uuids,\n                                        output_uuids,\n                                        kwargs,\n                                        info=f\"stage {stage_idx}\"))\n\n    def _compile_exec_one_tick(self, sched, donation_mapping, instruction_lists,\n                               executable_uuids, executable_config_lists):\n        worker_tmp_instructions = {}\n        for mesh in self.mesh_group:\n            for worker in mesh.workers:\n                worker_tmp_instructions[worker] = []\n\n        for mesh_idx, task in enumerate(sched):\n            if not task:\n                continue\n            batch_idx, stage_idx = task\n            stage = self.stages[stage_idx]\n            # shard_args for intermediates\n            to_reshard_vars = []\n            reshard_sharding_specs = []\n            for invar, spec in zip(stage.invars, stage.input_sharding_specs):\n                if self.env.var_at(invar, batch_idx, mesh_idx):\n                    # have a copy at the current mesh\n                    continue\n                # TODO(yonghao): to avoid congestion, maybe sending from the\n                # last one (a.k.a. the latest one receiving it) is better, but\n                # we have to create the corresponding cross-mesh communication\n                # task.\n                # if len(self.env.get_var_meshes(invar, batch_idx)) > 1:\n                #     raise NotImplementedError(\n                #         \"Not support resharding replicated\")\n                var_key = self.env.get_var_with_accumulate(invar, batch_idx)\n                src_idx = list(\n                    self.env.get_var_meshes(invar, batch_idx).keys())[0]\n                resharding = self._resharding_tasks[src_idx][mesh_idx][var_key]\n                if resharding.is_local_allgather_task:\n                    spec = resharding.task_spec.dst_sharding_spec\n                to_reshard_vars.append(invar)\n                reshard_sharding_specs.append(spec)\n            self._compile_get_vars_from_mesh(to_reshard_vars,\n                                             reshard_sharding_specs, mesh_idx,\n                                             batch_idx, instruction_lists,\n                                             instruction_lists,\n                                             executable_config_lists)\n\n            # execute\n            self._compile_exec_one_mesh(mesh_idx, task, executable_uuids,\n                                        donation_mapping,\n                                        worker_tmp_instructions)\n\n        for worker, worker_instruction in worker_tmp_instructions.items():\n            instruction_lists[worker].extend(worker_instruction)\n\n    def _compile_computation_executables(self):\n        \"\"\"Compile executables for forward, backward, and apply_grad\n        compuations.\"\"\"\n        executable_uuids = []  # List[stage_idx -> executable_uuids]\n        executable_config_lists = defaultdict(\n            list)  # Dict[worker -> List[ExecutableConfig]]\n\n        for stage_idx, stage in enumerate(self.stages):\n            exec_uuid = next_mesh_executable_uuid()\n            executable_uuids.append(exec_uuid)\n\n            mesh_idx = self.schedule.stage_placement(stage_idx)\n            assert len(mesh_idx) == 1\n            mesh_idx = list(mesh_idx)[0]\n            hlo = stage.get_spmd_partitioned()\n            exec_config = PartialGradWorkerExecutableConfig(\n                exec_uuid, hlo, stage.stage_plan, stage.donated_invars)\n\n            for worker in self.mesh_group[mesh_idx].workers:\n                executable_config_lists[worker].append(exec_config)\n\n        return executable_uuids, executable_config_lists\n\n    def _compile_grad_buffer_allocations(self, executable_config_lists):\n        \"\"\"Compile gradient buffer allocations.\"\"\"\n        num_mesh = len(self.mesh_group)\n        mesh_grad_vars = [{} for _ in range(num_mesh)]\n        instruction_lists = defaultdict(\n            list)  # Dict[worker -> List[PipelineInstruction]]\n        # collect gradient accumulation buffers in each mesh\n        for stage_idx, stage in enumerate(self.stages):\n            mesh_indices = list(self.schedule.stage_placement(stage_idx))\n            assert len(mesh_indices) == 1\n            mesh_idx = mesh_indices[0]\n            grad_var_spec_dict = mesh_grad_vars[mesh_idx]\n            input_specs = stage.input_sharding_specs\n            for var_idx, invar in enumerate(stage.invars):\n                if invar in self.grad_dummy_invars:\n                    if invar in grad_var_spec_dict:\n                        raise NotImplementedError(\n                            f\"accumulate {invar} at multiple stages in a mesh\")\n                    grad_var_spec_dict[invar] = input_specs[var_idx]\n\n        grad_uuids = [[] for _ in range(num_mesh)]\n        for mesh_idx in range(num_mesh):\n            grad_var_spec_dict = mesh_grad_vars[mesh_idx]\n            if len(grad_var_spec_dict):\n                grad_vars, grad_sharding_specs = list(\n                    zip(*grad_var_spec_dict.items()))\n\n                # TODO(yonghao): Some var has non-gradient intermediate states\n                # that need accumulation. for these vars, we need to record its\n                # first mb index when accum will take place.\n                grad_uuids[mesh_idx] = self._compile_alloc(\n                    grad_vars, grad_sharding_specs, mesh_idx,\n                    self.schedule.first_backward_batch_index, instruction_lists,\n                    executable_config_lists, \"grad acc\")\n\n        return grad_uuids, instruction_lists\n\n    def _compile_collect_mesh_input(self, mesh_idx):\n        mesh_arg_set = OrderedSet()\n        var_to_spec = {}\n        mesh_batch_vars = OrderedSet()\n        num_batch = self.num_batch\n        mesh_arg_indices = []\n        input_shard_indices = []\n        input_shard_specs = []\n        mesh_invar_is_batch = []\n        for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]:\n            stage = self.stages[stage_idx]\n            for spec, invar in zip(stage.input_sharding_specs, stage.invars):\n                if invar in self.env.global_invar_set:\n                    var_to_spec[invar] = spec\n                    if invar in self.env.global_batch_invar_set:\n                        # Split batch arg\n                        for batch_idx in range(num_batch):\n                            mesh_arg_set.add((invar, batch_idx))\n                        mesh_batch_vars.add(invar)\n                    else:\n                        mesh_arg_set.add((invar, 0))\n        mesh_arg_list = list(mesh_arg_set)\n\n        for info in mesh_arg_list:\n            var, batch_idx = info\n            if batch_idx != 0:\n                continue\n\n            global_idx = self.global_invars.index(var)\n            mesh_arg_indices.append(global_idx)\n            mesh_invar_is_batch.append(self.is_batch[global_idx])\n\n            if self.is_batch[global_idx]:\n                aval = var.aval\n                batch_dim = 0\n                new_shape = (num_batch * aval.shape[0],) + aval.shape[1:]\n                new_spec = get_microbatch_sharding_spec(var_to_spec[var],\n                                                        batch_dim, num_batch)\n                input_shard_indices.append(\n                    pxla.spec_to_indices(new_shape, new_spec))\n                input_shard_specs.append(var_to_spec[var])\n            else:\n                input_shard_indices.append(\n                    pxla.spec_to_indices(var.aval.shape, var_to_spec[var]))\n                input_shard_specs.append(var_to_spec[var])\n        return (mesh_arg_list, mesh_arg_indices, input_shard_indices,\n                input_shard_specs, mesh_invar_is_batch)\n\n    def _compile_split_input_to_microbatches(self):\n        \"\"\"\n        Split batch arguments into micro batches.\n\n        The split is like:\n        before: a, b, c, d\n        after (b, d are batch args and #mb=2): a, b0, b1, c, d0, d1\n        \"\"\"\n        donated_invar_set = OrderedSet()\n        for stage in self.stages:\n            for invar, donate in zip(stage.invars, stage.donated_invars):\n                if donate and invar in self.env.global_invar_set:\n                    donated_invar_set.add(invar)\n        num_mesh = len(self.mesh_group)\n        mesh_arg_lists = [None for _ in range(num_mesh)]\n\n        # Dispatch args to each mesh\n        arg_last_use = {}\n        donate_invars = []\n        mesh_arg_indices = []\n        input_shard_indices = []\n        input_shard_specs = []\n        batch_invars = []\n        for mesh_idx in range(num_mesh):\n            (mesh_arg_list, arg_indices, shard_indices, shard_specs,\n             is_batch) = self._compile_collect_mesh_input(mesh_idx)\n\n            mesh_arg_lists[mesh_idx] = mesh_arg_list\n            delete_after_run = [\n                var in donated_invar_set or\n                (var in self.env.global_batch_invar_set and\n                 global_config.always_donate_micro_batch_vars)\n                for var, _ in mesh_arg_list\n            ]\n            donate_invars.append(delete_after_run)\n            for info in mesh_arg_list:\n                var, batch_idx = info\n                if batch_idx != 0:\n                    continue\n                arg_last_use[var] = mesh_idx\n\n            mesh_arg_indices.append(arg_indices)\n            input_shard_indices.append(shard_indices)\n            input_shard_specs.append(shard_specs)\n            batch_invars.append(is_batch)\n\n        delete_after_shard = []\n        for mesh_idx in range(num_mesh):\n            delete_after_shard.append([\n                self.global_invars[idx] in donated_invar_set and\n                arg_last_use[self.global_invars[idx]] == mesh_idx\n                for idx in mesh_arg_indices[mesh_idx]\n            ])\n\n        # Get local uuids for each input\n        input_local_uuid_lists = [[] for _ in range(num_mesh)]\n        for mesh_idx in range(num_mesh):\n            mesh_arg_list = mesh_arg_lists[mesh_idx]\n            num_args = len(mesh_arg_list)\n            # shape: (num_args, num_hosts, num_devices_per_host)\n            if num_args > 0:\n                arg_uuids = self._get_next_uuids(num_args)\n                for arg_idx, info in enumerate(mesh_arg_lists[mesh_idx]):\n                    var, batch_idx = info\n                    self.env.set_var_mesh_uuid(var, batch_idx, mesh_idx,\n                                               arg_uuids[arg_idx])\n                    input_local_uuid_lists[mesh_idx].append(arg_uuids[arg_idx])\n        input_config = PipeshardInputConfig(\n            input_local_uuid_lists=input_local_uuid_lists,\n            donate_invars=donate_invars,\n            mesh_arg_indices=mesh_arg_indices,\n            input_shard_indices=input_shard_indices,\n            delete_after_shard=delete_after_shard,\n            batch_invars=batch_invars)\n        return input_config, input_shard_specs\n\n    def _compile_concate_get_spec(self, to_concate_vars):\n        var_to_spec_all_meshes = []\n        output_at = defaultdict(OrderedSet)\n        num_mesh = len(self.mesh_group)\n        for mesh_idx in range(num_mesh):\n            var_to_spec = {}\n            for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]:\n                stage = self.stages[stage_idx]\n                for spec, outvar in zip(stage.output_sharding_specs,\n                                        stage.outvars):\n                    if outvar in to_concate_vars:\n                        var_to_spec[outvar] = spec\n                        output_at[outvar].add(mesh_idx)\n            var_to_spec_all_meshes.append(var_to_spec)\n        return var_to_spec_all_meshes, output_at\n\n    def _compile_concate(self, instruction_lists, executable_config_lists):\n        \"\"\"\n        Generate concate instruction for variables used in non-microbatch part,\n        but are not reduced. They should be concated.\n        \"\"\"\n        batch_dim = 0\n        to_concate_vars = set(self.concat_vars_mapping.values())\n        to_concate_specs, output_at = self._compile_concate_get_spec(\n            to_concate_vars)\n        for var in self.concat_vars_mapping:\n            src_var = self.concat_vars_mapping[var]\n            dst_mesh_to_uuids = self.env.get_var_meshes(\n                var, self.schedule.last_backward_batch_index)\n            for mesh_idx in output_at[src_var]:\n                physical_mesh = self.mesh_group[mesh_idx]\n                # Get input and output uuids\n                input_args = np.zeros((self.num_batch,), dtype=np.int64)\n                for batch_idx in range(self.num_batch):\n                    input_args[batch_idx] = self.env.get_var_mesh_uuid(\n                        src_var, batch_idx, mesh_idx)\n                output_uuid = self._get_next_uuids(1)\n                dst_mesh_to_uuids[mesh_idx] = output_uuid[0]\n\n                # create and run concat executable\n                exec_uuid = next_mesh_executable_uuid()\n                spec = to_concate_specs[mesh_idx][src_var]\n                hlo = compile_concatenate(physical_mesh.shape, spec,\n                                          self.num_batch, batch_dim,\n                                          src_var.aval)\n                exec_config = ConcatWorkerExecutableConfig(exec_uuid, hlo)\n                kwargs = {\n                    \"sync_before\": False,\n                    \"sync_after\": False,\n                }\n                for worker in physical_mesh.workers:\n                    executable_config_lists[worker].append(exec_config)\n                    instruction_lists[worker].append(\n                        PipelineInstruction.run(exec_uuid, input_args,\n                                                output_uuid, kwargs))\n\n    def _compile_collect_outputs(self):\n        \"\"\"\n        Generate output information.\n\n        This function dispatches output information, including local uuid, local\n        indices to global indices, and output specs to each mesh.\n        \"\"\"\n        # List[mesh_idx -> List[uuid]]\n        output_local_uuid_list = [[] for _ in range(self.num_mesh)]\n        # List[arg_idx -> Dict[mesh_idx -> int]]\n        mesh_output_indices = []\n        # List[mesh_idx -> List[arg_idx -> sharding_spec]]\n        output_spec_list = [[] for _ in range(self.num_mesh)]\n\n        # collect outvar specs\n        var_to_spec_all_meshes = []\n        global_outvar_set = OrderedSet(self.global_outvars)\n        # This is only a patch. It will be deprecated after we move concat into\n        # a stage\n        reversed_concat = {\n            v: k\n            for k, v in self.concat_vars_mapping.items()\n            if k in global_outvar_set\n        }\n        output_at = defaultdict(OrderedSet)\n        for mesh_idx in range(self.num_mesh):\n            var_to_spec = {}\n            for stage_idx in self.schedule.mesh_stage_mapping[mesh_idx]:\n                stage = self.stages[stage_idx]\n                for spec, outvar in zip(stage.output_sharding_specs,\n                                        stage.outvars):\n                    if outvar in global_outvar_set:\n                        var_to_spec[outvar] = spec\n                        output_at[outvar].add(mesh_idx)\n                    if outvar in reversed_concat:\n                        concat_outvar = reversed_concat[outvar]\n                        var_to_spec[concat_outvar] = spec\n                        output_at[concat_outvar].add(mesh_idx)\n            var_to_spec_all_meshes.append(var_to_spec)\n        # assign indices and get specs\n        for outvar in self.global_outvars:\n            # the apply gradient only writes to microbatch 0\n            mesh_to_uuid = self.env.get_var_meshes(\n                outvar, self.schedule.last_backward_batch_index)\n            mesh_out_indices = {}\n            for mesh_idx in output_at[outvar]:\n                output_local_uuid_list[mesh_idx].append(mesh_to_uuid[mesh_idx])\n                mesh_out_indices[mesh_idx] = (\n                    len(output_local_uuid_list[mesh_idx]) - 1)\n                output_spec_list[mesh_idx].append(\n                    var_to_spec_all_meshes[mesh_idx][outvar])\n            mesh_output_indices.append(mesh_out_indices)\n\n        return output_local_uuid_list, mesh_output_indices, output_spec_list\n\n    def _compile_alloc(self, variables, sharding_specs, mesh_idx, batch_idx,\n                       instruction_lists, executable_config_lists, debug):\n        \"\"\"Compile an executable which allocates zero buffers.\n\n        The zero buffers are:\n        1) gradient accumulation buffers\n        2) temp buffers for receiving tensors\n        \"\"\"\n        config_class = AllocateZeroWorkerExecutableConfig\n        avals = [var.aval for var in variables]\n        sharded_shapes = [\n            get_shard_shape(aval, spec)\n            for aval, spec in zip(avals, sharding_specs)\n        ]\n        dtypes = [aval.dtype for aval in avals]\n        exec_uuid = next_mesh_executable_uuid()\n        config = config_class(exec_uuid, sharded_shapes, dtypes)\n\n        physical_mesh = self.mesh_group[mesh_idx]\n        output_uuids = self._get_next_uuids(len(variables))\n        for worker in physical_mesh.workers:\n            executable_config_lists[worker].append(config)\n            in_uuids = []\n            out_uuids = output_uuids\n            instruction_lists[worker].append(\n                PipelineInstruction.run(config.exec_uuid,\n                                        in_uuids,\n                                        out_uuids, {\n                                            \"sync_before\": False,\n                                            \"sync_after\": False\n                                        },\n                                        info=\"allocate zero for \" + debug))\n\n        # shape: (#args, num_hosts, num_devices_per_host)\n        for var_idx, var in enumerate(variables):\n            self.env.set_var_mesh_uuid(var, batch_idx, mesh_idx,\n                                       output_uuids[var_idx])\n        return output_uuids\n\n    def _get_outs_handler(self, mesh_output_indices, output_spec_list):\n        \"\"\"\n        Setup outs handlers that assemble RemoteBufs into DistributedArrays.\n        \"\"\"\n        outvar_idx_to_mesh_idx = {}  # Dict[var_idx -> List[mesh_idx]]\n        for i, _ in enumerate(self.global_outvars):\n            outvar_idx_to_mesh_idx[i] = list(mesh_output_indices[i].keys())\n\n        avals = [outvar.aval for outvar in self.global_outvars]\n        is_replicated = [\n            bool(len(outvar_idx_to_mesh_idx[i]) > 1)\n            for i, _ in enumerate(self.global_outvars)\n        ]\n\n        mesh_idx_list = []\n        outvar_index_on_mesh_list = []\n        spec_list = []\n        indices_list = []\n        output_placement_specs = []\n\n        # Generate cached info\n        for i, aval in enumerate(avals):\n            if not is_replicated[i]:\n                # for DistributedArray\n                mesh_idx = outvar_idx_to_mesh_idx[i][0]\n                outvar_index_on_mesh = mesh_output_indices[i][mesh_idx]\n                spec = output_spec_list[mesh_idx][outvar_index_on_mesh]\n                mesh_idx_list.append(mesh_idx)\n                outvar_index_on_mesh_list.append(outvar_index_on_mesh)\n                spec_list.append(spec)\n                indices_list.append(pxla.spec_to_indices(aval.shape, spec))\n\n                output_placement_specs.append(\n                    PlacementSpec(aval, (mesh_idx_list[-1],), (spec_list[-1],)))\n            else:\n                # for RepliatedDistributedArray\n                mesh_idx_list.append([])\n                outvar_index_on_mesh_list.append([])\n                spec_list.append([])\n                indices_list.append([])\n\n                for mesh_idx in outvar_idx_to_mesh_idx[i]:\n                    outvar_index_on_mesh = mesh_output_indices[i][mesh_idx]\n                    spec = output_spec_list[mesh_idx][outvar_index_on_mesh]\n\n                    mesh_idx_list[-1].append(mesh_idx)\n                    outvar_index_on_mesh_list[-1].append(outvar_index_on_mesh)\n                    spec_list[-1].append(spec)\n                    indices_list[-1].append(\n                        pxla.spec_to_indices(aval.shape, spec))\n                output_placement_specs.append(\n                    PlacementSpec(aval, tuple(mesh_idx_list[-1]),\n                                  tuple(spec_list[-1])))\n\n        def outs_handler(mesh_group, refs):\n            ret = []\n            for i, aval in enumerate(avals):\n                if not is_replicated[i]:\n                    # construct DistributedArray\n                    mesh_idx = mesh_idx_list[i]\n                    device_mesh = mesh_group[mesh_idx]\n                    arr = DistributedArray(\n                        device_mesh=device_mesh,\n                        aval=aval,\n                        sharding_spec=spec_list[i],\n                        remote_ref=refs[mesh_idx][outvar_index_on_mesh_list[i]],\n                        indices=indices_list[i])\n                else:\n                    # construct RepliatedDistributedArray\n                    meshes = []\n                    distributed_arrays = []\n                    for j, mesh_idx in enumerate(mesh_idx_list[i]):\n                        outvar_index_on_mesh = outvar_index_on_mesh_list[i][j]\n                        spec = spec_list[i][j]\n                        meshes.append(mesh_group[mesh_idx])\n                        distributed_arrays.append(\n                            DistributedArray(\n                                device_mesh=mesh_group[mesh_idx],\n                                aval=aval,\n                                sharding_spec=spec,\n                                remote_ref=refs[mesh_idx][outvar_index_on_mesh],\n                                indices=indices_list[i][j]))\n                    arr = ReplicatedDistributedArray(meshes, distributed_arrays)\n                ret.append(arr)\n            return ret\n\n        return outs_handler, output_placement_specs\n\n    def _compile_input_placement_spec(self, mesh_arg_indices,\n                                      input_shard_specs):\n        # build spec_arr: List[flatten global index -> PlacementSpec]\n        spec_arr = [None] * len(self.is_batch)\n        for mesh_idx, physical_mesh in enumerate(self.mesh_group):\n            for local_idx, global_idx in enumerate(mesh_arg_indices[mesh_idx]):\n                shard_spec = input_shard_specs[mesh_idx][local_idx]\n                if spec_arr[global_idx] is None:\n                    spec_arr[global_idx] = PlacementSpec(\n                        self.global_invars[global_idx].aval,\n                        (physical_mesh.mesh_id,), (shard_spec,))\n                else:\n                    old_val = spec_arr[global_idx]\n                    spec_arr[global_idx] = PlacementSpec(\n                        old_val.aval,\n                        old_val.mesh_ids + (physical_mesh.mesh_id,),\n                        old_val.sharding_specs + (shard_spec,))\n\n        return spec_arr\n\n    # TODO(yonghao): set empty buffer is not compatiable with local allgather\n    @staticmethod\n    def _compile_resharding_task(src_uuid: int,\n                                 resharding_task: SymbolicReshardingTask,\n                                 recv_uuid: int,\n                                 instruction_lists,\n                                 set_empty_buffer=False):\n        \"\"\"\n        Compile and generate SEND and RECV PipelineInstructions for a\n        ReshardingTask.\n\n        Args:\n            src_mesh: the src mesh\n            dst_mesh: the dst mesh\n            src_uuids: uuids of resharded buffer in src mesh\n            resharding_task: the task to be compiled\n            recv_uuids: uuids of resharded buffer in dst mesh\n            set_empty_buffer: set the empty buffer when recv or not\n        \"\"\"\n\n        # add send tasks for each worker\n        for w, task_uuid in resharding_task.send_worker_task_ids.items():\n            instruction_lists[w].append(\n                PipelineInstruction.send(task_uuid, [src_uuid]))\n\n        # add recv task for each worker\n        allgather_uuid = (resharding_task.allgather_uuid\n                          if resharding_task.is_local_allgather_task else None)\n        for w, task_uuid in resharding_task.recv_worker_task_ids.items():\n            instruction_lists[w].append(\n                PipelineInstruction.recv(task_uuid, [recv_uuid],\n                                         set_empty_buffer, allgather_uuid))\n\n    @staticmethod\n    def _compile_broadcast_resharding_task(\n            src_mesh, src_uuid: int,\n            resharding_task: SymbolicBroadcastReshardingTask, recv_uuid: int,\n            instruction_lists):\n\n        # add broadcast-based resharding task for each worker\n        for w, task_uuid in resharding_task.broadcast_worker_task_ids.items():\n            output_uuid = None\n            input_uuid = None\n            if w in src_mesh.workers:\n                input_uuid = [src_uuid]\n            else:\n                output_uuid = [recv_uuid]\n            instruction_lists[w].append(\n                PipelineInstruction.broadcast(task_uuid, input_uuid,\n                                              output_uuid, \"broadcast\"))\n\n    @staticmethod\n    def _compile_free(worker, used_outside, donated, instruction_lists):\n        \"\"\"Compile and generate FREE PipelineInstruction to recycle memory.\"\"\"\n        instruction_list = instruction_lists[worker]\n        new_list = []\n        cannot_free_uuids = OrderedSet(used_outside)\n        cannot_free_uuids.update(donated)\n        for instruction in reversed(instruction_list):\n            # for free instruction, do not free again\n            if instruction.input_uuids is None:\n                new_list.append(instruction)\n                continue\n            input_uuids = flatten_uuid_set(instruction.input_uuids)\n            if not instruction.opcode == PipelineInstType.FREE:\n                unused_uuids = input_uuids.difference(cannot_free_uuids)\n                if len(unused_uuids) > 0:\n                    new_list.append(\n                        PipelineInstruction.free(np.array(list(unused_uuids))))\n            cannot_free_uuids.update(input_uuids)\n            new_list.append(instruction)\n        return list(reversed(new_list))\n\n\nclass OverlapFriendlyPipelineInstEmitter(PipelineInstEmitter):\n    \"\"\"Pipeline instruction emitter that allocates buffers earlier.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        outvar_def_order = kwargs.pop(\"outvar_def_order\")\n        super().__init__(*args, **kwargs)\n        # Based on stage info, generate cross-mesh communication requirements\n        # This formulates what send task is required\n        # Dict[int, Dict[int, Tuple(List, List)]]\n        # src_mesh_idx -> (dst_mesh_idx -> (Vars, Sharding Specs))\n        self.stage_send_vars = [[] for _ in range(len(self.stages))]\n        self._get_stage_send_vars(outvar_def_order)\n\n    def _get_stage_send_vars(self, outvar_def_order):\n        self._compile_sharding_specs()\n        var_defined = {}\n        var_at_mesh = {}\n        global_invar_set = set(self.global_invars)\n        # mesh_idx -> set of stage_idx\n        for stage_idx, stage in enumerate(self.stages):\n            assert len(self.schedule.stage_placement(stage_idx)) == 1\n            mesh_idx = list(self.schedule.stage_placement(stage_idx))[0]\n            for var_idx, var in enumerate(stage.invars):\n                if (var in global_invar_set or var in self.grad_dummy_invars or\n                        mesh_idx in var_at_mesh[var]):\n                    continue\n                else:\n                    # Currently we use the first mesh, since there is almost no\n                    # redundant computation and the first sends earlier. If the\n                    # var is required multiple times, then we might need round-\n                    # robin to avoid congestion.\n                    src_stage_idx = list(var_defined[var])[0]\n                    # once the var is received, it is permanent stored. Maybe\n                    # we will can an option to config it.\n                    var_at_mesh[var].add(mesh_idx)\n                    # insert the recv task\n                    self.stage_send_vars[src_stage_idx].append(\n                        (mesh_idx, var, stage.input_sharding_specs[var_idx]))\n\n            for var in stage.outvars:\n                var_defined.setdefault(var, OrderedSet()).add(stage_idx)\n                var_at_mesh.setdefault(var, OrderedSet()).add(mesh_idx)\n        # Reorder send and merge\n        for stage_idx, stage in enumerate(self.stages):\n            send_vars = self.stage_send_vars[stage_idx]\n            var_def_order = {\n                k: i for i, k in enumerate(outvar_def_order[stage_idx])\n            }\n            send_vars = sorted(send_vars,\n                               key=lambda sv, order=var_def_order:\n                               (order[sv[1]], sv[0]))\n            final_send_seq = []\n            for recv_stage_idx, v, spec in send_vars:\n                if (len(final_send_seq) != 0 and\n                    (final_send_seq[-1][0] == recv_stage_idx)):\n                    final_send_seq[-1][1].append(v)\n                    final_send_seq[-1][2].append(spec)\n                else:\n                    final_send_seq.append((recv_stage_idx, [v], [spec]))\n            self.stage_send_vars[stage_idx] = final_send_seq\n\n    def _compile_exec_one_tick(self, sched, donation_mapping, instruction_lists,\n                               executable_uuids, executable_config_lists):\n        exec_insts = {}\n        comm_insts = {}\n        for mesh in self.mesh_group:\n            for worker in mesh.workers:\n                exec_insts[worker] = []\n                comm_insts[worker] = []\n        for mesh_idx, task in enumerate(sched):\n            if not task:\n                continue\n            # execute\n            self._compile_exec_one_mesh(mesh_idx, task, executable_uuids,\n                                        donation_mapping, exec_insts)\n        # send immediately after the result is created.\n        # we use another iteration to launch exec before alloc zero for recv\n        for mesh_idx, task in enumerate(sched):\n            if not task:\n                continue\n            batch_idx, stage_idx = task\n            if len(self.stage_send_vars[stage_idx]) > 0:\n                for recv_info in self.stage_send_vars[stage_idx]:\n                    (receiver_idx, received_vars,\n                     received_sharding_specs) = recv_info\n                    self._compile_get_vars_from_mesh(received_vars,\n                                                     received_sharding_specs,\n                                                     receiver_idx, batch_idx,\n                                                     comm_insts,\n                                                     instruction_lists,\n                                                     executable_config_lists)\n        for worker, insts in exec_insts.items():\n            instruction_lists[worker].extend(insts)\n            instruction_lists[worker].extend(comm_insts[worker])\n"
  },
  {
    "path": "alpa/pipeline_parallel/schedules.py",
    "content": "\"\"\"Generate pipeline schedules.\"\"\"\nimport itertools\nimport logging\nfrom abc import abstractmethod, ABCMeta\nfrom typing import Dict, List, Tuple\n\nimport numpy as np\n\nfrom alpa.pipeline_parallel.computation import PipelineComputation\nfrom alpa.util import cached_property, OrderedSet\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\ndef gen_dependency_with_stages(\n    compute_stages: List[PipelineComputation],\n    num_mesh: int,\n    apply_grad_stages: List[PipelineComputation] = ()):\n    \"\"\"Generate the dependency matrix for a list of pipeline stages.\"\"\"\n    n_stages = len(compute_stages) + len(apply_grad_stages)\n    d = np.zeros([n_stages, n_stages], dtype=int)\n    var_stage_id = {}\n    fwd_intermediate_vars = OrderedSet()\n    for i, stage in enumerate(itertools.chain(compute_stages,\n                                              apply_grad_stages)):\n        for var in stage.invars:\n            if var in var_stage_id:\n                d[i, var_stage_id[var]] = 1\n                if i < num_mesh and var_stage_id[var] != 2 * num_mesh - i - 1:\n                    # not the var from forward to backward. we don't care them.\n                    # not the var on the backward side\n                    fwd_intermediate_vars.add(var)\n            else:\n                # Assume the var is from global_invars\n                pass\n        for var in stage.outvars:\n            var_stage_id[var] = i\n\n    return d, fwd_intermediate_vars\n\n\ndef gen_linear_pipeline_dependency(num_stage):\n    \"\"\"\n    Generate a dependency matrix.\n\n    The matrix marks forward/backward stage pairs as neighbors. For test only.\n    \"\"\"\n    assert num_stage % 2 == 0\n    d = np.zeros([num_stage, num_stage], dtype=int)\n    for i in range(num_stage - 1):\n        d[i + 1][i] = 1\n    for i in range(num_stage // 2):\n        d[num_stage - 1 - i][i] = 1\n    return d\n\n\nclass PipelineSchedule(metaclass=ABCMeta):\n    \"\"\"\n    A pipeline schedule used by the distributed runtime.\n\n    The core interface of this schedule is .schedule object.\n\n    Args:\n        dependency (np.array): dependency adjacency matrix.\n        sliced_mesh (List[VirtualPhysicalMesh]): a list of pre-sliced virtual\n            meshes to assign stages on.\n        apply_grad_placement (Dict[int, int]): A map from apply grad's stage idx\n            to the worker it is assigned.\n        num_batch (int): number of microbatches.\n    \"\"\"\n\n    def __init__(self,\n                 *,\n                 dependency,\n                 meshes,\n                 apply_grad_placement,\n                 num_batch=1):\n        self.dependency = dependency\n        self.meshes = meshes\n        self.apply_grad_placement = apply_grad_placement\n        self.num_batch = num_batch\n\n        self._schedules: List[List[Tuple]] = self._generate_schedule()\n\n    @property\n    @abstractmethod\n    def name(self):\n        raise NotImplementedError()\n\n    @abstractmethod\n    def _generate_schedule(self):\n        \"\"\"Implementation of the schedule.\"\"\"\n        raise NotImplementedError()\n\n    def pprint_schedule(self, to_print=False):\n        \"\"\"Pretty print the schedule.\"\"\"\n        printout = \"\\n\"\n        device_str = \" \".join([f\"d{d:<8}\" for d in range(self.num_mesh)])\n        printout = printout + f\"Clock k : {device_str} \\n\"\n        for clock, scheds in enumerate(self.schedules):\n            sched_str = \" \".join([f\"{str(sched):<8}\" for sched in scheds])\n            printout = printout + f\"Clock {clock:<2}: {sched_str} \\n\"\n        if to_print:\n            logger.info(printout)\n        return printout\n\n    @property\n    def schedules(self):\n        \"\"\"Return the schedules.\"\"\"\n        return self._schedules\n\n    @property\n    def num_stage(self):\n        \"\"\"Return the number of stage, including apply_grad stages.\"\"\"\n        return self.dependency.shape[0]\n\n    @property\n    def num_mesh(self):\n        \"\"\"Return the number of meshes.\"\"\"\n        return len(self.meshes)\n\n    @property\n    def num_clock(self):\n        \"\"\"Return the number of clocks in the schedule.\"\"\"\n        return len(self._schedules)\n\n    @cached_property\n    def stage_mesh_mapping(self):\n        \"\"\"Generate a stage-worker mapping according to the schedule.\"\"\"\n        placements = {}\n        for tasks in self._schedules:\n            for mesh_idx, task in enumerate(tasks):\n                if task:\n                    _, stage_idx = task\n                    if stage_idx not in placements:\n                        placements[stage_idx] = OrderedSet()\n                    if mesh_idx not in placements[stage_idx]:\n                        placements[stage_idx].add(mesh_idx)\n        return placements\n\n    @cached_property\n    def mesh_stage_mapping(self):\n        \"\"\"Generate a worker-stage mapping according to the schedule.\"\"\"\n        ownership = {}\n        for tasks in self._schedules:\n            for mesh_idx, task in enumerate(tasks):\n                if task:\n                    _, stage_idx = task\n                    if mesh_idx not in ownership:\n                        ownership[mesh_idx] = OrderedSet()\n                    if stage_idx not in ownership[mesh_idx]:\n                        ownership[mesh_idx].add(stage_idx)\n        return ownership\n\n    def stage_placement(self, stage_idx):\n        \"\"\"Query the placement of a stage given its stage index.\"\"\"\n        return self.stage_mesh_mapping[stage_idx]\n\n    def mesh_placement(self, mesh_idx):\n        \"\"\"Query the responsible stages of a worker given a worker index.\"\"\"\n        return self.mesh_stage_mapping[mesh_idx]\n\n    def should_skip_grad_sync(self, task):\n        \"\"\"\n        Query if grad sync (w/ other date replicas) should be skipped on a task.\n\n        Args:\n            task (Tuple[int]): (batch index, stage index).\n        \"\"\"\n        batch_idx, _ = task\n        return batch_idx != self.last_backward_batch_index\n\n    @abstractmethod\n    def previous_backward_batch_index(self, batch_idx):\n        \"\"\"Return microbatch index during backward prior to batch_idx.\"\"\"\n        raise NotImplementedError()\n\n    @property\n    @abstractmethod\n    def first_backward_batch_index(self):\n        \"\"\"Return the index of the first microbatch at backward pass.\"\"\"\n        raise NotImplementedError()\n\n    @property\n    @abstractmethod\n    def last_backward_batch_index(self):\n        \"\"\"Return the index of the last microbatch at backward pass.\"\"\"\n        raise NotImplementedError()\n\n\nclass GpipeSchedule(PipelineSchedule):\n    \"\"\"Construct a Gpipe-like schedule.\"\"\"\n\n    @property\n    def name(self):\n        return \"gpipe\"\n\n    def _generate_schedule(self):\n        \"\"\"\n        Generate a Gpipe-like schedule.\n\n        Note that here we always assume num_pipeline_workers = num_stage / 2.\n\n        The schedule will look like below:\n        i: index of micro-batch\n        j: index of partition/device\n        k: clock number\n\n        k (i,j) (i,j) (i,j)\n        - ----- ----- -----\n        0 (0,0)\n        1 (1,0) (0,1)\n        2 (2,0) (1,1) (0,2)\n        3       (2,1) (1,2)\n        4             (2,2)\n        5 reverse...\n        \"\"\"\n        m = self.num_batch\n        n = self.num_mesh\n        num_clock = m + n - 1\n        schedules = []\n        for k in range(num_clock):\n            scheds = [None] * n\n            for d in range(max(1 + k - m, 0), min(1 + k, n)):\n                scheds[d] = (k - d, d)\n            schedules.append(scheds)\n\n        def reverse(scheds):\n            rev = []\n            for task in scheds:\n                if not task:\n                    rev.append(None)\n                else:\n                    rev.append((m - 1 - task[0], 2 * n - 1 - task[1]))\n                    # rev.append((task[0], 2 * n - 1 - task[1]))\n            return rev\n\n        # backward schedules\n        # Note: large microbatch index is executed earlier in backward now.\n        for k in range(num_clock):\n            mapped_scheds = schedules[num_clock - k - 1]\n            schedules.append(reverse(mapped_scheds))\n\n        # apply_grad schedules\n        scheds = [None] * n\n        for stage_idx, worker in self.apply_grad_placement.items():\n            scheds[worker] = (self.last_backward_batch_index, stage_idx)\n        schedules.append(scheds)\n        return schedules\n\n    @property\n    def first_backward_batch_index(self):\n        \"\"\"Return the index of the first microbatch at backward pass.\"\"\"\n        return 0\n        # return self.num_batch - 1\n\n    @property\n    def last_backward_batch_index(self):\n        \"\"\"Return the index of the last microbatch at backward pass.\"\"\"\n        return self.num_batch - 1\n        # return 0\n\n    def previous_backward_batch_index(self, batch_idx):\n        \"\"\"Return the index of the previous microbatch at backward pass.\"\"\"\n        assert batch_idx > 0\n        return batch_idx - 1\n        # return batch_idx + 1\n\n\nclass PipeDreamFlush(PipelineSchedule):\n    \"\"\"\n    Generate a PipeDream-Flush schedule (a.k.a. 1F1B).\n\n    It has similar latency to GPipe but is more memory-efficient.\n    \"\"\"\n\n    @property\n    def name(self):\n        return \"1f1b\"\n\n    def _generate_schedule(self):\n        \"\"\"\n        Using the same notation as GPipeSchedule but adding the F for forward\n        and B for backward, this schedule can be represented as\n        k (i,j)   (i,j)   (i,j)\n        - ------- ------- -------\n        0 (0,0,F)\n        1 (1,0,F) (0,1,F)\n        2 (2,0,F) (1,1,F) (0,2,F)\n        3                 (0,2,B)\n        4         (0,1,B) (1,2,F)\n        5 (0,0,B) (2,1,F) (1,2,B)\n        6 (3,0,F) (1,1,B) (2,2,F)\n        ...\n        \"\"\"\n        m = self.num_batch\n        n = self.num_mesh\n\n        # equal to gpipe\n        num_clock = (m + n - 1) * 2\n        schedules = [[None] * n for k in range(num_clock)]\n\n        num_warmup_microbatches = [min(n - i - 1, m) for i in range(n)]\n        num_microbatches_remaining = [m - i for i in num_warmup_microbatches]\n\n        next_fwd_mb_idx = [0 for _ in range(n)]\n        next_bwd_mb_idx = [0 for _ in range(n)]\n        next_available_clock = list(range(n))\n        finished_bwd_batch_indices = np.zeros(shape=[num_clock, n],\n                                              dtype=np.int32)\n\n        # warm-up clocks\n        for i in range(n):\n            for _ in range(num_warmup_microbatches[i]):\n                schedules[next_available_clock[i]][i] = (next_fwd_mb_idx[i], i)\n                next_available_clock[i] = next_available_clock[i] + 1\n                next_fwd_mb_idx[i] = next_fwd_mb_idx[i] + 1\n\n        # run 1F1B\n        for i in reversed(range(n)):\n            # from the last device to the first\n            for _ in range(num_microbatches_remaining[i]):\n                # running through all the remaining microbatches\n                # forward\n                next_clock = next_available_clock[i]\n                schedules[next_clock][i] = (next_fwd_mb_idx[i], i)\n                next_fwd_mb_idx[i] = next_fwd_mb_idx[i] + 1\n                finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i]\n                next_clock = next_clock + 1\n\n                # backward\n                # first, offset the next available clock to the clock\n                # when the previous stage has just finished backward of the\n                # target mb.\n                if i + 1 < n:  # not the last device\n                    # find the next possible backward clock\n                    while finished_bwd_batch_indices[next_clock][\n                            i + 1] <= next_bwd_mb_idx[i]:\n                        assert finished_bwd_batch_indices[\n                            next_clock - 1][i] == next_bwd_mb_idx[i]\n                        finished_bwd_batch_indices[next_clock][\n                            i] = finished_bwd_batch_indices[next_clock - 1][i]\n                        next_clock = next_clock + 1\n\n                schedules[next_clock][i] = (next_bwd_mb_idx[i], 2 * n - 1 - i)\n                finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i]\n                next_bwd_mb_idx[i] = next_bwd_mb_idx[i] + 1\n                next_available_clock[i] = next_clock + 1\n\n        # run cooldown passes\n        for i in reversed(range(n)):\n            for _ in range(num_warmup_microbatches[i]):\n                assert i + 1 < n\n                next_clock = next_available_clock[i]\n                while finished_bwd_batch_indices[next_clock][\n                        i + 1] <= next_bwd_mb_idx[i]:\n                    finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[\n                        i]\n                    next_clock = next_clock + 1\n                schedules[next_clock][i] = (next_bwd_mb_idx[i], 2 * n - 1 - i)\n                finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i]\n                next_bwd_mb_idx[i] = next_bwd_mb_idx[i] + 1\n                next_available_clock[i] = next_clock + 1\n            # update status matrix for the last worker\n            if i > 0:\n                finished_bwd_batch_indices[next_available_clock[i]:num_clock,\n                                           i] = m\n\n        # append apply_grad schedules\n        scheds = [None] * n\n        for stage_idx, worker in self.apply_grad_placement.items():\n            scheds[worker] = (self.last_backward_batch_index, stage_idx)\n        schedules.append(scheds)\n        return schedules\n\n    @property\n    def first_backward_batch_index(self):\n        \"\"\"Return the index of the first microbatch at backward pass.\"\"\"\n        return 0\n\n    @property\n    def last_backward_batch_index(self):\n        \"\"\"Return the index of the last microbatch at backward pass.\"\"\"\n        return self.num_batch - 1\n\n    def previous_backward_batch_index(self, batch_idx):\n        \"\"\"Return the index of the previous microbatch at backward pass.\"\"\"\n        assert batch_idx > 0\n        return batch_idx - 1\n\n\nclass InferenceSchedule(PipelineSchedule):\n    \"\"\"Construct a Gpipe-like schedule.\"\"\"\n\n    @property\n    def name(self):\n        return \"inference\"\n\n    def _generate_schedule(self):\n        \"\"\"\n        Generate a forward-only schedule.\n\n        The schedule will look like below:\n        i: index of micro-batch\n        j: index of partition/device\n        k: clock number\n\n        k (i,j) (i,j) (i,j)\n        - ----- ----- -----\n        0 (0,0)\n        1 (1,0) (0,1)\n        2 (2,0) (1,1) (0,2)\n        3       (2,1) (1,2)\n        4             (2,2)\n        \"\"\"\n        m = self.num_batch\n        n = self.num_mesh\n        num_clock = m + n - 1\n        schedules = []\n        for k in range(num_clock):\n            scheds = [None] * n\n            for d in range(max(1 + k - m, 0), min(1 + k, n)):\n                scheds[d] = (k - d, d)\n            schedules.append(scheds)\n\n        # There should be no apply_grad tasks in the inference schedule.\n        # apply_grad schedules\n        scheds = [None] * n\n        for stage_idx, worker in self.apply_grad_placement.items():\n            scheds[worker] = (self.last_backward_batch_index, stage_idx)\n        schedules.append(scheds)\n\n        return schedules\n\n    @property\n    def first_backward_batch_index(self):\n        \"\"\"Return the index of the first microbatch at backward pass.\"\"\"\n        return 0\n\n    @property\n    def last_backward_batch_index(self):\n        \"\"\"Return the index of the last microbatch at backward pass.\"\"\"\n        return self.num_batch - 1\n\n    def previous_backward_batch_index(self, batch_idx):\n        \"\"\"Return the index of the previous microbatch at backward pass.\"\"\"\n        assert batch_idx > 0\n        return batch_idx - 1\n\n\nclass OverlapFriendlyPipeDreamSchedule(PipeDreamFlush):\n    \"\"\"\n    Generate a PipeDream-Flush schedule (a.k.a. 1F1B) but is more communication-\n    computation-overlap-friendly.\n\n    It has similar latency to 1F1B but costs more memory to store intermediates.\n    \"\"\"\n\n    def _generate_schedule(self):\n        \"\"\"\n        This schedule is very close to that of PipeDream, but runs forward\n        microbatches as much as possible to create more opportunity for\n        overlapping communication and computation. The trade-off is it uses more\n        memory to store intermediate activations for more microbatches.\n\n        Using the same notation as PipeDreamFlush, this schedule is as:\n        k (i,j)   (i,j)   (i,j)\n        - ------- ------- -------\n        0 (0,0,F)\n        1 (1,0,F) (0,1,F)\n        2 (2,0,F) (1,1,F) (0,2,F)\n        3 (3,0,F) (2,1,F) (0,2,B)\n        4 (4,0,F) (0,1,B) (1,2,F)\n        5 (0,0,B) (3,1,F) (1,2,B)\n        6 (5,0,F) (1,1,B) (2,2,F)\n        ...\n        The overlapping is only for forward communication but not for backward\n        due to data dependency.\n        \"\"\"\n        batch = self.num_batch\n        mesh = self.num_mesh\n\n        num_clock = (mesh + batch - 1) * 2\n        schedules = [[None] * mesh for _ in range(num_clock)]\n        for mesh_idx in range(mesh):\n            # The warmup batch number doubles\n            num_warmup_batch = min(batch, 2 * (mesh - mesh_idx) - 1)\n            fwd_stage_idx = mesh_idx\n            bwd_stage_idx = mesh * 2 - mesh_idx - 1\n            tic = mesh_idx\n            is_forward = True\n            fwd_idx = -1\n            bwd_idx = -1\n            for exec_idx in range(batch * 2):\n                if exec_idx >= num_warmup_batch:\n                    if ((is_forward and bwd_idx < batch - 1) or\n                        (not is_forward and fwd_idx < batch - 1)):\n                        is_forward = not is_forward\n                if is_forward:\n                    fwd_idx += 1\n                    schedules[tic][mesh_idx] = (fwd_idx, fwd_stage_idx)\n                else:\n                    bwd_idx += 1\n                    # Do not launch too early at cooldown period. This is for\n                    # potential use of centralized runtime or debug.\n                    min_available_tic = ((mesh - 1) + (bwd_idx * 2 + 1) +\n                                         (mesh - 1 - mesh_idx))\n                    final_tic = max(tic, min_available_tic)\n                    schedules[final_tic][mesh_idx] = (bwd_idx, bwd_stage_idx)\n                tic += 1\n\n        # append apply_grad schedules\n        scheds = [None] * mesh\n        for stage_idx, mesh_idx in self.apply_grad_placement.items():\n            scheds[mesh_idx] = (self.last_backward_batch_index, stage_idx)\n        schedules.append(scheds)\n        return schedules\n\n\npipeline_schedule: Dict[str, PipelineSchedule] = {}\npipeline_schedule[\"gpipe\"] = GpipeSchedule\npipeline_schedule[\"1f1b\"] = PipeDreamFlush\npipeline_schedule[\"inference\"] = InferenceSchedule\npipeline_schedule[\"1f1b_overlap_friendly\"] = OverlapFriendlyPipeDreamSchedule\n\n\ndef create_pipeline_schedule(name, dependency, meshes, apply_grad_placement,\n                             num_batch):\n    return pipeline_schedule[name](dependency=dependency,\n                                   meshes=meshes,\n                                   apply_grad_placement=apply_grad_placement,\n                                   num_batch=num_batch)\n"
  },
  {
    "path": "alpa/pipeline_parallel/stage_construction.py",
    "content": "\"\"\"\nCore implementations for stage construction algorithms.\nThe algorithm groups layers into pipeline stages.\n\"\"\"\nfrom dataclasses import dataclass\nimport logging\nfrom typing import Sequence, List, Tuple, Dict, Union, Optional\n\nfrom jax._src.lib import xla_extension as xe\nfrom jax.core import Var\nimport numpy as np\n\nfrom alpa.device_mesh import VirtualPhysicalMesh\nfrom alpa.global_env import global_config\nfrom alpa.pipeline_parallel.computation import (\n    JaxPipelineComputation, merge_marked_jaxprs_with_named_call)\nfrom alpa.pipeline_parallel.stage_profiling import (get_compute_cost,\n                                                    last_compute_cost_file_name)\nfrom alpa.shard_parallel.auto_sharding import AutoShardingOption\nfrom alpa.timer import timers\nfrom alpa.util import OrderedSet, maybe_numba_jit, jaxpr_to_hlo\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\n@dataclass\nclass AutoStageOption:\n    \"\"\"Options of auto stage construction algorithm.\"\"\"\n    # The search space of the physical submesh shapes.\n    # Possible choices: {\"power_of_two\", \"small_power_of_two\", \"all\"}.\n    submesh_physical_shape_space: str = \"power_of_two\"\n    # The search space of the logical mesh shapes.\n    # Possible choices: {\"same_as_physical\", \"data_parallel_only\",\n    #                    \"single_node_model_parallel\", \"all\", \"manual\"}.\n    # If \"manual\", the user needs to specify the logical mesh shape.\n    manually_specified_submeshes: Sequence[Tuple[int, int]] = None\n    # The search space for the logical mesh shapes.\n    # Possible choices: {\"all\", \"single_node_model_parallel\",\n    #                    \"same_as_physical\", \"data_parallel_only\",\n    #                    \"model_parallel_only\"}.\n    submesh_logical_shape_space: str = \"single_node_model_parallel\"\n    # Profile only individual layers or composition different layers.\n    # Possible choices: {\"individual\", \"composition\"}.\n    layer_profile_mode: str = \"composition\"\n    # The tolerance of imbalance in the auto-stage construction.\n    stage_imbalance_tolerance: float = np.inf\n    # Use HLO cost model for computational cost or profile for the cost.\n    use_hlo_cost_model: bool = False\n    # The filename of profiling result database.\n    profiling_database_filename: Optional[str] = None\n    # The file name of the cached compute cost.\n    cached_profile_result: Optional[str] = None\n\n\n@dataclass\nclass ManualStageOption:\n    \"\"\"Options of manual stage assignment.\"\"\"\n    # Layer IDs of each forward stage.\n    forward_stage_layer_ids: Sequence[Sequence[int]]\n    # The physical shapes of submeshes of each stage.\n    submesh_physical_shapes: Sequence[Sequence[int]]\n    # The logical shapes of submeshes of each stage.\n    submesh_logical_shapes: Sequence[Sequence[int]]\n    # The auto-sharding options of each stage.\n    submesh_autosharding_option_dicts: Sequence[dict]\n\n\n@dataclass\nclass UniformStageOption:\n    # The number of stages.\n    num_stages: int = None\n    # The physical shape of all submeshes.\n    submesh_physical_shape: Sequence[int] = None\n    # The logical shape of all submeshes.\n    submesh_logical_shape: Sequence[int] = None\n    # The auto-sharding option of all stages.\n    submesh_autosharding_option: dict = None\n\n\nStageOption = Union[AutoStageOption, ManualStageOption, UniformStageOption]\n\n# Get results for debugging\nlast_forward_stage_layer_ids = None\nlast_submesh_shapes = None\nlast_logical_mesh_shapes = None\nlast_autosharding_option_dicts = None\n\n\ndef get_last_dp_result():\n    \"\"\"Gets the DP result of the last run.\"\"\"\n    return (last_compute_cost_file_name, last_forward_stage_layer_ids,\n            last_submesh_shapes, last_logical_mesh_shapes,\n            last_autosharding_option_dicts)\n\n\n@maybe_numba_jit\ndef get_optimal_submeshes(best_s, f_argmin, num_devices, num_layers,\n                          submesh_n_devices):\n    current_s = best_s\n    current_layer = 0\n    current_devices = num_devices\n\n    res = []\n    while current_s > 0 and current_layer < num_layers and current_devices > 0:\n        next_start_layer, submesh_choice, autosharding_choice = (\n            f_argmin[current_s, current_layer, current_devices])\n        assert next_start_layer != -1 and current_devices != -1\n        res.append(((current_layer, next_start_layer), submesh_choice,\n                    autosharding_choice))\n        current_s -= 1\n        current_layer = next_start_layer\n        current_devices -= submesh_n_devices[submesh_choice]\n    assert (current_s == 0 and current_layer == num_layers and\n            current_devices == 0)\n\n    return res\n\n\n@maybe_numba_jit\ndef training_dp_impl_2(num_layers, num_devices, submesh_sizes,\n                       valid_idxs_and_costs, max_n_succ_stages):\n    f = np.full((num_layers + 1, num_layers + 1, num_devices + 1),\n                np.inf,\n                dtype=np.float32)\n    f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1),\n                          0.0,\n                          dtype=np.float32)\n    f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3),\n                       -1,\n                       dtype=np.int32)\n    f[0, num_layers, 0] = 0\n    for d in range(1, num_devices + 1):\n        for l, i, submesh_id, n_config, stage_cost in valid_idxs_and_costs:\n            l, i, submesh_id, n_config = map(int, (l, i, submesh_id, n_config))\n            n_submesh_devices = submesh_sizes[submesh_id]\n            if n_submesh_devices <= d:\n                for s in range(1, num_layers + 1):\n                    if s - 1 > max_n_succ_stages[l, i, submesh_id, n_config]:\n                        continue\n\n                    new_cost = f[s - 1, i + 1,\n                                 d - n_submesh_devices] + stage_cost\n                    if new_cost < f[s, l, d]:\n                        f[s, l, d] = new_cost\n                        f_argmin[s, l, d] = (i + 1, submesh_id, n_config)\n                        f_stage_max[s, l, d] = max(\n                            f_stage_max[s - 1, i + 1, d - n_submesh_devices],\n                            stage_cost)\n\n    return f, f_stage_max, f_argmin\n\n\ndef training_dp_2(\n    num_devices,\n    num_microbatches,\n    submesh_choices,\n    compute_cost,\n    max_n_succ_stages,\n):\n    \"\"\"Faster implementation of the training DP algorihtm.\"\"\"\n    # TODO(zhuohan): Further verify the correctness of this implementation.\n    timers(\"stage-construction-dp\").start()\n\n    num_layers = len(compute_cost)\n    all_possible_stage_costs = np.sort(np.unique(compute_cost))\n    best_cost = np.inf\n    best_solution = None\n    last_max_stage_cost = 0.0\n    # FIXME(zhuohan): Set this gap as a tunable parameter in global config\n    gap = 1e-6\n    assert len(\n        all_possible_stage_costs), \"no solution in auto stage construction.\"\n\n    submesh_sizes = np.array([n * m for (n, m) in submesh_choices],\n                             dtype=np.int64)\n\n    for max_stage_cost in all_possible_stage_costs:\n        if max_stage_cost - last_max_stage_cost < gap:\n            continue\n        if max_stage_cost * num_microbatches >= best_cost:\n            break\n\n        # Lifts check for stage_cost <= t_max_stage_cost out of the inner dp\n        # loop.\n        valid_cost_idxs = np.transpose(\n            (compute_cost <= max_stage_cost).nonzero())\n        # This corresponds to the i of k <= i <= K from eqn. 3 in the alpa\n        # paper.\n        valid_cost_idxs = valid_cost_idxs[\n            valid_cost_idxs[:, 0] <= valid_cost_idxs[:, 1]]\n        if len(valid_cost_idxs) == 0:\n            continue\n        valid_costs = compute_cost[tuple(valid_cost_idxs.T)]\n        valid_idxs_and_costs = np.hstack(\n            [valid_cost_idxs, valid_costs[:, np.newaxis]])\n        # Sort by descending layer idx because DP initializes\n        # F[0, num_layers, 0] = 0\n        valid_idxs_and_costs = valid_idxs_and_costs[np.flip(\n            valid_cost_idxs[:, 1].argsort())]\n\n        # Don't perform backtracking each time (do it only for the best\n        # solution).\n        f, f_stage_max, f_argmin = training_dp_impl_2(\n            num_layers,\n            num_devices,\n            submesh_sizes,\n            valid_idxs_and_costs,\n            max_n_succ_stages,\n        )\n\n        best_s = f[:, 0, num_devices].argmin()\n        best_total_cost = f[best_s, 0, num_devices]\n        if np.isinf(best_total_cost):\n            continue\n        stage_cost = (num_microbatches - 1) * f_stage_max[best_s, 0,\n                                                          num_devices]\n\n        if best_total_cost + stage_cost < best_cost:\n            best_cost = best_total_cost + stage_cost\n            best_solution = best_s, f_argmin\n        last_max_stage_cost = max_stage_cost\n\n    assert best_solution is not None, (\n        \"Unable to find any solution to inter-op dp.\")\n    best_s, f_argmin = best_solution\n    best_solution = get_optimal_submeshes(best_s, f_argmin, num_devices,\n                                          num_layers, submesh_sizes)\n\n    timers(\"stage-construction-dp\").stop()\n    return best_cost, best_solution\n\n\n@maybe_numba_jit\ndef training_dp_impl(num_layers, num_devices, num_microbatches, submesh_choices,\n                     num_autosharding_configs, compute_cost, max_n_succ_stages,\n                     max_stage_cost):\n    \"\"\"The core implementation of the DP algorithm.\"\"\"\n    # For f, layer ID start from 0\n    # f[#pipeline stages,\n    #   layer id that is currently being considered,\n    #   number of devices used]\n    f = np.full((num_layers + 1, num_layers + 1, num_devices + 1),\n                np.inf,\n                dtype=np.float32)\n    f_stage_max = np.full((num_layers + 1, num_layers + 1, num_devices + 1),\n                          0.0,\n                          dtype=np.float32)\n    f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3),\n                       -1,\n                       dtype=np.int32)\n    f[0, num_layers, 0] = 0\n    for s in range(1, num_layers + 1):  # pylint: disable=too-many-nested-blocks\n        for i in range(num_layers - 1, -1, -1):\n            for j in range(1, num_devices + 1):\n                for k in range(num_layers, i, -1):\n                    for m, submesh in enumerate(submesh_choices):\n                        n_submesh_devices = np.prod(np.array(submesh))\n                        if n_submesh_devices <= j:\n                            # TODO(zhuohan): This level of for loop is not\n                            #   necessary. It can be optimized by sorting\n                            #   the logical mesh shapes.\n                            for n_config in range(num_autosharding_configs):\n                                if s - 1 <= max_n_succ_stages[i, k - 1, m,\n                                                              n_config]:\n                                    stage_cost = compute_cost[i, k - 1, m,\n                                                              n_config]\n                                    new_cost = f[s - 1, k, j -\n                                                 n_submesh_devices] + stage_cost\n                                    if (stage_cost <= max_stage_cost and\n                                            new_cost < f[s, i, j]):\n                                        f[s, i, j] = new_cost\n                                        f_stage_max[s, i, j] = max(\n                                            f_stage_max[s - 1, k,\n                                                        j - n_submesh_devices],\n                                            stage_cost)\n                                        f_argmin[s, i, j] = (k, m, n_config)\n\n    best_s = -1\n    best_total_cost = np.inf\n    for s in range(1, num_layers + 1):\n        if f[s, 0, num_devices] < best_total_cost:\n            best_s = s\n            best_total_cost = f[s, 0, num_devices]\n\n    if np.isinf(best_total_cost):\n        return np.inf, None\n\n    total_cost = f[best_s, 0, num_devices] + (\n        num_microbatches - 1) * f_stage_max[best_s, 0, num_devices]\n    current_s = best_s\n    current_layer = 0\n    current_devices = num_devices\n\n    res = []\n    while current_s > 0 and current_layer < num_layers and current_devices > 0:\n        next_start_layer, submesh_choice, autosharding_choice = (\n            f_argmin[current_s, current_layer, current_devices])\n        assert next_start_layer != -1 and current_devices != -1\n        res.append(((current_layer, next_start_layer), submesh_choice,\n                    autosharding_choice))\n        current_s -= 1\n        current_layer = next_start_layer\n        current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))\n    assert (current_s == 0 and current_layer == num_layers and\n            current_devices == 0)\n\n    return total_cost, res\n\n\ndef training_dp(num_layers, num_devices, num_microbatches, submesh_choices,\n                num_autosharding_configs, compute_cost, max_n_succ_stages):\n    \"\"\"Auto stage dynamic programming.\"\"\"\n    timers(\"stage-construction-dp\").start()\n\n    all_possible_stage_costs = np.sort(np.unique(compute_cost))\n    best_cost = np.inf\n    best_solution = None\n    last_max_stage_cost = 0.0\n    # FIXME(zhuohan): Set this gap as a tunable parameter in global config\n    gap = 1e-6\n    assert len(\n        all_possible_stage_costs), \"no solution in auto stage construction.\"\n    for max_stage_cost in all_possible_stage_costs:\n        if max_stage_cost * num_microbatches >= best_cost:\n            break\n        if max_stage_cost - last_max_stage_cost < gap:\n            continue\n        cost, solution = training_dp_impl(num_layers, num_devices,\n                                          num_microbatches, submesh_choices,\n                                          num_autosharding_configs,\n                                          compute_cost, max_n_succ_stages,\n                                          max_stage_cost)\n        if cost < best_cost:\n            best_cost = cost\n            best_solution = solution\n        last_max_stage_cost = max_stage_cost\n\n    timers(\"stage-construction-dp\").stop()\n    return best_cost, best_solution\n\n\n@maybe_numba_jit\ndef inference_dp_impl(num_layers, num_devices, submesh_choices,\n                      num_autosharding_configs, compute_cost):\n    \"\"\"The core implementation of the DP algorithm.\"\"\"\n    # For f, layer ID start from 0\n    # f[#pipeline stages,\n    #   layer id that is currently being considered,\n    #   number of devices used]\n    f = np.full((num_layers + 1, num_layers + 1, num_devices + 1),\n                np.inf,\n                dtype=np.float32)\n    f_argmin = np.full((num_layers + 1, num_layers + 1, num_devices + 1, 3),\n                       -1,\n                       dtype=np.int32)\n    f[0, 0, 0] = 0\n    for s in range(1, num_layers + 1):  # pylint: disable=too-many-nested-blocks\n        for i in range(1, num_layers + 1):\n            for j in range(1, num_devices + 1):\n                for k in range(0, i):\n                    for m, submesh in enumerate(submesh_choices):\n                        n_submesh_devices = np.prod(np.array(submesh))\n                        if n_submesh_devices <= j:\n                            for n_config in range(num_autosharding_configs):\n                                stage_cost = compute_cost[k, i - 1, m, n_config]\n                                new_cost = max(\n                                    f[s - 1, k, j - n_submesh_devices],\n                                    stage_cost)\n                                if new_cost < f[s, i, j]:\n                                    f[s, i, j] = new_cost\n                                    f_argmin[s, i, j] = (k, m, n_config)\n\n    best_s = -1\n    best_total_cost = np.inf\n    for s in range(1, num_layers + 1):\n        if f[s, num_layers, num_devices] * s < best_total_cost:\n            best_s = s\n            best_total_cost = f[s, num_layers, num_devices] * s\n\n    if np.isinf(best_total_cost):\n        return np.inf, None\n\n    current_s = best_s\n    current_layer = num_layers\n    current_devices = num_devices\n\n    res = []\n    while current_s > 0 and current_layer > 0 and current_devices > 0:\n        next_end_layer, submesh_choice, autosharding_choice = (\n            f_argmin[current_s, current_layer, current_devices])\n        assert next_end_layer != -1\n        res.append(((next_end_layer, current_layer), submesh_choice,\n                    autosharding_choice))\n        current_s -= 1\n        current_layer = next_end_layer\n        current_devices -= np.prod(np.array(submesh_choices[submesh_choice]))\n    assert (current_s == 0 and current_layer == 0 and current_devices == 0)\n\n    return best_total_cost, res\n\n\ndef inference_dp(num_layers, num_devices, submesh_choices,\n                 num_autosharding_configs, compute_cost):\n    \"\"\"Auto stage dynamic programming.\"\"\"\n    timers(\"stage-construction-dp\").start()\n    cost, solution = inference_dp_impl(num_layers, num_devices, submesh_choices,\n                                       num_autosharding_configs, compute_cost)\n    solution = list(reversed(solution))\n    timers(\"stage-construction-dp\").stop()\n    return cost, solution\n\n\ndef get_submesh_choices(\n        num_hosts: int,\n        num_devices_per_host: int,\n        space: str,\n        manually_specified_submeshes: Optional[Sequence[Tuple[int,\n                                                              int]]] = None):\n    \"\"\"Gets the valid choices of submesh shapes.\"\"\"\n    if global_config.overwrite_submesh_choices is not None:\n        return global_config.overwrite_submesh_choices\n    submesh_choices = []\n\n    # smaller submeshes:\n    i = 1\n    while i <= num_devices_per_host:\n        submesh_choices.append((1, i))\n        i *= 2\n    assert submesh_choices[-1][1] == num_devices_per_host, (\n        \"Only supports the cases where num_devices_per_host is power of two, \"\n        f\"while now num_devices_per_host = {num_devices_per_host}\")\n\n    # larger meshes:\n    if space == \"all\":\n        for i in range(2, num_hosts + 1):\n            submesh_choices.append((i, num_devices_per_host))\n    elif space == \"power_of_two\":\n        i = 2\n        while i <= num_hosts:\n            submesh_choices.append((i, num_devices_per_host))\n            i *= 2\n    elif space == \"small_power_of_two\":\n        i = 2\n        while i <= min(num_hosts, 4):\n            submesh_choices.append((i, num_devices_per_host))\n            i *= 2\n    elif space == \"manual\":\n        submesh_choices = manually_specified_submeshes\n    else:\n        raise ValueError(f\"Invalid submesh space: {space}\")\n\n    return tuple(submesh_choices)\n\n\ndef get_one_submesh_autosharding_config_choices(\n        virtual_submesh: VirtualPhysicalMesh, space: str, batch_size: int):\n    \"\"\"\n    Return a list of logical meshes and autosharding configs.\n    Which will be used by the auto stage construction algorithm.\n\n    Args:\n        virtual_submesh: a submesh.\n        space: The search space of the logical mesh shapes.\n            possible choices: {\"same_as_physical\", \"data_parallel_only\",\n                               \"single_node_model_parallel\", \"all\"}.\n        batch_size: the batch size used.\n    \"\"\"\n    results = []\n    num_devices = virtual_submesh.num_devices\n    if space in [\"all\", \"single_node_model_parallel\"]:\n        if space == \"all\":\n            max_mp_dimension = num_devices\n        else:  # space == \"single_node_model_parallel\"\n            max_mp_dimension = virtual_submesh.num_devices_per_host\n\n        for mp_size in range(1, max_mp_dimension + 1):\n            if num_devices % mp_size == 0:\n                dp_size = num_devices // mp_size\n                if batch_size % dp_size == 0:\n                    results.append((virtual_submesh.get_logical_mesh(\n                        (dp_size, mp_size)), {\n                            \"force_batch_dim_to_mesh_dim\": 0\n                        }))\n        results.append((virtual_submesh.get_logical_mesh((num_devices, 1)), {}))\n    elif space == \"same_as_physical\":\n        results.append((virtual_submesh.get_logical_mesh(), {}))\n    elif space == \"data_parallel_only\":\n        results.append((virtual_submesh.get_logical_mesh((num_devices, 1)), {\n            \"force_batch_dim_to_mesh_dim\": 0\n        }))\n    elif space == \"model_parallel_only\":\n        results.append((virtual_submesh.get_logical_mesh((1, num_devices)), {\n            \"force_batch_dim_to_mesh_dim\": 0\n        }))\n    else:\n        raise ValueError(f\"Invalid space for get_one_submesh_autosharding\"\n                         f\"_config_choices: {space}\")\n    return results\n\n\ndef get_all_submesh_autosharding_config_choices(virtual_mesh, submesh_choices,\n                                                space, batch_size):\n    \"\"\"Get all possible auto sharding config choices for all possible submesh\n    shapes.\"\"\"\n    # A config is: Tuple(logical_mesh_shape, autosharding_option_dict).\n    # Enumerate all (2D Mesh with force batch dim) + one (1D Mesh with mix batch\n    # dim).\n    autosharding_configs = []\n    for submesh in submesh_choices:\n        num_hosts, num_devices_per_host = submesh\n        virtual_submesh = virtual_mesh.slice_2d(\n            tuple(range(num_hosts)),\n            (tuple(range(num_devices_per_host)),) * num_hosts)\n        submesh_autosharding_configs = (\n            get_one_submesh_autosharding_config_choices(virtual_submesh, space,\n                                                        batch_size))\n        autosharding_configs.append(submesh_autosharding_configs)\n\n    # Pad all submesh to the maximum number of configs\n    max_num_autosharding_configs = max(\n        len(configs) for configs in autosharding_configs)\n    for configs in autosharding_configs:\n        configs += [None] * (max_num_autosharding_configs - len(configs))\n\n    return autosharding_configs\n\n\ndef get_sliced_virtual_submeshes(virtual_mesh, submesh_shapes):\n    \"\"\"Slice the origin mesh into submeshes given submesh shapes.\"\"\"\n    num_hosts = virtual_mesh.num_hosts\n    num_devices_per_host = virtual_mesh.num_devices_per_host\n    submesh_sizes = [np.prod(submesh) for submesh in submesh_shapes]\n    virtual_submeshes = [None] * len(submesh_shapes)\n    assert sum(submesh_sizes) == virtual_mesh.num_devices\n    sorted_submesh_indices = np.argsort(submesh_sizes, kind=\"stable\")\n    current_host_id = 0\n    current_device_id = 0\n    for i in reversed(sorted_submesh_indices):\n        required_num_hosts, required_num_devices = submesh_shapes[i]\n        if required_num_devices == num_devices_per_host:\n            assert current_device_id == 0\n            assert current_host_id + required_num_hosts <= num_hosts, (\n                \"Do not have enough hosts for the solution.\")\n            virtual_submeshes[i] = virtual_mesh.slice_2d(\n                tuple(\n                    range(current_host_id,\n                          current_host_id + required_num_hosts)),\n                (tuple(range(num_devices_per_host)),) * required_num_hosts)\n            current_host_id += required_num_hosts\n        else:\n            assert required_num_hosts == 1\n            assert required_num_devices < num_devices_per_host\n            assert (current_device_id + required_num_devices <=\n                    num_devices_per_host), (\n                        \"Do not have enough devices in a host for the solution\")\n            virtual_submeshes[i] = virtual_mesh.slice_2d([current_host_id], [\n                tuple(\n                    range(current_device_id,\n                          current_device_id + required_num_devices))\n            ])\n            current_device_id += required_num_devices\n            if current_device_id == num_devices_per_host:\n                current_host_id += 1\n                current_device_id = 0\n    assert current_host_id == num_hosts\n    assert current_device_id == 0\n    return virtual_submeshes\n\n\ndef cluster_layers_and_slice_mesh(\n        layers: Sequence[JaxPipelineComputation],\n        virtual_mesh: VirtualPhysicalMesh, accumulator_mapping: Dict[Var, Var],\n        acc_grad_invars: Sequence[Var], acc_grad_outvars: Sequence[Var],\n        num_micro_batches: int, batch_size: int,\n        jax_apply_layers: Sequence[JaxPipelineComputation],\n        apply_grad_global_info: Tuple, pipeline_schedule: str,\n        default_as_option: AutoShardingOption, stage_option: StageOption):\n    \"\"\"\n    Stage-mesh assignment.\n\n    This function clusters pipeline layers into stages, slice the device\n    mesh into multiple submeshes, and assign the stages to the submeshes.\n    We first profile the compute cost of layers on different choices\n    of submeshes and find the optimal solution with DP.\n\n    Args:\n        layers: All the layers.\n        virtual_mesh: The virtual device mesh.\n        accumulator_mapping: The donation_mapping for the layers.\n        acc_grad_invars: invars of the gradient accumulation layers.\n        acc_grad_outvars: outvars of the gradient accumulation layers.\n        num_micro_batches: The number of microbatches.\n        batch_size: The micro batch size.\n        jax_apply_layers: The apply gradient computations corresponding\n          to each forward layers.\n        pipeline_schedule: The pipeline schedule.\n        default_as_option: The default auto-sharding option.\n        stage_option: The options controling how to construct stages.\n    \"\"\"\n    timers(\"stage-construction\").start()\n\n    inference_mode = (pipeline_schedule == \"inference\")\n    if virtual_mesh.launched_physical_mesh_group is None:\n        given_mesh = False\n    else:\n        given_mesh = True\n\n    if inference_mode:\n        num_layers = len(layers)\n    else:\n        # Assume each forward layer corresponds to a backward layer\n        assert len(layers) % 2 == 0\n        num_layers = len(layers) // 2\n\n    if isinstance(stage_option, AutoStageOption):\n        if given_mesh:\n            # TODO(zhuohan): Implement the auto slicing with given mesh.\n            raise NotImplementedError(\"automatically slicing layers with \"\n                                      \"existing physical meshes is not\"\n                                      \"supported yet.\")\n\n        submesh_choices = get_submesh_choices(\n            virtual_mesh.num_hosts, virtual_mesh.num_devices_per_host,\n            stage_option.submesh_physical_shape_space,\n            stage_option.manually_specified_submeshes)\n        autosharding_configs = get_all_submesh_autosharding_config_choices(\n            virtual_mesh, submesh_choices,\n            stage_option.submesh_logical_shape_space, batch_size)\n        num_autosharding_configs = len(autosharding_configs[0])\n\n        # Use DP to find the optimal solution.\n        compute_cost, max_n_succ_stages = get_compute_cost(\n            virtual_mesh, submesh_choices, autosharding_configs, layers,\n            accumulator_mapping, acc_grad_invars, acc_grad_outvars,\n            jax_apply_layers, apply_grad_global_info, num_micro_batches,\n            default_as_option, stage_option, inference_mode)\n        if inference_mode:\n            _, solution = inference_dp(num_layers, virtual_mesh.num_devices,\n                                       submesh_choices,\n                                       num_autosharding_configs, compute_cost)\n        else:\n            _, solution = training_dp(num_layers, virtual_mesh.num_devices,\n                                      num_micro_batches, submesh_choices,\n                                      num_autosharding_configs, compute_cost,\n                                      max_n_succ_stages)\n\n        assert solution is not None, \"no solution in auto stage construction.\"\n\n        # Parse solution\n        forward_stage_layer_ids = [\n            list(range(start_id, end_id))\n            for (start_id, end_id), _, _ in solution\n        ]\n        submesh_shapes = [\n            submesh_choices[submesh_id] for _, submesh_id, _ in solution\n        ]\n        selected_autosharding_configs = [\n            autosharding_configs[submesh_id][autosharding_config_id]\n            for _, submesh_id, autosharding_config_id in solution\n        ]\n        logical_mesh_shapes = [\n            mesh.shape for mesh, _ in selected_autosharding_configs\n        ]\n        autosharding_option_dicts = [\n            option_dict for _, option_dict in selected_autosharding_configs\n        ]\n\n        # Print and store the results\n        print(\"Result forward_stage_layer_ids:\", forward_stage_layer_ids)\n        print(\"Result mesh_shapes:\", submesh_shapes)\n        print(\"Result logical_mesh_shapes:\", logical_mesh_shapes)\n        print(\"Result autosharding_option_dicts:\", autosharding_option_dicts)\n        global last_forward_stage_layer_ids, last_submesh_shapes\n        global last_logical_mesh_shapes, last_autosharding_option_dicts\n        last_forward_stage_layer_ids = forward_stage_layer_ids\n        last_submesh_shapes = submesh_shapes\n        last_logical_mesh_shapes = logical_mesh_shapes\n        last_autosharding_option_dicts = autosharding_option_dicts\n    elif isinstance(stage_option, ManualStageOption):\n        # Check forward_stage_layer_ids is a partition of range(num_layers)\n        forward_stage_layer_ids = stage_option.forward_stage_layer_ids\n        last_layer_id = 0\n        for stage_layer_ids in forward_stage_layer_ids:\n            for layer_id in stage_layer_ids:\n                assert layer_id == last_layer_id\n                last_layer_id += 1\n        assert last_layer_id == num_layers, (\n            f\"{last_layer_id} layers in stage option, but {num_layers} marked\")\n        submesh_shapes = stage_option.submesh_physical_shapes\n        logical_mesh_shapes = (stage_option.submesh_logical_shapes or\n                               submesh_shapes)\n        autosharding_option_dicts = (\n            stage_option.submesh_autosharding_option_dicts)\n    elif isinstance(stage_option, UniformStageOption):\n        num_stages = stage_option.num_stages or num_layers\n        if stage_option.submesh_physical_shape is not None:\n            assert stage_option.submesh_logical_shape is not None\n            submesh_logical_shape = stage_option.submesh_logical_shape\n            submesh_shapes = [stage_option.submesh_physical_shape] * num_stages\n            logical_mesh_shapes = [submesh_logical_shape] * num_stages\n            assert virtual_mesh.num_devices == np.prod(\n                submesh_logical_shape) * num_stages\n            forward_stage_layer_ids = _cluster_layers_with_even_tflops(\n                layers[:num_layers], num_stages)\n            autosharding_option = stage_option.submesh_autosharding_option\n            if autosharding_option is None:\n                autosharding_option = {}\n            autosharding_option_dicts = [autosharding_option] * num_stages\n        else:\n            if given_mesh:\n                submesh_shapes = [\n                    x.shape\n                    for x in virtual_mesh.launched_physical_mesh_group.meshes\n                ]\n                logical_mesh_shapes = submesh_shapes\n            else:\n                num_devices = virtual_mesh.num_devices\n\n                assert num_devices >= num_stages, \"No enough devices\"\n                assert num_devices % num_stages == 0\n                num_devices_per_mesh = num_devices // num_stages\n                if num_devices_per_mesh > virtual_mesh.num_devices_per_host:\n                    assert (num_devices_per_mesh %\n                            virtual_mesh.num_devices_per_host == 0)\n                    submesh_shape = (num_devices_per_mesh //\n                                     virtual_mesh.num_devices_per_host,\n                                     virtual_mesh.num_devices_per_host)\n                else:\n                    assert (virtual_mesh.num_devices_per_host %\n                            num_devices_per_mesh == 0)\n                    submesh_shape = (1, num_devices_per_mesh)\n                submesh_shapes = [submesh_shape] * num_stages\n                logical_mesh_shapes = [submesh_shape] * num_stages\n\n            forward_stage_layer_ids = [[i] for i in range(num_layers)]\n            autosharding_option_dicts = [{}] * num_stages\n    else:\n        raise ValueError(f\"Invalid pipeline stage option: {stage_option}\")\n\n    if given_mesh:\n        sliced_meshes = [\n            mesh.get_virtual_physical_mesh()\n            for mesh in virtual_mesh.launched_physical_mesh_group\n        ]\n    else:\n        sliced_meshes = get_sliced_virtual_submeshes(virtual_mesh,\n                                                     submesh_shapes)\n\n    num_forward_stages = len(forward_stage_layer_ids)\n\n    if inference_mode:\n        stage_layer_ids = forward_stage_layer_ids\n        stage_to_mesh = list(range(num_forward_stages))\n    else:\n        backward_stage_layer_ids = [[\n            2 * num_layers - 1 - i for i in reversed(layer_ids)\n        ] for layer_ids in reversed(forward_stage_layer_ids)]\n        stage_layer_ids = forward_stage_layer_ids + backward_stage_layer_ids\n        stage_to_mesh = list(range(num_forward_stages)) + list(\n            reversed(range(num_forward_stages)))\n\n    stage_outvars = get_stage_outvars(layers, stage_layer_ids, acc_grad_outvars)\n    merged_stages = []\n    for stage_id, layer_ids in enumerate(stage_layer_ids):\n        if len(layer_ids) == 1:\n            merged_stages.append(layers[layer_ids[0]])\n            continue\n\n        stage_layer_jaxprs = [layers[i].closed_jaxpr() for i in layer_ids]\n        stage_name = str(stage_id)\n        merged_stage_jaxpr = merge_marked_jaxprs_with_named_call(\n            stage_layer_jaxprs,\n            stage_outvars[stage_id],\n            accumulator_mapping,\n            stage_name,\n            wrap_with_marker=True)\n        merged_stage = JaxPipelineComputation.from_closed_jaxpr(\n            stage_name, merged_stage_jaxpr)\n        merged_stages.append(merged_stage)\n    stages = merged_stages\n\n    # Check the validity of logical mesh shapes\n    assert len(logical_mesh_shapes) == len(sliced_meshes)\n    for logical_mesh_shape, submesh in zip(logical_mesh_shapes, sliced_meshes):\n        assert np.prod(logical_mesh_shape) == submesh.num_devices\n\n    if autosharding_option_dicts is not None:\n        assert len(autosharding_option_dicts) == len(sliced_meshes)\n    else:\n        autosharding_option_dicts = [{}] * len(sliced_meshes)\n\n    manual_stage_option = ManualStageOption(\n        forward_stage_layer_ids, tuple(x.shape for x in sliced_meshes),\n        logical_mesh_shapes, autosharding_option_dicts)\n\n    timers(\"stage-construction\").stop()\n    return stages, stage_to_mesh, sliced_meshes, manual_stage_option\n\n\ndef get_stage_outvars(layers: Sequence[JaxPipelineComputation],\n                      layer_assignment, global_outvars) -> List[OrderedSet]:\n    \"\"\"\n    Get the outvars of a stage used by another stage by liveness analysis.\n\n    Args:\n        layers: clustered layers\n        layer_assignment: the assignment of layers to stages\n        global_outvars: global outvars\n\n    Returns:\n        A list of outvars for each stage\n    \"\"\"\n    n_stages = len(layer_assignment)\n    used = OrderedSet(global_outvars)\n    stage_outvars = [OrderedSet() for _ in range(n_stages)]\n    for stage_id, layer_ids in reversed(list(enumerate(layer_assignment))):\n        for layer_id in layer_ids:\n            for var in layers[layer_id].outvars:\n                if var in used:\n                    stage_outvars[stage_id].add(var)\n            for var in layers[layer_id].invars:\n                used.add(var)\n    return stage_outvars\n\n\ndef _cluster_layers_with_even_tflops(layers, num_stage):\n    # prefix sum: total flops till layer_i\n    flops = [0]\n    for layer in layers:\n        hlo = jaxpr_to_hlo(\"tmp\", layer.closed_jaxpr(),\n                           [False] * len(layer.invars))\n        layer_flops = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module())\n        flops.append(flops[-1] + layer_flops)\n    avg_flop = flops[-1] / num_stage\n    # the last one is to avoid IndexError\n    flops = flops[1:] + [flops[-1] + 1]\n    forward_layer_ids = [[-1]]\n    nxt_bound = avg_flop\n    for i in range(len(layers)):\n        # if flops already exceeds threshold or cutting at current layer is\n        # closer to the ideal average, then choose it to cut.\n        # The first condition is to avoid a too large layer that occupies\n        # several times of average flops\n        if ((flops[i] >= nxt_bound * (1 - 1e-5)) or\n            (flops[i + 1] >= nxt_bound and\n             abs(flops[i + 1] - nxt_bound) > abs(flops[i] - nxt_bound))):\n            nxt_bound += avg_flop\n            forward_layer_ids.append(\n                tuple(range(forward_layer_ids[-1][-1] + 1, i + 1)))\n    forward_layer_ids = forward_layer_ids[1:]\n    return forward_layer_ids\n"
  },
  {
    "path": "alpa/pipeline_parallel/stage_profiling.py",
    "content": "\"\"\"Functionalities about profiling the stages.\"\"\"\nfrom abc import ABC, abstractmethod\nfrom collections import namedtuple\nimport dataclasses\nfrom time import time\nfrom datetime import datetime\nimport gc\nimport logging\nimport pickle\nfrom typing import Dict, Sequence, Tuple\n\nimport jax.numpy as jnp\nfrom jax.core import (ClosedJaxpr, Var, gensym)\nfrom jax.interpreters import pxla\nfrom jax._src.lib import xla_bridge as xb, xla_extension as xe\nimport numpy as np\nimport tqdm\nimport ray\nfrom ray.exceptions import RayActorError\nfrom ray.util import ActorPool\n\nfrom alpa.device_mesh import (DistributedArray, PhysicalDeviceMesh,\n                              VirtualPhysicalMesh, _shard_device_array,\n                              get_global_cluster)\nfrom alpa.global_env import global_config\nfrom alpa.mesh_executable import (PartialGradAccMeshDriverExecutable,\n                                  get_grad_sync_channel_ids)\nfrom alpa.mesh_profiling import (ProfilingResultDatabase,\n                                 estimate_hlo_module_cost)\nfrom alpa.pipeline_parallel.apply_grad import APPLY_GRAD_MARKER_SUFFIX\nfrom alpa.pipeline_parallel.computation import (\n    JaxPipelineComputation, get_local_donation_mapping_and_add_missing_invars,\n    merge_marked_jaxprs_with_named_call, merge_unmarked_with_call)\nfrom alpa.pipeline_parallel.cross_mesh_resharding import (\n    CrossMeshCommunicator, SymbolicReshardingTask, CollectiveGroup,\n    ReshardingTaskSpec, SymbolicBroadcastReshardingTask)\nfrom alpa.pipeline_parallel.layer_stats import eqn_flops\nfrom alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray\nfrom alpa.shard_parallel.auto_sharding import (AutoShardingOption,\n                                               LogicalDeviceMesh,\n                                               run_auto_sharding_pass,\n                                               run_spmd_partitioner_pass,\n                                               run_backend_compilation,\n                                               hlo_sharding_to_sharding_spec)\nfrom alpa.timer import timers\nfrom alpa.util import (get_shard_shape, jaxpr_to_hlo, OrderedSet,\n                       retrieve_placement_group, get_num_available_gpus,\n                       setup_computation_alias)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\nlast_compute_cost_file_name = None\n\nINFINITY_N_STAGES = 2**20\nGB = 1024**3\n\nModuleCompileOutput = namedtuple(\n    \"ModuleCompileOutput\",\n    [\"hlo\", \"input_sharding_protos\", \"output_sharding_proto\"])\n\nCompileOutput = namedtuple(\"CompileOutput\", [\n    \"acc_grad_module_compile_outputs\", \"stage_plan\",\n    \"apply_grad_input_sharding_protos\"\n])\n\nCompileConfig = namedtuple(\n    \"CompileConfig\",\n    [\"hlo\", \"names\", \"module_donate_invars\", \"module_acc_grad_outvars_indices\"])\n\nModuleProfileConfig = namedtuple(\"ModuleProfileConfig\", [\n    \"invar_names\", \"outvar_names\", \"invar_avals\", \"outvar_avals\",\n    \"donated_invars\", \"acc_grad_invars_indices\", \"acc_grad_outvars_indices\"\n])\n\nApplyGradConfig = namedtuple(\"ApplyGradConfig\",\n                             [\"invars\", \"apply_grad_only_invars\"])\n\nStageConfig = namedtuple(\"StageConfig\", [\n    \"n_modules\", \"compile_config\", \"module_profile_configs\", \"apply_grad_config\"\n])\n\n\nclass ModuleProfileResult(\n        namedtuple(\"ModuleProfileResult\", [\n            \"compute_cost\", \"peak_memory\", \"temp_buffer_size\", \"invar_names\",\n            \"outvar_names\", \"invar_sizes\", \"outvar_sizes\", \"donated_invars\",\n            \"acc_grad_invars_indices\", \"acc_grad_outvars_indices\",\n            \"available_memory\"\n        ])):\n    \"\"\"Profile result of a module.\"\"\"\n\n    def __str__(self):\n        invar_size = sum(self.invar_sizes)\n        outvar_size = sum(self.outvar_sizes)\n        return (f\"ModuleProfileResult(\"\n                f\"compute_cost={self.compute_cost:.3f}, \"\n                f\"peak_memory={self.peak_memory / GB:.3f} GB, \"\n                f\"invar_size={invar_size / GB:.3f} GB, \"\n                f\"outvar_size={outvar_size / GB:.3f} GB, \"\n                f\"temp_buffer_size={self.temp_buffer_size / GB:.3f} GB, \"\n                f\"available_memory={self.available_memory / GB:.3f} GB)\")\n\n\nclass StageProfileResult:\n    \"\"\"Profile result of a stage.\"\"\"\n\n    def __init__(self, n_modules, initial_var_names, initial_var_sizes):\n        self.n_modules = n_modules\n        self.module_profile_results: Sequence[ModuleProfileResult] = [\n            None\n        ] * n_modules\n        self.available_memory = None\n        self.initial_var_names = tuple(initial_var_names)\n        self.initial_var_sizes = tuple(initial_var_sizes)\n\n    def fully_profiled(self):\n        return all(r is not None for r in self.module_profile_results)\n\n    def is_module_profiled(self, module_idx):\n        return self.module_profile_results[module_idx] is not None\n\n    def add_module_profile_result(self, module_idx, result):\n        self.module_profile_results[module_idx] = result\n        if self.available_memory is None:\n            self.available_memory = result.available_memory\n        else:\n            self.available_memory = min(self.available_memory,\n                                        result.available_memory)\n\n    def __str__(self):\n        total_initial_var_size = sum(self.initial_var_sizes)\n        return (f\"StageProfileResult(\"\n                f\"available_memory={self.available_memory / GB:.3f} GB, \"\n                f\"initial_var_size={total_initial_var_size / GB:.3f} GB, \"\n                f\"module_profile_results={self.module_profile_results})\")\n\n\nclass BaseWorkerPoolWrapper(ABC):\n    \"\"\"Basic wrapper of ray's ActorPool.\"\"\"\n\n    @abstractmethod\n    def __init__(self):\n        self.actors = None\n        self.pool = None\n        self.is_shutdown = False\n\n    def submit(self, fn, value):\n        \"\"\"See ray.util.ActorPool.submit.\"\"\"\n        self.pool.submit(fn, value)\n\n    def get_next(self):\n        \"\"\"See ray.util.ActorPool.get_next.\"\"\"\n        return self.pool.get_next()\n\n    def get_next_unordered(self):\n        \"\"\"See ray.util.ActorPool.get_next_unordered.\"\"\"\n        return self.pool.get_next_unordered(\n            timeout=global_config.profile_timeout)\n\n    def shutdown(self, force=True):\n        \"\"\"Shut down the worker.\"\"\"\n        for w in self.actors:\n            if force:\n                ray.kill(w)\n            else:\n                w.__ray_terminate__.remote()\n        gc.collect()\n        self.is_shutdown = True\n\n    def __del__(self):\n        if not self.is_shutdown:\n            self.shutdown()\n\n\ndef get_input_output_sharding_proto(hlo_module, num_devices):\n    \"\"\"Given proto of XlaComputation, return its input and output sharding.\"\"\"\n    if num_devices <= 1:\n        return None, None\n    hlo_module.infer_spmd_shardings()\n    input_shardings = hlo_module.spmd_parameters_shardings()\n    output_sharding = hlo_module.spmd_output_sharding()\n    input_sharding_protos = [\n        x.to_proto().SerializeToString() for x in input_shardings\n    ]\n    output_sharding_proto = output_sharding.to_proto().SerializeToString()\n    return input_sharding_protos, output_sharding_proto\n\n\nclass CompileWorker:\n    \"\"\"\n    A ray actor to compile Jaxpr to HLO Proto using distributed workers.\n\n    To activate the worker, a gpu resource is required.\n    \"\"\"\n\n    def compile_stage_for_profiling(self, stage_id, config: CompileConfig,\n                                    logical_mesh, autosharding_option,\n                                    num_micro_batches):\n        \"\"\"\n        Compile a single stage with auto sharding for profiling.\n\n        Args:\n            stage_id: the index of the input stage.\n            config: configs for compilation.\n            logical_mesh: the logical mesh for compilation.\n            autosharding_option: the global config dictionary for compilation\n                setting.\n            num_micro_batches: the number of microbatches.\n\n        Returns:\n            hlo: The WrappedHlo of the compiled executable for accumulate grad\n            stage_plan: The sharding strategy from auto sharding\n            input_sharding_protos: The proto of accumulate grad's input sharding\n            output_sharding_protos: same as above\n            hooked_proto: The proto of variables from forward to backward\n        \"\"\"\n\n        # Compile with search to get sharding annotations.\n        other_kwargs = {\n            \"logical_mesh\": logical_mesh,\n            \"return_mode\": \"stages\",\n            \"as_option\": autosharding_option,\n            \"num_micro_batches\": num_micro_batches,\n            \"memory_budget_per_device\": None,\n        }\n        try:\n            # pylint: disable=unbalanced-tuple-unpacking\n            module_names, hlos, stage_plan = (run_auto_sharding_pass(\n                config.hlo, **other_kwargs))\n        except RuntimeError as e:\n            logger.warning(f\"Compilation error (auto-sharding pass) \"\n                           f\"for stage {stage_id} : {e}\")\n            return stage_id, None\n\n        # Read input/output shardings\n        hlo_dict = dict(zip(module_names, hlos))\n\n        assert (sum(\n            name.endswith(APPLY_GRAD_MARKER_SUFFIX) for name in config.names) <=\n                1), (\"Only one apply grad module is allowed in a single stage.\")\n\n        acc_grad_module_compile_outputs = []\n        apply_grad_input_sharding_protos = None\n\n        for module_id, module_name in enumerate(config.names):\n            hlo = hlo_dict[module_name]\n            setup_computation_alias(hlo, config.module_donate_invars[module_id])\n            module = hlo.get_module()\n            if module_name.endswith(APPLY_GRAD_MARKER_SUFFIX):\n                apply_grad_input_sharding_protos, _ = (\n                    get_input_output_sharding_proto(module,\n                                                    logical_mesh.num_devices))\n            else:\n                acc_grad_outvars_indices = (\n                    config.module_acc_grad_outvars_indices[module_id])\n                rewrite_for_grad_acc = len(acc_grad_outvars_indices) > 0\n                (input_sharding_protos,\n                 output_sharding_proto) = get_input_output_sharding_proto(\n                     module, logical_mesh.num_devices)\n\n                # Compile accumulate_grad part to fully optimized\n                try:\n                    optimized_hlo = run_spmd_partitioner_pass(\n                        hlo,\n                        logical_mesh.num_devices,\n                        rewrite_for_grad_acc=rewrite_for_grad_acc,\n                        rewrite_grad_acc_indices=acc_grad_outvars_indices)\n                except IndexError as e:\n                    logger.warning(f\"Compilation error (spmd partitioner pass) \"\n                                   f\"for stage {stage_id} : {e}\")\n                    return stage_id, None\n                acc_grad_module_compile_outputs.append(\n                    ModuleCompileOutput(optimized_hlo, input_sharding_protos,\n                                        output_sharding_proto))\n\n        return stage_id, CompileOutput(acc_grad_module_compile_outputs,\n                                       stage_plan,\n                                       apply_grad_input_sharding_protos)\n\n    @staticmethod\n    def run_auto_sharding_pass(stage_id, hlo, other_kwargs):\n        \"\"\"Run auto-sharding pass on a WrappedHlo.\"\"\"\n        assert other_kwargs[\"return_mode\"] == \"stages\"\n        # pylint: disable=unbalanced-tuple-unpacking\n        hlo_stage_names, hlo_stages, stage_plan = run_auto_sharding_pass(\n            hlo, **other_kwargs)\n        return stage_id, (hlo_stage_names, hlo_stages, stage_plan)\n\n\nclass CompileWorkerPool(BaseWorkerPoolWrapper):\n    \"\"\"A pool of CompileWorker for distributed compilation.\"\"\"\n\n    def __init__(self, num_cpus, debug_mode=False):\n        super().__init__()\n        worker_cls = ray.remote(num_cpus=1)(CompileWorker)\n        self.actors = [worker_cls.remote() for _ in range(num_cpus)]\n        self.pool = ActorPool(self.actors)\n        self.local_worker = CompileWorker() if debug_mode else None\n\n    def local_get(self, fn, *value):\n        \"\"\"Debug use function.\n\n        This function submits the work to local worker instead of a remote ray\n        actor to help with debug.\n        \"\"\"\n        return fn(self.local_worker, *value)\n\n\nclass ProfileWorker:\n    \"\"\"A ray actor to profile a WrappedHlo on a given mesh.\n\n    It requests gpu resources from ray. When exceptions is catched, it restarts\n    the whole mesh.\n    \"\"\"\n\n    def __init__(self, virtual_mesh: VirtualPhysicalMesh):\n        self.mesh = virtual_mesh.get_physical_mesh()\n        self.virtual_mesh = virtual_mesh\n\n    def _profile_impl(self, stage_id, compiled_module_output, stage_plan,\n                      profile_config):\n        \"\"\"Implementation of profile function.\n\n        The profiler first compile the WrappedHLO into Mesh Executable, then\n        profiles the executable and computes the maximal number of stages\n        following up this stage.\n\n        Args:\n            stage_id: the stage id of the proto.\n            compiled_module_output: Compiled WrappedHlo, input sharding,\n                spec and output sharding spec.\n            stage_plan: The compiled sharding strategy from the auto sharding\n                pass.\n            profile_config: Profile config of the module.\n\n        Returns:\n            stage_id: the input stage id.\n            cost (float): the time to run the profiled stage.\n            max_stage: maximal number of stages following up this stage.\n            debug_info: other profiled outputs for debug use. This includes\n                peak memory during the computation, the total available memory,\n                the input intermediate size and input initial size.\n        \"\"\"\n        input_avals = profile_config.invar_avals\n        output_avals = profile_config.outvar_avals\n        donated_invars = profile_config.donated_invars\n        input_shardings = compiled_module_output.input_sharding_protos\n        output_sharding = compiled_module_output.output_sharding_proto\n        hlo = compiled_module_output.hlo\n        hlo_module = hlo.get_module()\n        if input_shardings is not None:\n            hlo_module.set_spmd_parameters_shardings(\n                [xe.HloSharding(x) for x in input_shardings])\n            hlo_module.set_spmd_output_sharding(xe.HloSharding(output_sharding))\n        executable = PartialGradAccMeshDriverExecutable(self.mesh, hlo,\n                                                        stage_plan, input_avals,\n                                                        output_avals,\n                                                        donated_invars)\n\n        # Run profiling\n        self.mesh.reset_memory_stats()\n        peak_memory = executable.get_total_allocation_size()\n        available_memory = self.mesh.get_available_memory()\n        cost = executable.profile_with_dummy_inputs(skip_grad_sync=True)\n        del executable\n\n        return stage_id, cost, peak_memory, available_memory\n\n    def profile(self, stage_id, compiled_output, stage_plan, profile_info):\n        \"\"\"Run profiling on this profile worker.\n\n        If the RayActorError is catched, it retries until profile_maximum_retry\n        is reached. Otherwise, it directly returns. In both cases, the mesh\n        restarts.\n        \"\"\"\n        for _ in range(global_config.profile_maximum_retry):\n            try:\n                return self._profile_impl(stage_id, compiled_output, stage_plan,\n                                          profile_info)\n            except RayActorError as e:\n                logger.warning(f\"Meet ray actor error in profiling: {e}\")\n                self.restart(forced=True)\n            except RuntimeError as e:\n                logger.warning(f\"Meet runtime error in profiling: {e}\")\n                self.restart(forced=True)\n                break\n            except AssertionError as e:\n                logger.warning(f\"Meet assertion error in profiling: {e}\")\n                self.restart(forced=True)\n                break\n        return stage_id, np.inf, np.inf, 0\n\n    def restart(self, forced):\n        \"\"\"Restart the physical mesh.\"\"\"\n        self.mesh.shutdown(forced=forced)\n        self.virtual_mesh.launched_physical_mesh = None\n        self.mesh = self.virtual_mesh.get_physical_mesh()\n\n\nclass ProfileWorkerPool(BaseWorkerPoolWrapper):\n    \"\"\"A pool of ProfileWorker for distributed profiling.\"\"\"\n\n    def __init__(self, virtual_meshes, placement_group):\n        super().__init__()\n        worker_cls = ray.remote(ProfileWorker)\n        self.actors = [\n            worker_cls.options(placement_group=placement_group).remote(mesh)\n            for mesh in virtual_meshes\n        ]\n        self.pool = ActorPool(self.actors)\n\n\nclass HloCostModelProfileWorker:\n    \"\"\"A ray actor to estimate the cost of WrappedHLO based on cost model.\"\"\"\n\n    def __init__(self, prof_result, num_devices, num_micro_batches):\n        self.backend = xb.get_backend(global_config.backend)\n        self.prof_result = prof_result\n        self.num_devices = num_devices\n        self.num_micro_batches = num_micro_batches\n\n    def profile(self, stage_id, compiled_module_output, stage_plan,\n                profile_config):\n        \"\"\"Use cost model to estimate cost on this profile worker.\"\"\"\n        try:\n            compiled = run_backend_compilation(\n                self.backend,\n                compiled_module_output.hlo,\n                stage_plan,\n                self.num_devices,\n                bypass_device_assignment_check=True)\n        except RuntimeError as e:\n            logger.warning(f\"Compilation error (backend codegen): {e}\")\n            return stage_id, np.inf, np.inf, 0\n\n        hlo_module = compiled.hlo_modules()[0]\n        grad_sync_channel_ids = \"\"\n        if profile_config.acc_grad_outvars_indices:\n            grad_sync_channel_ids = get_grad_sync_channel_ids(hlo_module)\n        peak_memory = compiled.total_allocation_size()\n        available_memory = self.prof_result.available_memory_per_device\n        cost = estimate_hlo_module_cost(hlo_module, self.prof_result,\n                                        self.num_micro_batches,\n                                        grad_sync_channel_ids)\n        del compiled\n\n        #with open(f\"/home/ubuntu/efs/alpa/benchmark/alpa/tmp/\"\n        #          f\"profile_stage_{stage_id}.hlo\", \"w\") as fout:\n        #    fout.write(hlo_module.to_string())\n\n        return stage_id, cost, peak_memory, available_memory\n\n\nclass HloCostModelProfileWorkerPool(BaseWorkerPoolWrapper):\n    \"\"\"A pool of HloCostModelProfileWorker for distributed profiling.\n\n    Instead of doing real measurements, this class uses a HLO instruction\n    cost model to estimate the cost.\n    \"\"\"\n\n    def __init__(self, num_cpus, placement_group, prof_result, mesh_num_devices,\n                 num_micro_batches):\n        super().__init__()\n        num_gpus = get_num_available_gpus(placement_group)\n        gpu_per_cpu = 1\n        while gpu_per_cpu * num_cpus > num_gpus:\n            gpu_per_cpu /= 2\n        env_vars = {\"XLA_FLAGS\": \"--xla_gpu_autotune_level=0\"}\n        worker_cls = ray.remote(num_cpus=0,\n                                num_gpus=gpu_per_cpu)(HloCostModelProfileWorker)\n        self.actors = [\n            worker_cls.options(\n                runtime_env={\n                    \"env_vars\": env_vars\n                },\n                placement_group=placement_group,\n            ).remote(prof_result, mesh_num_devices, num_micro_batches)\n            for _ in range(num_cpus)\n        ]\n        self.pool = ActorPool(self.actors)\n\n\ndef compile_all(stages, num_micro_batches, default_as_option, profile_results):\n    \"\"\"\n    Compile all input stages.\n    \"\"\"\n    num_cpus = int(\n        min(max(ray.available_resources()[\"CPU\"] // 2, 1), len(stages)))\n\n    compile_workers = CompileWorkerPool(num_cpus)\n    num_compiled_stages = 0\n    for i, (stage_idx, stage_config, auto_sharding_config) in enumerate(stages):\n        if (stage_idx in profile_results and\n                profile_results[stage_idx].fully_profiled()):\n            continue\n        logical_mesh, autosharding_option_dict = auto_sharding_config\n        compile_workers.submit(\n            lambda w, v: w.compile_stage_for_profiling.remote(*v),\n            (i, stage_config.compile_config, logical_mesh,\n             dataclasses.replace(default_as_option, **\n                                 autosharding_option_dict), num_micro_batches))\n        num_compiled_stages += 1\n\n    compiled_outputs = [None] * len(stages)\n    for _ in tqdm.tqdm(range(num_compiled_stages)):\n        try:\n            i, compiled_output = compile_workers.get_next_unordered()\n        except TimeoutError:\n            logger.warning(\"Compile worker timeout\")\n            continue\n        except RayActorError as e:\n            logger.warning(f\"A Compile worker died unexpectedly: {e}\")\n            continue\n        compiled_outputs[i] = compiled_output\n        stage_idx, stage_config, auto_sharding_config = stages[i]\n        logical_mesh_shape = compiled_output.stage_plan.logical_mesh_shape\n        apply_in_shardings = compiled_output.apply_grad_input_sharding_protos\n        if apply_in_shardings is not None:\n            (initial_var_names,\n             initial_var_sizes) = compute_apply_grad_invar_size(\n                 apply_in_shardings, stage_config.apply_grad_config,\n                 logical_mesh_shape)\n        else:\n            initial_var_names = ()\n            initial_var_sizes = ()\n        if stage_idx not in profile_results:\n            profile_results[stage_idx] = StageProfileResult(\n                stage_config.n_modules, initial_var_names, initial_var_sizes)\n        else:\n            original_initial_size_dict = dict(\n                zip(profile_results[stage_idx].initial_var_names,\n                    profile_results[stage_idx].initial_var_sizes))\n            new_initial_size_dict = dict(\n                zip(initial_var_names, initial_var_sizes))\n            assert original_initial_size_dict == new_initial_size_dict, (\n                f\"Initial sizes mismatch between loaded result and newly \"\n                f\"compiled result: {original_initial_size_dict} \"\n                f\"vs {new_initial_size_dict}.\")\n\n    compile_workers.shutdown()\n    return compiled_outputs\n\n\ndef generate_module_profile_result(raw_result: Tuple,\n                                   profile_config: ModuleProfileConfig,\n                                   compile_output: ModuleCompileOutput,\n                                   logical_mesh_shape: Tuple[int, ...]):\n    compute_costs, peak_memory, available_memory = raw_result\n    invar_sizes = get_sharded_size_by_proto(\n        compile_output.input_sharding_protos, profile_config.invar_avals,\n        logical_mesh_shape, False)\n    outvar_sizes = get_sharded_size_by_proto(\n        [compile_output.output_sharding_proto], profile_config.outvar_avals,\n        logical_mesh_shape)\n    donate_invar_sizes = [\n        size\n        for donated, size in zip(profile_config.donated_invars, invar_sizes)\n        if donated\n    ]\n    temp_buffer_size = (peak_memory - sum(invar_sizes) - sum(outvar_sizes) +\n                        sum(donate_invar_sizes))\n\n    return ModuleProfileResult(\n        compute_cost=np.mean(compute_costs),\n        peak_memory=peak_memory,\n        temp_buffer_size=temp_buffer_size,\n        invar_names=tuple(profile_config.invar_names),\n        outvar_names=tuple(profile_config.outvar_names),\n        invar_sizes=invar_sizes,\n        outvar_sizes=outvar_sizes,\n        donated_invars=tuple(profile_config.donated_invars),\n        acc_grad_invars_indices=tuple(profile_config.acc_grad_invars_indices),\n        acc_grad_outvars_indices=tuple(profile_config.acc_grad_outvars_indices),\n        available_memory=available_memory,\n    )\n\n\ndef profile_all(stages, compiled_outputs: Sequence[CompileOutput], meshes,\n                num_micro_batches, auto_stage_option, profile_results):\n    \"\"\"Profile all compiled outputs on given meshes.\n\n    This function launches a profile worker pool and submits given tasks.\n    \"\"\"\n    placement_group = retrieve_placement_group()\n\n    if auto_stage_option.use_hlo_cost_model:\n        num_cpus = int(\n            min(max(ray.available_resources()[\"CPU\"] // 2, 1), len(stages)))\n        mesh_num_devices = meshes[0].num_devices\n        prof_database = ProfilingResultDatabase()\n        prof_database.load(auto_stage_option.profiling_database_filename)\n        prof_result = prof_database.query(\"default\", meshes[0].shape)\n        profile_workers = HloCostModelProfileWorkerPool(num_cpus,\n                                                        placement_group,\n                                                        prof_result,\n                                                        mesh_num_devices,\n                                                        num_micro_batches)\n    else:\n        profile_workers = ProfileWorkerPool(meshes, placement_group)\n\n    successful_compile_ct = 0\n    for i, (compiled_output, stage) in enumerate(zip(compiled_outputs, stages)):\n        if compiled_output is None:\n            continue\n        stage_idx, stage_config, _ = stage\n\n        for module_id, (acc_grad_module, profile_config) in enumerate(\n                zip(compiled_output.acc_grad_module_compile_outputs,\n                    stage_config.module_profile_configs)):\n            if profile_results[stage_idx].is_module_profiled(module_id):\n                continue\n            profile_workers.submit(lambda w, v: w.profile.remote(*v),\n                                   ((i, module_id), acc_grad_module,\n                                    compiled_output.stage_plan, profile_config))\n            successful_compile_ct += 1\n\n    pbar = tqdm.tqdm(range(successful_compile_ct))\n    for _ in pbar:\n        try:\n            ((i, module_id),\n             *module_raw_result) = profile_workers.get_next_unordered()\n        except TimeoutError:\n            profile_workers.shutdown(force=True)\n            logger.warning(\"After waiting for too long, \"\n                           \"all profile workers are forcely killed\")\n            return profile_results\n        except (RuntimeError, RayActorError):\n            profile_workers.shutdown(force=True)\n            logger.warning(\"Meet unexpected error, \"\n                           \"all profile workers are forcely killed\")\n            return profile_results\n        stage_idx, stage_config, _ = stages[i]\n        stage_compile_output = compiled_outputs[i]\n        module_profile_result = generate_module_profile_result(\n            module_raw_result, stage_config.module_profile_configs[module_id],\n            stage_compile_output.acc_grad_module_compile_outputs[module_id],\n            stage_compile_output.stage_plan.logical_mesh_shape)\n        pbar.write(f\"result[{stage_idx}, {module_id}] \"\n                   f\"= {module_profile_result}\")\n        profile_results[stage_idx].add_module_profile_result(\n            module_id, module_profile_result)\n    profile_workers.shutdown()\n    return profile_results\n\n\ndef generate_training_stages_2d(layers,\n                                layer_flops_prefix_sum,\n                                accumulator_mapping,\n                                acc_grad_invars,\n                                acc_grad_outvars,\n                                apply_grad_layers,\n                                apply_grad_global_info,\n                                mesh_id,\n                                autosharding_configs,\n                                mesh_num_devices,\n                                cluster_size,\n                                stage_imbalance_tolerance=np.inf):\n    print(\"- Generate all stage infos (Jaxpr -> HLO)\")\n    assert len(layers) % 2 == 0\n    num_layers = len(layers) // 2\n    indices = list(range(2 * num_layers))\n    computation_source_ratio = mesh_num_devices / cluster_size\n    is_full_mesh = computation_source_ratio == 1\n    tot_flops = layer_flops_prefix_sum[2 * num_layers]\n    stages = []\n    for start in tqdm.tqdm(range(0, num_layers)):\n        for end in tqdm.tqdm(range(start, num_layers), leave=False):\n            if is_full_mesh and not (start == 0 and end == num_layers - 1):\n                continue\n            flops_ratio = (\n                layer_flops_prefix_sum[end + 1] - layer_flops_prefix_sum[start]\n                + layer_flops_prefix_sum[2 * num_layers - start] -\n                layer_flops_prefix_sum[2 * num_layers - end - 1]) / tot_flops\n            if (computation_source_ratio > flops_ratio *\n                (1 + stage_imbalance_tolerance) or\n                    computation_source_ratio < flops_ratio /\n                (1 + stage_imbalance_tolerance)):\n                continue\n            forward_layer_indices = indices[start:end + 1]\n            backward_layer_indices = indices[2 * num_layers - end -\n                                             1:2 * num_layers - start]\n            selected_apply_grad_layers = [\n                apply_grad_layers[idx]\n                for idx in forward_layer_indices\n                if apply_grad_layers[idx] is not None\n            ]\n            stage_name = f\"stage_{start}_{end}\"\n            stage_config = generate_stage_info(\n                layers, [forward_layer_indices, backward_layer_indices],\n                accumulator_mapping, acc_grad_invars, acc_grad_outvars,\n                stage_name, selected_apply_grad_layers, apply_grad_global_info)\n            for config_idx, autosharding_config in enumerate(\n                    autosharding_configs):\n                if autosharding_config is not None:\n                    stage_indices = (start, end, mesh_id, config_idx)\n                    stages.append(\n                        (stage_indices, stage_config, autosharding_config))\n    return stages\n\n\ndef generate_inference_stages_2d(layers,\n                                 layer_flops_prefix_sum,\n                                 accumulator_mapping,\n                                 acc_grad_invars,\n                                 acc_grad_outvars,\n                                 apply_grad_layers,\n                                 apply_grad_global_info,\n                                 mesh_id,\n                                 autosharding_configs,\n                                 mesh_num_devices,\n                                 cluster_size,\n                                 stage_imbalance_tolerance=np.inf):\n    print(\"- Generate all stage infos (Jaxpr -> HLO)\")\n    num_layers = len(layers)\n    indices = list(range(2 * num_layers))\n    computation_source_ratio = mesh_num_devices / cluster_size\n    is_full_mesh = computation_source_ratio == 1\n    tot_flops = layer_flops_prefix_sum[num_layers]\n    stages = []\n    for start in tqdm.tqdm(range(0, num_layers)):\n        for end in tqdm.tqdm(range(start, num_layers), leave=False):\n            if is_full_mesh and not (start == 0 and end == num_layers - 1):\n                continue\n            flops_ratio = (layer_flops_prefix_sum[end + 1] -\n                           layer_flops_prefix_sum[start]) / tot_flops\n            if (computation_source_ratio > flops_ratio *\n                (1 + stage_imbalance_tolerance) or\n                    computation_source_ratio < flops_ratio /\n                (1 + stage_imbalance_tolerance)):\n                continue\n            forward_layer_indices = indices[start:end + 1]\n            selected_apply_grad_layers = [\n                apply_grad_layers[idx]\n                for idx in forward_layer_indices\n                if apply_grad_layers[idx] is not None\n            ]\n            assert len(selected_apply_grad_layers) == 0, (\n                \"Inference stage should not have apply_grad_layers\")\n            stage_name = f\"stage_{start}_{end}\"\n            stage_config = generate_stage_info(layers, [forward_layer_indices],\n                                               accumulator_mapping,\n                                               acc_grad_invars,\n                                               acc_grad_outvars, stage_name,\n                                               selected_apply_grad_layers,\n                                               apply_grad_global_info)\n            for config_idx, autosharding_config in enumerate(\n                    autosharding_configs):\n                if autosharding_config is not None:\n                    stage_indices = (start, end, mesh_id, config_idx)\n                    stages.append(\n                        (stage_indices, stage_config, autosharding_config))\n    return stages\n\n\ndef get_merged_stages_memory_stats(\n        profile_results: Sequence[StageProfileResult],\n        inference_mode: bool = False):\n    initial_var_sizes_dict = {}\n    for stage_result in profile_results:\n        for name, size in zip(stage_result.initial_var_names,\n                              stage_result.initial_var_sizes):\n            if name not in initial_var_sizes_dict:\n                initial_var_sizes_dict[name] = size\n            else:\n                assert initial_var_sizes_dict[name] == size, (\n                    f\"Apply grad invar {name} has different size accross \"\n                    f\"different stages: {initial_var_sizes_dict[name]} \"\n                    f\"vs. {size}.\")\n    initial_size = sum(initial_var_sizes_dict.values())\n    peak_memory = 0\n    available_memory = min(\n        result.available_memory for result in profile_results)\n    n_stages = len(profile_results)\n    n_modules = profile_results[0].n_modules\n    if inference_mode:\n        assert n_modules == 1, \"Inference mode should only have 1 module.\"\n        module_execution_orders = [list(range(n_stages))]\n    else:\n        assert n_modules == 2, (\"Only support forward and backward modules in \"\n                                \"training mode.\")\n        module_execution_orders = [\n            list(range(n_stages)),\n            list(range(n_stages - 1, -1, -1))\n        ]\n    assert all(result.n_modules == n_modules for result in profile_results)\n\n    # eliminate_time[var] = k means that the variable can be eliminated after\n    # stage k.\n    last_used_stage_no = {}\n    donation_mapping = {}\n    reverse_donation_mapping = {}\n    acc_grad_invars = OrderedSet()\n    acc_grad_outvars = OrderedSet()\n    stage_no = n_stages * n_modules\n    for module_id, stage_order in reversed(\n            list(enumerate(module_execution_orders))):\n        for stage_id in reversed(stage_order):\n            stage_no -= 1\n            module_result = profile_results[stage_id].module_profile_results[\n                module_id]\n            for invar in module_result.invar_names:\n                if invar not in last_used_stage_no:\n                    last_used_stage_no[invar] = stage_no\n            for i, (invar, donated) in enumerate(\n                    zip(module_result.invar_names,\n                        module_result.donated_invars)):\n                if donated:\n                    # Note: here we assume that we always donate the i-th\n                    # invar to the i-th outvar. See rearrange_vars function.\n                    donation_mapping[invar] = module_result.outvar_names[i]\n                    reverse_donation_mapping[\n                        module_result.outvar_names[i]] = invar\n            for var_id in module_result.acc_grad_invars_indices:\n                acc_grad_invars.add(module_result.invar_names[var_id])\n            for var_id in module_result.acc_grad_outvars_indices:\n                acc_grad_outvars.add(module_result.outvar_names[var_id])\n\n    all_module_invars = []\n    for module_id, stage_order in enumerate(module_execution_orders):\n        module_invars = {}\n        in_module_vars = OrderedSet()\n        for stage_id in stage_order:\n            module_result = profile_results[stage_id].module_profile_results[\n                module_id]\n            for invar, size in zip(module_result.invar_names,\n                                   module_result.invar_sizes):\n                # If the variable is from another module instead of generated\n                # with in the module, it cannot be freed within the execution\n                # of a single module, but need to be freed after the module\n                # finishes.\n                if invar in in_module_vars:\n                    continue\n                if invar in module_invars:\n                    module_invars[invar] = max(module_invars[invar], size)\n                else:\n                    module_invars[invar] = size\n            for outvar in module_result.outvar_names:\n                in_module_vars.add(outvar)\n        all_module_invars.append(module_invars)\n\n    env = {}\n    intermediate_size = None\n    stage_no = -1\n    for module_id, stage_order in enumerate(module_execution_orders):\n        module_invars = all_module_invars[module_id]\n        env.update(module_invars)\n        for stage_id in stage_order:\n            stage_no += 1\n            module_result = profile_results[stage_id].module_profile_results[\n                module_id]\n            for invar, size in zip(module_result.invar_names,\n                                   module_result.invar_sizes):\n                if invar not in env:\n                    env[invar] = size\n                else:\n                    # env[invar] and size might be different because of\n                    # different sharding specs. We take the max for\n                    # estimation.\n                    env[invar] = max(env[invar], size)\n            for outvar, size in zip(module_result.outvar_names,\n                                    module_result.outvar_sizes):\n                assert outvar not in env\n                env[outvar] = size\n                if outvar in reverse_donation_mapping:\n                    assert reverse_donation_mapping[outvar] in env\n                    del env[reverse_donation_mapping[outvar]]\n            total_env_size = sum(env.values())\n            peak_memory = max(peak_memory,\n                              total_env_size + module_result.temp_buffer_size)\n            # Remove the variables that are no longer used and is generated\n            # within the module.\n            var_to_be_eliminated = []\n            for var in env:\n                if (var not in module_invars and var not in acc_grad_invars and\n                        var not in acc_grad_outvars and\n                    (var not in last_used_stage_no or\n                     last_used_stage_no[var] <= stage_no)):\n                    var_to_be_eliminated.append(var)\n            for var in var_to_be_eliminated:\n                del env[var]\n        # Remove the variables that are no longer used\n        var_to_be_eliminated = []\n        for var in env:\n            if (var not in acc_grad_invars and var not in acc_grad_outvars and\n                (var not in last_used_stage_no or\n                 last_used_stage_no[var] <= stage_no)):\n                var_to_be_eliminated.append(var)\n        for var in var_to_be_eliminated:\n            del env[var]\n\n        # Record the variables that are not eliminated at the end of the\n        # last forward module.\n        if module_id == 0 and not inference_mode:\n            intermediate_size = sum(env.values())\n\n    for var in acc_grad_invars:\n        if var not in donation_mapping:\n            del env[var]\n\n    for var in acc_grad_outvars:\n        del env[var]\n\n    assert len(env) == 0, f\"Variables {env.keys()} are not eliminated.\"\n\n    if inference_mode:\n        max_stage = None\n    else:\n        max_stage = int((available_memory - peak_memory - initial_size) //\n                        max(intermediate_size, 1e-8) - 1)\n        max_stage = min(max(-1, max_stage), INFINITY_N_STAGES)\n\n    return (available_memory, peak_memory, initial_size, intermediate_size,\n            max_stage)\n\n\ndef interpret_profile_result_training_2d(\n        profile_results: Dict[Tuple[int, ...],\n                              StageProfileResult], num_layers: int,\n        num_submesh_choices: int, num_autosharding_configs: int):\n    all_compute_cost = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        np.inf,\n        dtype=np.float64)\n    all_max_n_succ_stages = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        -1,\n        dtype=np.int64)\n\n    for index in np.ndindex(num_layers, num_layers, num_submesh_choices,\n                            num_autosharding_configs):\n        if index not in profile_results:\n            continue\n        profile_result = profile_results[index]\n        all_compute_cost[index] = sum(\n            result.compute_cost\n            for result in profile_result.module_profile_results)\n        _, _, _, _, all_max_n_succ_stages[index] = (\n            get_merged_stages_memory_stats([profile_result]))\n\n    return all_compute_cost, all_max_n_succ_stages\n\n\ndef interpret_profile_result_inference_2d(\n        profile_results: Dict[Tuple[int, ...],\n                              StageProfileResult], num_layers: int,\n        num_submesh_choices: int, num_autosharding_configs: int):\n    all_compute_cost = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        np.inf,\n        dtype=np.float64)\n    all_peak_memory = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        np.inf,\n        dtype=np.float64)\n\n    for index in np.ndindex(num_layers, num_layers, num_submesh_choices,\n                            num_autosharding_configs):\n        if index not in profile_results:\n            continue\n        profile_result = profile_results[index]\n        assert len(profile_result.module_profile_results) == 1\n        all_compute_cost[index] = (\n            profile_result.module_profile_results[0].compute_cost)\n        all_peak_memory[index] = (\n            profile_result.module_profile_results[0].peak_memory)\n\n    return all_compute_cost, all_peak_memory\n\n\ndef generate_training_stages_1d(layers, accumulator_mapping, acc_grad_invars,\n                                acc_grad_outvars, apply_grad_layers,\n                                apply_grad_global_info, mesh_id,\n                                autosharding_configs):\n    print(\"- Generate all stage infos (Jaxpr -> HLO)\")\n    assert len(layers) % 2 == 0\n    num_layers = len(layers) // 2\n    stages = []\n    for l in tqdm.tqdm(range(0, num_layers)):\n        selected_apply_grad_layers = ([] if apply_grad_layers[l] is None else\n                                      [apply_grad_layers[l]])\n        stage_name = f\"stage_{l}\"\n        stage_config = generate_stage_info(layers, [(l,),\n                                                    (2 * num_layers - l - 1,)],\n                                           accumulator_mapping, acc_grad_invars,\n                                           acc_grad_outvars, stage_name,\n                                           selected_apply_grad_layers,\n                                           apply_grad_global_info)\n        for config_idx, autosharding_config in enumerate(autosharding_configs):\n            if autosharding_config is not None:\n                stage_indices = (l, mesh_id, config_idx)\n                stages.append(\n                    (stage_indices, stage_config, autosharding_config))\n    return stages\n\n\ndef generate_inference_stages_1d(layers, accumulator_mapping, acc_grad_invars,\n                                 acc_grad_outvars, apply_grad_layers,\n                                 apply_grad_global_info, mesh_id,\n                                 autosharding_configs):\n    print(\"- Generate all stage infos (Jaxpr -> HLO)\")\n    num_layers = len(layers)\n    stages = []\n    for l in tqdm.tqdm(range(0, num_layers)):\n        selected_apply_grad_layers = ([] if apply_grad_layers[l] is None else\n                                      [apply_grad_layers[l]])\n        assert len(selected_apply_grad_layers) == 0, (\n            \"Inference stage should not have apply_grad_layers\")\n        stage_name = f\"stage_{l}\"\n        stage_config = generate_stage_info(layers, [(l,)], accumulator_mapping,\n                                           acc_grad_invars, acc_grad_outvars,\n                                           stage_name,\n                                           selected_apply_grad_layers,\n                                           apply_grad_global_info)\n        for config_idx, autosharding_config in enumerate(autosharding_configs):\n            if autosharding_config is not None:\n                stage_indices = (l, mesh_id, config_idx)\n                stages.append(\n                    (stage_indices, stage_config, autosharding_config))\n    return stages\n\n\ndef interpret_profile_result_training_1d(\n        profile_results: Dict[Tuple[int, ...],\n                              StageProfileResult], num_layers: int,\n        num_submesh_choices: int, num_autosharding_configs: int):\n    all_compute_cost = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        np.inf,\n        dtype=np.float64)\n    all_max_n_succ_stages = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        -1,\n        dtype=np.int64)\n\n    for start in range(num_layers):\n        for end in range(start, num_layers):\n            for submesh_choice in range(num_submesh_choices):\n                for config_idx in range(num_autosharding_configs):\n                    if any(\n                        (l, submesh_choice, config_idx) not in profile_results\n                            for l in range(start, end + 1)):\n                        continue\n                    selected_profile_results = [\n                        profile_results[(l, submesh_choice, config_idx)]\n                        for l in range(start, end + 1)\n                    ]\n                    all_compute_cost[\n                        start, end, submesh_choice, config_idx] = sum(\n                            result.compute_cost\n                            for profile_result in selected_profile_results\n                            for result in profile_result.module_profile_results)\n                    (_, _, _, _, all_max_n_succ_stages[start, end,\n                                                       submesh_choice,\n                                                       config_idx]\n                    ) = get_merged_stages_memory_stats(selected_profile_results)\n    return all_compute_cost, all_max_n_succ_stages\n\n\ndef interpret_profile_result_inference_1d(\n        profile_results: Dict[Tuple[int, ...],\n                              StageProfileResult], num_layers: int,\n        num_submesh_choices: int, num_autosharding_configs: int):\n    all_compute_cost = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        np.inf,\n        dtype=np.float64)\n    all_peak_memory = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        np.inf,\n        dtype=np.float64)\n\n    for start in range(num_layers):\n        for end in range(start, num_layers):\n            for submesh_choice in range(num_submesh_choices):\n                for config_idx in range(num_autosharding_configs):\n                    if any(\n                        (l, submesh_choice, config_idx) not in profile_results\n                            for l in range(start, end + 1)):\n                        continue\n                    selected_profile_results = [\n                        profile_results[(l, submesh_choice, config_idx)]\n                        for l in range(start, end + 1)\n                    ]\n                    for result in selected_profile_results:\n                        assert len(result.module_profile_results) == 1\n                    all_compute_cost[\n                        start, end, submesh_choice, config_idx] = sum(\n                            profile_result.module_profile_results[0].\n                            compute_cost\n                            for profile_result in selected_profile_results)\n                    (available_memory, peak_memory, _, _,\n                     _) = get_merged_stages_memory_stats(\n                         selected_profile_results, inference_mode=True)\n                    if peak_memory > available_memory:\n                        all_compute_cost[start, end, submesh_choice,\n                                         config_idx] = np.inf\n    return all_compute_cost, all_peak_memory\n\n\ndef distributed_profile_on_mesh(stages, meshes: Sequence[VirtualPhysicalMesh],\n                                num_micro_batches, default_as_option,\n                                auto_stage_option, profile_results):\n    timers(\"stage-construction-compilation\").start()\n\n    if len(stages) == 0:\n        # Suspend timers\n        timers(\"stage-construction-compilation\").stop()\n        return profile_results\n\n    print(\"- Compile all stages\")\n    try:\n        compiled_outputs = compile_all(stages, num_micro_batches,\n                                       default_as_option, profile_results)\n    except RayActorError as e:\n        logger.warning(f\"Compilation fatal error: {e}\")\n        timers(\"stage-construction-compilation\").stop()\n        return profile_results\n    timers(\"stage-construction-compilation\").stop()\n\n    print(\"- Profile all stages\")\n    # shape of compute_cost and max_n_succ_stages:\n    # (num_layers, num_layers, num_autosharding_configs)\n    timers(\"stage-construction-profiling\").start()\n    profile_results = profile_all(stages, compiled_outputs, meshes,\n                                  num_micro_batches, auto_stage_option,\n                                  profile_results)\n    timers(\"stage-construction-profiling\").stop()\n    return profile_results\n\n\ndef check_profile_results_consistent(stages,\n                                     profile_results: Dict[Tuple,\n                                                           StageProfileResult]):\n    for stage_idx, stage_config, _ in stages:\n        if stage_idx not in profile_results:\n            continue\n        profile_result = profile_results[stage_idx]\n        assert profile_result.n_modules == stage_config.n_modules\n        for module_profile_result, module_profile_config in (\n                profile_result.module_profile_results,\n                stage_config.module_profile_configs):\n            if module_profile_result is None:\n                continue\n            assert (module_profile_result.invar_names ==\n                    module_profile_config.invar_names)\n            assert (module_profile_result.outvar_names ==\n                    module_profile_config.outvar_names)\n            assert (module_profile_result.donated_invars ==\n                    module_profile_config.donated_invars)\n            assert (module_profile_result.required_outvars_indices ==\n                    module_profile_config.required_outvars_indices)\n\n\ndef _get_layer_flops_prefix_sum(layers):\n    layer_flops_prefix_sum = [0]\n    for layer in layers:\n        layer_flops = sum(eqn_flops(eqn) for eqn in layer.eqns)\n        layer_flops_prefix_sum.append(layer_flops_prefix_sum[-1] + layer_flops)\n    return layer_flops_prefix_sum\n\n\ndef get_compute_cost(\n        virtual_mesh: VirtualPhysicalMesh,\n        submesh_choices: Sequence[Tuple[int]],\n        autosharding_configs: Sequence[Sequence[Tuple[LogicalDeviceMesh,\n                                                      dict]]],\n        layers: Sequence[JaxPipelineComputation],\n        accumulator_mapping: Dict[Var, Var],\n        acc_grad_invars: Sequence[Var],\n        acc_grad_outvars: Sequence[Var],\n        apply_grad_layers: Sequence[JaxPipelineComputation],\n        apply_grad_global_info: Tuple,\n        num_micro_batches: int,\n        default_as_option: AutoShardingOption,\n        auto_stage_option: \"AutoStageOption\",\n        inference_mode: bool = False):\n    \"\"\"Get computation cost for each possible (stage, mesh) configuration.\n\n    This function enumerates all given submesh choices, then profiles compute\n    cost of all stage configuration under the submesh. For each submesh, it\n    slices the given mesh or the whole device cluster into submeshes to profile.\n\n    Args:\n        virtual_mesh: The whole virtual mesh. If profile_with_whole_ray_cluster\n            is turned off in global config, virtual_mesh is sliced into pieces\n            to run profiling. Otherwise, the whole device cluster is sliced for\n            profiling.\n        submesh_choices: All available submesh shape choices.\n        autosharding_configs: All auto sharding configs for each submesh.\n        layers: Layers for computing and accumulating gradients (forward +\n            backward).\n        accumulator_mapping: Donation mapping from accumulator to\n            accumulated results for all layers.\n        acc_grad_outvars: Global input variables for all layers.\n        acc_grad_outvars: Global output variables for all layers.\n        apply_grad_layers: Apply gradient computations corresponding to each\n            forward layers.\n        apply_grad_global_info: Donation mapping and outvars for apply gradient\n            stages.\n        default_as_option: The default auto-sharding options.\n        auto_stage_option: The auto stage construction algorthm options.\n        inference_mode: Whether to run in inference mode.\n\n    Returns:\n        Two np.ndarray, each with shape (L, L, S, C), where L is the number of\n        forward layers, S is the number of submesh choices, and C is the maximal\n        number of autosharding configs for a submesh choice.\n        At index (i, j, s, c), the array stores the value under the condition:\n        the stage contains forward layers i, i+1, ... j and corresponding\n        backward layers, and runs under the s-th submesh and c-th auto sharding\n        config for the submesh.\n        compute_cost: The compute cost of all possible configurations.\n        max_n_succ_stages: The maximal number of succeeding stages. This\n            is calculated based on memory constraints.\n    \"\"\"\n    cluster_size = virtual_mesh.num_devices\n    layer_flops_prefix_sum = _get_layer_flops_prefix_sum(layers)\n    if inference_mode:\n        num_layers = len(layers)\n    else:\n        assert len(layers) % 2 == 0\n        num_layers = len(layers) // 2\n    num_submesh_choices = len(submesh_choices)\n    num_autosharding_configs = len(autosharding_configs[0])\n\n    if auto_stage_option.cached_profile_result is not None:\n        with open(auto_stage_option.cached_profile_result, \"rb\") as f:\n            profile_results = pickle.load(f)\n    else:\n        profile_results = {}\n    print(\"-\" * 20 + \" Automatic stage clustering \" + \"-\" * 20)\n    print(f\"submesh_choices: {submesh_choices}\")\n\n    # Reverse submesh_choices to test larger meshes first\n    for mesh_id, submesh in reversed(list(enumerate(submesh_choices))):\n        print(f\"- Profiling for submesh {mesh_id} {submesh}:\")\n        num_hosts, num_devices_per_host = submesh\n        tic = time()\n        if global_config.profile_with_whole_ray_cluster:\n            whole_cluster_virtual_mesh = get_global_cluster(\n            ).get_virtual_physical_mesh()\n            sliced_virtual_meshes = (\n                whole_cluster_virtual_mesh.slice_profiling_submeshes(\n                    num_hosts, num_devices_per_host))\n        else:\n            sliced_virtual_meshes = virtual_mesh.slice_profiling_submeshes(\n                num_hosts, num_devices_per_host)\n\n        if auto_stage_option.layer_profile_mode == \"composition\":\n            if inference_mode:\n                stages = generate_inference_stages_2d(\n                    layers, layer_flops_prefix_sum, accumulator_mapping,\n                    acc_grad_invars, acc_grad_outvars, apply_grad_layers,\n                    apply_grad_global_info, mesh_id,\n                    autosharding_configs[mesh_id],\n                    sliced_virtual_meshes[0].num_devices, cluster_size,\n                    auto_stage_option.stage_imbalance_tolerance)\n            else:\n                stages = generate_training_stages_2d(\n                    layers, layer_flops_prefix_sum, accumulator_mapping,\n                    acc_grad_invars, acc_grad_outvars, apply_grad_layers,\n                    apply_grad_global_info, mesh_id,\n                    autosharding_configs[mesh_id],\n                    sliced_virtual_meshes[0].num_devices, cluster_size,\n                    auto_stage_option.stage_imbalance_tolerance)\n        elif auto_stage_option.layer_profile_mode == \"individual\":\n            if inference_mode:\n                stages = generate_inference_stages_1d(\n                    layers, accumulator_mapping, acc_grad_invars,\n                    acc_grad_outvars, apply_grad_layers, apply_grad_global_info,\n                    mesh_id, autosharding_configs[mesh_id])\n            else:\n                stages = generate_training_stages_1d(\n                    layers, accumulator_mapping, acc_grad_invars,\n                    acc_grad_outvars, apply_grad_layers, apply_grad_global_info,\n                    mesh_id, autosharding_configs[mesh_id])\n        else:\n            raise ValueError(f\"Unknown layer profile mode: \"\n                             f\"{auto_stage_option.layer_profile_mode}\")\n\n        check_profile_results_consistent(stages, profile_results)\n\n        profile_results = distributed_profile_on_mesh(\n            stages, sliced_virtual_meshes, num_micro_batches, default_as_option,\n            auto_stage_option, profile_results)\n\n        toc = time()\n        print(f\"Profiling for submesh {mesh_id} {submesh} takes {toc - tic:.2f}\"\n              f\" seconds\")\n        print(\"-\" * 50)\n\n    timestamp = datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n    profile_result_file_name = (f\"profile-results-{timestamp}.npy\")\n    np.save(profile_result_file_name, profile_results)\n    global last_compute_cost_file_name\n    last_compute_cost_file_name = profile_result_file_name\n    print(f\"Profile result saved to: {profile_result_file_name}\")\n    print(\"-\" * 70)\n\n    if auto_stage_option.layer_profile_mode == \"composition\":\n        if inference_mode:\n            compute_cost, _ = interpret_profile_result_inference_2d(\n                profile_results, num_layers, num_submesh_choices,\n                num_autosharding_configs)\n            max_n_succ_stages = None\n        else:\n            (compute_cost,\n             max_n_succ_stages) = interpret_profile_result_training_2d(\n                 profile_results, num_layers, num_submesh_choices,\n                 num_autosharding_configs)\n    elif auto_stage_option.layer_profile_mode == \"individual\":\n        if inference_mode:\n            compute_cost, _ = interpret_profile_result_inference_1d(\n                profile_results, num_layers, num_submesh_choices,\n                num_autosharding_configs)\n            max_n_succ_stages = None\n        else:\n            (compute_cost,\n             max_n_succ_stages) = interpret_profile_result_training_1d(\n                 profile_results, num_layers, num_submesh_choices,\n                 num_autosharding_configs)\n    else:\n        raise ValueError(f\"Unknown layer profile mode: \"\n                         f\"{auto_stage_option.layer_profile_mode}\")\n\n    return compute_cost, max_n_succ_stages\n\n\ndef select_module_layers(layers: Sequence[JaxPipelineComputation],\n                         layer_indices: Sequence[int],\n                         accumulator_mapping: Dict[Var, Var],\n                         acc_grad_outvars: Sequence[Var]):\n    \"\"\"\n    For each module, select the layers and get the accumulator mapping and\n    required outvars for each module.\n\n    Args:\n        layers: all layers.\n        layer_indices: a list of layer ids within the module.\n        accumulator_mapping: the mapping from accumulator input to output,\n            used to determine the donation.\n        acc_grad_invars: the invars of the accumulator gradient layers.\n        acc_grad_outvars: the outvars of the accumulator gradient layers.\n\n    Returns:\n        module: a list of layers that belong to the module.\n        module_accumulator_mappings: accumulator mapping for the module.\n        module_required_outvars: required outvars for the module.\n    \"\"\"\n    reversed_accumulator_mapping = {\n        v: k for k, v in accumulator_mapping.items()\n    }\n\n    gensym_fn = gensym([layer.closed_jaxpr().jaxpr for layer in layers])\n    num_layers = len(layers)\n    local_used = OrderedSet()\n    new_layers = []\n    module_required_outvars = OrderedSet()\n    module_accumulator_mapping = {}\n    used_by_other_layers_set = OrderedSet(acc_grad_outvars)\n    for layer_id in reversed(range(num_layers)):\n        layer = layers[layer_id]\n        if layer_id not in layer_indices:\n            used_by_other_layers_set.update(layer.invars)\n            continue\n        layer_donation, new_layer = (\n            get_local_donation_mapping_and_add_missing_invars(\n                layer, reversed_accumulator_mapping, gensym_fn))\n        for invar in layer_donation:\n            assert (invar not in local_used and\n                    invar not in used_by_other_layers_set)\n\n        required_outvars = [\n            var for var in new_layer.outvars if var in used_by_other_layers_set\n        ]\n        module_accumulator_mapping.update(layer_donation)\n        module_required_outvars.update(required_outvars)\n        local_used.update(new_layer.invars)\n        new_layers.append(new_layer)\n    return (reversed(new_layers), module_accumulator_mapping,\n            module_required_outvars)\n\n\ndef split_sharding_specs(layers: Sequence[JaxPipelineComputation],\n                         mixed_jaxpr: ClosedJaxpr, in_sharding_specs,\n                         out_sharding_specs):\n    \"\"\"\n    Split sharding specs of layers.\n\n    Some intermediate sharding specs are missed,\n    but they are not across meshes so this does not matter.\n    \"\"\"\n    in_sharding_dict = dict(zip(mixed_jaxpr.jaxpr.invars, in_sharding_specs))\n    out_sharding_dict = dict(zip(mixed_jaxpr.jaxpr.outvars, out_sharding_specs))\n    layer_in_sharding_specs = []\n    layer_out_sharding_specs = []\n    for layer in layers:\n        layer_in_sharding_specs.append(\n            [in_sharding_dict.get(var, None) for var in layer.invars])\n        layer_out_sharding_specs.append(\n            [out_sharding_dict.get(var, None) for var in layer.outvars])\n    return layer_in_sharding_specs, layer_out_sharding_specs\n\n\ndef generate_stage_info(all_layers, selected_indices,\n                        global_accumulator_mapping, acc_grad_invars,\n                        acc_grad_outvars, name, apply_grad_layers,\n                        apply_grad_info):\n    \"\"\"Combine selected layers together for profiling.\"\"\"\n    modules = []\n    module_accumulator_mappings = []\n    module_required_outvars = []\n    for layer_indices in selected_indices:\n        module, module_accumulator_mapping, required_outvars = (\n            select_module_layers(all_layers, layer_indices,\n                                 global_accumulator_mapping, acc_grad_outvars))\n        modules.append(module)\n        module_accumulator_mappings.append(module_accumulator_mapping)\n        module_required_outvars.append(required_outvars)\n\n    n_modules = len(modules)\n    module_jaxprs = [\n        [layer.closed_jaxpr() for layer in layers] for layers in modules\n    ]\n\n    module_names = [f\"{name}_acc_grad_{i}\" for i in range(n_modules)]\n    module_merged_jaxprs = []\n    module_profile_configs = []\n\n    all_modules_donation_mapping = {}\n    all_modules_donate_invars = []\n    all_modules_outvars = OrderedSet()\n    all_modules_acc_grad_outvars_indices = []\n    acc_grad_invars_set = OrderedSet(acc_grad_invars)\n    acc_grad_outvars_set = OrderedSet(acc_grad_outvars)\n    for module_name, jaxprs, accumulator_mapping, required_outvars in zip(\n            module_names, module_jaxprs, module_accumulator_mappings,\n            module_required_outvars):\n        merged_jaxpr = merge_marked_jaxprs_with_named_call(\n            jaxprs, required_outvars, accumulator_mapping, module_name)\n        outvars_set = set(merged_jaxpr.jaxpr.outvars)\n        is_donated = tuple(invar in accumulator_mapping and\n                           accumulator_mapping[invar] in outvars_set\n                           for invar in merged_jaxpr.jaxpr.invars)\n        acc_grad_invars_indices = tuple(\n            i for i, outvar in enumerate(merged_jaxpr.jaxpr.invars)\n            if outvar in acc_grad_invars_set)\n        acc_grad_outvars_indices = tuple(\n            i for i, outvar in enumerate(merged_jaxpr.jaxpr.outvars)\n            if outvar in acc_grad_outvars_set)\n        invar_names = tuple(repr(var) for var in merged_jaxpr.jaxpr.invars)\n        outvar_names = tuple(repr(var) for var in merged_jaxpr.jaxpr.outvars)\n        invar_avals = tuple(var.aval for var in merged_jaxpr.jaxpr.invars)\n        outvar_avals = tuple(var.aval for var in merged_jaxpr.jaxpr.outvars)\n        profile_config = ModuleProfileConfig(invar_names, outvar_names,\n                                             invar_avals, outvar_avals,\n                                             is_donated,\n                                             acc_grad_invars_indices,\n                                             acc_grad_outvars_indices)\n        module_merged_jaxprs.append(merged_jaxpr)\n        module_profile_configs.append(profile_config)\n        all_modules_donate_invars.append(is_donated)\n        all_modules_donation_mapping.update(accumulator_mapping)\n        all_modules_outvars.update(merged_jaxpr.jaxpr.outvars)\n        all_modules_acc_grad_outvars_indices.append(acc_grad_outvars_indices)\n\n    if len(apply_grad_layers) > 0:\n        apply_grad_donation, apply_grad_outvars = apply_grad_info\n        apply_grad_module_name = \"_\".join([name, APPLY_GRAD_MARKER_SUFFIX])\n        merged_apply = merge_marked_jaxprs_with_named_call(\n            [layer.closed_jaxpr() for layer in apply_grad_layers],\n            apply_grad_outvars, apply_grad_donation, name + \"_apply\")\n        outvars_set = set(merged_apply.jaxpr.outvars)\n        is_donated = tuple(invar in apply_grad_donation and\n                           apply_grad_donation[invar] in outvars_set\n                           for invar in merged_apply.jaxpr.invars)\n        apply_only_invars = OrderedSet(merged_apply.jaxpr.invars)\n        for module_jaxpr in module_merged_jaxprs:\n            apply_only_invars = apply_only_invars.difference(\n                module_jaxpr.jaxpr.invars)\n            apply_only_invars = apply_only_invars.difference(\n                module_jaxpr.jaxpr.outvars)\n        apply_info = ApplyGradConfig(merged_apply.jaxpr.invars,\n                                     apply_only_invars)\n        module_names.append(apply_grad_module_name)\n        module_merged_jaxprs.append(merged_apply)\n        all_modules_donate_invars.append(is_donated)\n        all_modules_donation_mapping.update(apply_grad_donation)\n        all_modules_outvars.update(merged_apply.jaxpr.outvars)\n    else:\n        apply_info = None\n\n    all_modules_merged_jaxpr, all_modules_is_donated = (\n        merge_unmarked_with_call(module_merged_jaxprs, module_names,\n                                 all_modules_outvars,\n                                 all_modules_donation_mapping))\n    hlo = jaxpr_to_hlo(name, all_modules_merged_jaxpr, all_modules_is_donated)\n    compile_config = CompileConfig(hlo, module_names, all_modules_donate_invars,\n                                   all_modules_acc_grad_outvars_indices)\n    stage_config = StageConfig(n_modules, compile_config,\n                               module_profile_configs, apply_info)\n    return stage_config\n\n\ndef create_collective_group(src_mesh: PhysicalDeviceMesh,\n                            dst_mesh: PhysicalDeviceMesh) -> CollectiveGroup:\n    \"\"\"Create a dummy collective group for profiling.\"\"\"\n    cg = CollectiveGroup(\n        OrderedSet(src_mesh.device_strs + dst_mesh.device_strs), src_mesh,\n        dst_mesh)\n    cg.instantiate()\n    return cg\n\n\ndef dummy_resharding_send_recv_strategy(spec: ReshardingTaskSpec):\n    \"\"\"Generates a dummy sharding strategy for profiling.\"\"\"\n    src_loads = {src: 0 for src in spec.src.device_mesh.device_strs}\n    dst_loads = {dst: 0 for dst in spec.dst.device_mesh.device_strs}\n    return (\n        CrossMeshCommunicator._generate_send_recv_resharding_strategy_by_loads(  # pylint: disable=protected-access\n            spec, src_loads, dst_loads))\n\n\ndef dummy_resharding_broadcast_strategy(spec: ReshardingTaskSpec):\n    \"\"\"Generates a dummy sharding strategy for profiling.\"\"\"\n    src_loads = {src: 0 for src in spec.src.device_mesh.device_strs}\n    dst_loads = {dst: 0 for dst in spec.dst.device_mesh.device_strs}\n    return (\n        CrossMeshCommunicator._generate_broadcast_resharding_strategy_by_loads(  # pylint: disable=protected-access\n            spec, src_loads, dst_loads))\n\n\n# FIXME(Hao): this function is broken by recent updates. Use with caution.\ndef profile_layer_communication_cost(\n        src: JaxPipelineComputation, dst: JaxPipelineComputation,\n        src_outvar_sharding_spec, dst_invar_sharding_spec,\n        src_mesh: VirtualPhysicalMesh, dst_mesh: VirtualPhysicalMesh,\n        collective_group: CollectiveGroup):\n    \"\"\"Profile communication cost for given two stages.\n\n    It ignores the global load balance, but instead only consider the balance of\n    the task. However, as the communication is sequential and SPMD, this does\n    not hurt much.\n    \"\"\"\n    src_outvars = {v: idx for idx, v in enumerate(src.outvars)}\n\n    backup_use_dummy_value = global_config.use_dummy_value_for_benchmarking\n    global_config.use_dummy_value_for_benchmarking = True\n    tasks = []\n    src_phy_mesh = collective_group.src_mesh\n    for idx, invar in enumerate(dst.invars):\n        if invar in src_outvars:\n            out_sharding_spec = src_outvar_sharding_spec[src_outvars[invar]]\n            in_sharding_spec = dst_invar_sharding_spec[idx]\n            src_array = VirtualDistributedArray(device_mesh=src_mesh,\n                                                aval=invar.aval,\n                                                sharding_spec=out_sharding_spec)\n            dst_array = VirtualDistributedArray(device_mesh=dst_mesh,\n                                                aval=invar.aval,\n                                                sharding_spec=in_sharding_spec)\n            task_spec = ReshardingTaskSpec(src_array, dst_array, [])\n            # create resharding strategy, ignore global load balance\n            if global_config.resharding_mode == \"send_recv\":\n                strategy = dummy_resharding_send_recv_strategy(task_spec)\n            else:\n                strategy = dummy_resharding_broadcast_strategy(task_spec)\n            task_spec.set_resharding_strategy(strategy)\n            # create distributed array as dummy inputs\n            input_indices = pxla.spec_to_indices(invar.aval.shape,\n                                                 out_sharding_spec)\n            remote_ref = _shard_device_array(jnp.zeros_like(invar.aval),\n                                             src_phy_mesh, input_indices)\n            DistributedArray(src_phy_mesh, invar.aval, in_sharding_spec,\n                             remote_ref, input_indices)\n            if global_config.resharding_mode == \"send_recv\":\n                task = SymbolicReshardingTask(task_spec, collective_group,\n                                              collective_group.src_mesh,\n                                              collective_group.dst_mesh)\n            else:\n                task = SymbolicBroadcastReshardingTask(\n                    task_spec, collective_group, collective_group.src_mesh,\n                    collective_group.dst_mesh)\n            tasks.append(task)\n\n    for task in tasks:\n        task.put_send_recv_tasks()\n    src_phy_mesh.sync_workers()\n    collective_group.dst_mesh.sync_workers()\n    results = []\n    for task in tasks:\n        results.append(task.do_prepared(task.src_array, True))\n\n    tot_cost = sum(max(result) for result in results)\n\n    global_config.use_dummy_value_for_benchmarking = backup_use_dummy_value\n    return tot_cost\n\n\ndef _get_sharded_sizes(sharding_specs, avals, logical_mesh_shape):\n    \"\"\"Compute bytes of avals with given sharding proto and logical\n    mesh.\"\"\"\n\n    def get_byte(shape, dtype):\n        return np.prod(shape) * np.dtype(dtype).itemsize\n\n    if len(avals) == 0:\n        return ()\n\n    if np.prod(logical_mesh_shape) == 1:\n        return tuple(get_byte(aval.shape, aval.dtype) for aval in avals)\n\n    sharded_shapes = [\n        get_shard_shape(aval, spec)\n        for aval, spec in zip(avals, sharding_specs)\n    ]\n\n    return tuple(\n        get_byte(shape, aval.dtype)\n        for shape, aval in zip(sharded_shapes, avals))\n\n\ndef get_sharded_size_by_proto(serialized_proto,\n                              avals,\n                              logical_mesh_shape,\n                              tuple_proto=True):\n    \"\"\"Compute bytes of serialized proto.\"\"\"\n\n    if len(avals) == 0:\n        return ()\n\n    if np.prod(logical_mesh_shape) == 1:\n        sharding_specs = None\n    else:\n        if tuple_proto:\n            hlo_sharding = xe.HloSharding(serialized_proto[0])\n            sharding_specs = hlo_sharding_to_sharding_spec(\n                hlo_sharding, avals, logical_mesh_shape)\n        else:\n            sharding_specs = [\n                hlo_sharding_to_sharding_spec(xe.HloSharding(proto), aval,\n                                              logical_mesh_shape)\n                for (proto, aval) in zip(serialized_proto, avals)\n            ]\n    return _get_sharded_sizes(sharding_specs, avals, logical_mesh_shape)\n\n\ndef compute_apply_grad_invar_size(input_sharding_protos,\n                                  config: ApplyGradConfig, logical_mesh_shape):\n    \"\"\"Compute the size of parameters only used in apply gradient period.\n\n    These parameters are never used in compute gradient period but stored on\n    the GPU, so they take memory and influence max_n_succ_stages.\n    \"\"\"\n    if config.invars is None:\n        assert config.apply_grad_only_invars is None\n        return 0\n    avals = [v.aval for v in config.invars]\n    if np.prod(logical_mesh_shape) == 1:\n        selected_sharding_specs = None\n        ordered_selected_vars = list(config.apply_grad_only_invars)\n    else:\n        assert len(input_sharding_protos) == len(config.invars)\n        sharding_specs = [\n            hlo_sharding_to_sharding_spec(xe.HloSharding(sharding_proto), aval,\n                                          logical_mesh_shape)\n            for sharding_proto, aval in zip(input_sharding_protos, avals)\n        ]\n        ordered_selected_vars = []\n        selected_sharding_specs = []\n        for var, spec in zip(config.invars, sharding_specs):\n            if var in config.apply_grad_only_invars:\n                ordered_selected_vars.append(var)\n                selected_sharding_specs.append(spec)\n    ordered_selected_avals = [v.aval for v in ordered_selected_vars]\n    ordered_selected_names = [repr(v) for v in ordered_selected_vars]\n    return (ordered_selected_names,\n            _get_sharded_sizes(selected_sharding_specs, ordered_selected_avals,\n                               logical_mesh_shape))\n"
  },
  {
    "path": "alpa/serialization.py",
    "content": "\"\"\"\nSerialization utilities for Alpa.\nSupport DistributedArray and ReplicatedDistributedArray serialization in Alpa.\n\"\"\"\n\nimport logging\nimport os\nimport pickle\nfrom typing import Union\n\nfrom flax.serialization import to_state_dict, from_state_dict\nimport jax\nfrom jax._src.tree_util import tree_flatten, tree_leaves, tree_unflatten, PyTreeDef\nimport msgpack\nimport numpy as np\n\nfrom alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray,\n                              get_global_virtual_physical_mesh,\n                              get_global_physical_mesh)\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\ndef _dfs_pytree(tree, prefix):\n    paths = []\n    if isinstance(tree, dict):\n        for k, v in tree.items():\n            paths += _dfs_pytree(v, prefix + \".\" + str(k))\n    elif isinstance(tree, (tuple, list)):\n        for i, v in enumerate(tree):\n            paths += _dfs_pytree(v, prefix + \".\" + str(i))\n    elif tree is not None:\n        # Leaf node\n        paths.append(prefix)\n    return paths\n\n\ndef _save_unsharded_array(ckpt_dir, arr):\n    os.makedirs(ckpt_dir, exist_ok=True)\n    shard_name = \"shard_0.0\"\n    metadata = {\n        \"global_shape\": arr.shape,\n        \"dtype\": arr.dtype,\n        \"shard_names\": [shard_name],\n        \"shard_indices\": None,\n    }\n    with open(os.path.join(ckpt_dir, shard_name), \"wb\") as datafile:\n        np.save(datafile, arr)\n    with open(os.path.join(ckpt_dir, \"metadata_0\"), \"wb\") as metafile:\n        pickle.dump(metadata, metafile)\n\n\ndef load_sharded_array(ckpt_dir, metadatas):\n    \"\"\"\n        Used by MeshHostWorker.load_tensor to first load the entire shared\n        array from disk.\n    \"\"\"\n    assert len(metadatas) > 0\n    with open(os.path.join(ckpt_dir, metadatas[0]), \"rb\") as metafile:\n        meta = pickle.load(metafile)\n    if meta[\"shard_indices\"] is None:\n        return np.load(os.path.join(ckpt_dir, meta[\"shard_names\"][0]))\n    entire_array = np.empty(meta[\"global_shape\"], meta[\"dtype\"])\n    for metadata in metadatas:\n        with open(os.path.join(ckpt_dir, metadata), \"rb\") as metafile:\n            meta = pickle.load(metafile)\n        for shard_name, shard_indice in zip(meta[\"shard_names\"],\n                                            meta[\"shard_indices\"]):\n            entire_array[shard_indice] = np.load(\n                os.path.join(ckpt_dir, shard_name))\n    return entire_array\n\n\ndef save_checkpoint(ckpt_dir: Union[str, os.PathLike],\n                    target: PyTreeDef,\n                    step: int,\n                    local_cache_dir: Union[str, os.PathLike, None] = None):\n    \"\"\"\n        Save a checkpoint of the `target` to `ckpt_dir`.\n\n        If you want to save a model which has been parallelized on multiple\n        nodes by alpa, `ckpt_dir` should be a shared filesystem path.\n        It is also recommended to provide a `local_cache_dir` on local disk\n        to speed up the saving process because `save_checkpoint` will return\n        as soon as each node has saved its shard of the model into\n        `local_cache_dir`. The DaemonMoveWorkers will then move these local\n        shards into `ckpt_dir` in the background.\n\n        If you just want to save a unparallelized model or the model is\n        parallellized on a single node, `ckpt_dir` should be a normal\n        path on local disk, and the `local_cache_dir` should be None.\n\n        Args:\n           ckpt_dir: the directory where this checkpoint will be saved.\n           target: serializable flax object, usually a trainState.\n           step: training step number or other metric number.\n           local_cache_dir: If not None, `ckpt_dir` should be a\n           shared filesystem path, and this function will return as soon as\n           the shards have been saved to this local directory. DaemonMoveWorkers\n           will move these shards into `ckpt_dir` in the background.\n    \"\"\"\n    # create directories if not exist\n    os.makedirs(ckpt_dir, exist_ok=True)\n    if local_cache_dir is not None:\n        os.makedirs(local_cache_dir, exist_ok=True)\n\n    target = to_state_dict(target)\n    flat_dirs = _dfs_pytree(target, \"state\")\n    flat_target, target_tree = tree_flatten(target)\n    flat_metadata = []\n    assert (len(flat_dirs) == len(flat_target))\n    for arr_dir, x in zip(flat_dirs, flat_target):\n        arr_path = os.path.join(ckpt_dir, arr_dir)\n        if local_cache_dir is None:\n            arr_cache_path = None\n        else:\n            arr_cache_path = os.path.join(local_cache_dir, arr_dir)\n        if isinstance(x, (DistributedArray, ReplicatedDistributedArray,\n                          np.ndarray, jax.xla.DeviceArray)):\n            if isinstance(x, DistributedArray):\n                x.save(arr_path, arr_cache_path)\n            elif isinstance(x, ReplicatedDistributedArray):\n                x.replica.save(arr_path, arr_cache_path)\n            elif isinstance(x, (np.ndarray, jax.xla.DeviceArray)):\n                _save_unsharded_array(arr_path, x)\n            flat_metadata.append(arr_dir)\n        else:\n            flat_metadata.append(x)\n\n    metapath = os.path.join(ckpt_dir, f\"checkpoint_{step}\")\n    metadata = tree_unflatten(target_tree, flat_metadata)\n    with open(metapath, \"wb\") as metafile:\n        metafile.write(msgpack.packb(metadata))\n\n\ndef restore_checkpoint(ckpt_dir: Union[str, os.PathLike], step: int,\n                       placement_specs: PyTreeDef):\n    \"\"\"\n        Restore the specified checkpoint from `ckpt_dir` and reshard it\n        according to the `placement_specs`.\n\n        Args:\n            ckpt_dir: directory of checkpoints to restore from. If you\n            do not have a shared filesystem, each host needs a copy of\n            the checkpoint on its local disk at the same path.\n            step: step number to load.\n            placement_specs: shardingSpec and deviceMesh placement info\n            for loading.\n    \"\"\"\n    metapath = os.path.join(ckpt_dir, f\"checkpoint_{step}\")\n    with open(metapath, \"rb\") as metafile:\n        metadata = from_state_dict(placement_specs,\n                                   msgpack.unpackb(metafile.read()))\n\n    state_paths, state_tree = tree_flatten(metadata)\n    flat_info = tree_leaves(placement_specs)\n    flat_load_state = []\n    mesh_group = get_global_virtual_physical_mesh().launched_physical_mesh_group\n    physical_mesh = get_global_physical_mesh()\n\n    assert mesh_group is not None or physical_mesh is not None\n\n    for path, info in zip(state_paths, flat_info):\n        if info is None:\n            logger.warning(\"Variable is not used, skip loading it\")\n            flat_load_state.append(None)\n        elif mesh_group is None:\n            dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path),\n                                             info.aval, physical_mesh,\n                                             info.sharding_specs[0])\n            flat_load_state.append(dist_arr)\n        elif len(info.mesh_ids) == 1:\n            dist_arr = DistributedArray.load(os.path.join(ckpt_dir,\n                                                          path), info.aval,\n                                             mesh_group[info.mesh_ids[0]],\n                                             info.sharding_specs[0])\n            flat_load_state.append(dist_arr)\n        else:\n            meshes, arrays = [], []\n            for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs):\n                meshes.append(mesh_group[mesh_id])\n                dist_arr = DistributedArray.load(os.path.join(ckpt_dir,\n                                                              path), info.aval,\n                                                 mesh_group[mesh_id], spec)\n                arrays.append(dist_arr)\n            flat_load_state.append(ReplicatedDistributedArray(meshes, arrays))\n\n    return tree_unflatten(state_tree, flat_load_state)\n"
  },
  {
    "path": "alpa/serve/__init__.py",
    "content": "\"\"\"Alpa serving backend\"\"\"\nfrom alpa.serve.controller import CONTROLLER_NAME, run_controller\n"
  },
  {
    "path": "alpa/serve/controller.py",
    "content": "#pylint: disable=missing-class-docstring, raise-missing-from\n\"\"\"Central controller\"\"\"\nimport asyncio\nfrom collections import defaultdict\nimport dataclasses\nimport logging\nimport os\nimport pickle\nimport socket\nimport time\nfrom typing import Callable, List, Dict, Optional, Tuple, Any, Union\n\nimport ray\nfrom ray.actor import ActorHandle\nfrom ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy\nfrom starlette.middleware.cors import CORSMiddleware\nimport uvicorn\n\nfrom alpa.api import init\nfrom alpa.serve.http_util import (HTTPRequestWrapper, receive_http_body,\n                                  Response, set_socket_reuse_port, ASGIHandler,\n                                  build_starlette_request, new_port,\n                                  RelayException, make_error_response)\n\nlogger = logging.getLogger(__file__)\n\nCONTROLLER_NAME = \"controller\"\nMAX_REPLICA_FAILURE_RETRIES = 10\nDISCONNECT_ERROR_CODE = \"disconnection\"\nSOCKET_REUSE_PORT_ENABLED = (os.environ.get(\"SERVE_SOCKET_REUSE_PORT_ENABLED\",\n                                            \"1\") == \"1\")\n\n\n@dataclasses.dataclass\nclass CreateInfo:\n    model_def: Any\n    init_args: Optional[List]\n    init_kwargs: Optional[Dict]\n\n    def append_init_args(self,\n                         init_args: Optional[List] = None,\n                         init_kwargs: Optional[Dict] = None):\n        return CreateInfo(\n            self.model_def,\n            self.init_args + init_args if init_args else self.init_args,\n            dict(self.init_kwargs).update(init_kwargs)\n            if init_kwargs else self.init_kwargs,\n        )\n\n\n@dataclasses.dataclass\nclass ModelInfo:\n    create_info: CreateInfo\n    managers: List[ActorHandle]\n    next_pt: int\n\n\n@ray.remote(num_cpus=1)\nclass DeviceMeshGroupManager:\n\n    def __init__(self, virtual_mesh_shape: Optional[Tuple[int]] = None):\n        if virtual_mesh_shape:\n            init(cluster=\"ray\",\n                 num_nodes=virtual_mesh_shape[0],\n                 num_devices_per_node=virtual_mesh_shape[1])\n        else:\n            init(cluster=\"ray\")\n\n        # Dict[str, object]\n        self.replicas = {}\n\n    def create_replica(self, name: str, create_info: CreateInfo):\n        assert name not in self.replicas\n\n        model_def, args, kwargs = (create_info.model_def, create_info.init_args,\n                                   create_info.init_kwargs)\n        args = args or []\n        kwargs = kwargs or {}\n        self.replicas[name] = model_def(*args, **kwargs)\n\n    def delete_replica(self, name: str):\n        assert name in self.replicas\n        del self.replicas[name]\n\n    async def handle_request(self, name: str, request_wrapper: bytes):\n        request_wrapper = pickle.loads(request_wrapper)\n        request = build_starlette_request(request_wrapper)\n        try:\n            response = await self.replicas[name].handle_request(request)\n            return response\n        except Exception as e:  # pylint: disable=broad-except\n            return RelayException(e)\n\n\n@ray.remote(num_cpus=0)\nclass Controller:\n\n    def __init__(self,\n                 host: str,\n                 port: int,\n                 root_path: str,\n                 ssl_keyfile: Optional[str] = None,\n                 ssl_certfile: Optional[Union[str, os.PathLike]] = None):\n        self.host = host\n        self.port = port\n        self.root_path = root_path\n        self.ssl_keyfile = ssl_keyfile\n        self.ssl_certfile = ssl_certfile\n\n        self.manager_lock = defaultdict(asyncio.Lock)\n\n        # Dict[str -> ModelInfo]\n        self.model_info = {}\n        self.mesh_group_managers = {}\n\n        # Launch http server\n        self.setup_complete = asyncio.Event()\n        self.http_server_task = asyncio.get_event_loop().create_task(\n            self.run_http_server())\n\n    async def launch_mesh_group_manager(\n            self,\n            group_id: int,\n            virtual_mesh_shape: Optional[Tuple[int]] = None,\n            num_gpus: int = 0):\n        assert group_id not in self.mesh_group_managers, (\n            f\"Mesh group {group_id} is already launched\")\n        self.mesh_group_managers[group_id] = (DeviceMeshGroupManager.options(\n            name=f\"mesh_group_manager_{group_id}\",\n            num_gpus=num_gpus).remote(virtual_mesh_shape))\n\n    async def register_model(self,\n                             name: str,\n                             model_def: Callable,\n                             init_args: Optional[List] = None,\n                             init_kwargs: Optional[Dict] = None,\n                             override: bool = False):\n        async with self.manager_lock[name]:\n            if name in self.model_info:\n                if override:\n                    for manager in self.model_info[name].managers:\n                        await manager.delete_replica.remote(name)\n                else:\n                    raise ValueError(f\"Model {name} is already registered\")\n\n            self.model_info[name] = ModelInfo(\n                CreateInfo(model_def, init_args, init_kwargs), [], 0)\n\n    async def create_replica(self,\n                             name: str,\n                             mesh_group_id: int,\n                             append_init_args: Optional[List] = None,\n                             append_init_kwargs: Optional[Dict] = None):\n        async with self.manager_lock[name]:\n            assert mesh_group_id in self.mesh_group_managers, (\n                f\"Group {mesh_group_id} does not exist\")\n            model_info = self.model_info[name]\n            manager = self.mesh_group_managers[mesh_group_id]\n            assert manager not in model_info.managers\n            create_info = model_info.create_info.append_init_args(\n                append_init_args, append_init_kwargs)\n\n            logger.info(\n                f\"Create replica of model={name} on mesh={mesh_group_id}\")\n            await manager.create_replica.remote(name, create_info)\n            model_info.managers.append(manager)\n\n    async def handle_asgi(self, scope, receive, send):\n        assert scope[\"type\"] == \"http\"\n        scope[\"tstamp\"] = time.time()\n\n        # Receive request\n        http_body_bytes = await receive_http_body(scope, receive, send)\n        request_wrapper = HTTPRequestWrapper(scope, http_body_bytes)\n        request = build_starlette_request(request_wrapper)\n        request_wrapper = pickle.dumps(request_wrapper)\n\n        # Route\n        try:\n            obj = await request.json()\n\n            assert \"model\" in obj, \"Model name is not specified in the request.\"\n            name = obj[\"model\"]\n\n            assert name in self.model_info, (\n                f\"Model '{name}' is not registered.\")\n            model_info = self.model_info[name]\n            assert model_info.managers, (\n                f\"No replica of model '{name}' is created.\")\n            manager = model_info.managers[model_info.next_pt]\n            model_info.next_pt = (model_info.next_pt + 1) % len(\n                model_info.managers)\n\n            response = await manager.handle_request.remote(\n                name, request_wrapper)\n            if isinstance(response, RelayException):\n                response = make_error_response(response)\n                status_code = 400\n            else:\n                status_code = 200\n        except Exception as e:  # pylint: disable=broad-except\n            response = make_error_response(e)\n            status_code = 400\n\n        await Response(response,\n                       status_code=status_code).send(scope, receive, send)\n\n    def get_info(self):\n        return {\n            \"host\": self.host,\n            \"port\": self.port,\n            \"root_path\": self.root_path,\n        }\n\n    ##### HTTP related functions #####\n    async def ready(self):\n        \"\"\"Returns when HTTP proxy is ready to serve traffic.\n        Or throw exception when it is not able to serve traffic.\n        \"\"\"\n        done_set, _ = await asyncio.wait(\n            [\n                # Either the HTTP setup has completed.\n                # The event is set inside self.run.\n                self.setup_complete.wait(),\n                # Or self.run errored.\n                self.http_server_task,\n            ],\n            return_when=asyncio.FIRST_COMPLETED,\n        )\n\n        # Return None, or re-throw the exception from self.running_task.\n        return await done_set.pop()\n\n    async def run_http_server(self):\n        sock = socket.socket()\n        if SOCKET_REUSE_PORT_ENABLED:\n            set_socket_reuse_port(sock)\n\n        try:\n            sock.bind((self.host, self.port))\n        except OSError:\n            # The OS failed to bind a socket to the given host and port.\n            raise ValueError(\n                f\"Failed to bind HTTP proxy to '{self.host}:{self.port}'.\"\n                f\"Please make sure your http-host and http-port are \"\n                f\"specified correctly.\")\n\n        # Note(simon): we have to use lower level uvicorn Config and Server\n        # class because we want to run the server as a coroutine. The only\n        # alternative is to call uvicorn.run which is blocking.\n        app = ASGIHandler(self)\n        app = CORSMiddleware(\n            app,\n            allow_origins=[\"*\"],\n            allow_methods=[\"*\"],\n            allow_headers=[\"*\"],\n        )\n\n        config = uvicorn.Config(\n            app,\n            host=self.host,\n            port=self.port,\n            root_path=self.root_path,\n            lifespan=\"off\",\n            access_log=False,\n            ssl_keyfile=self.ssl_keyfile,\n            ssl_certfile=self.ssl_certfile,\n        )\n        server = uvicorn.Server(config=config)\n\n        # TODO(edoakes): we need to override install_signal_handlers here\n        # because the existing implementation fails if it isn't running in\n        # the main thread and uvicorn doesn't expose a way to configure it.\n        server.install_signal_handlers = lambda: None\n\n        self.setup_complete.set()\n        await server.serve(sockets=[sock])\n\n\ndef run_controller(host,\n                   port=None,\n                   root_path=\"/\",\n                   name=CONTROLLER_NAME,\n                   ssl_keyfile: Optional[str] = None,\n                   ssl_certfile: Optional[Union[str, os.PathLike]] = None):\n    controller = Controller.options(\n        name=name,\n        scheduling_strategy=NodeAffinitySchedulingStrategy(\n            node_id=ray.get_runtime_context().node_id,\n            soft=False,\n        )).remote(\n            host=host,\n            port=port or new_port(),\n            root_path=root_path,\n            ssl_keyfile=ssl_keyfile,\n            ssl_certfile=ssl_certfile,\n        )\n    ray.get(controller.ready.remote())\n    return controller\n"
  },
  {
    "path": "alpa/serve/http_util.py",
    "content": "# pylint: skip-file\n\"\"\"\nAdopted from\nhttps://github.com/ray-project/ray/blob/master/python/ray/serve/_private/http_util.py\nhttps://github.com/ray-project/ray/blob/master/python/ray/serve/_private/utils.py\n\"\"\"\nimport asyncio\nfrom dataclasses import dataclass\nimport inspect\nimport json\nimport random\nimport socket\nimport traceback\nfrom typing import Any, Dict, Type\n\nfrom fastapi.encoders import jsonable_encoder\nimport numpy as np\nimport starlette.responses\nimport starlette.requests\nfrom starlette.types import Send, ASGIApp\n\ntry:\n    import pandas as pd\nexcept ImportError:\n    pd = None\n\n\n@dataclass\nclass HTTPRequestWrapper:\n    scope: Dict[Any, Any]\n    body: bytes\n\n\ndef build_starlette_request(request_wrapper):\n    \"\"\"Build and return a Starlette Request from ASGI payload.\n\n    This function is intended to be used immediately before task invocation\n    happens.\n    \"\"\"\n    scope, serialized_body = request_wrapper.scope, request_wrapper.body\n\n    # Simulates receiving HTTP body from TCP socket.  In reality, the body has\n    # already been streamed in chunks and stored in serialized_body.\n    received = False\n\n    async def mock_receive():\n        nonlocal received\n\n        # If the request has already been received, starlette will keep polling\n        # for HTTP disconnect. We will pause forever. The coroutine should be\n        # cancelled by starlette after the response has been sent.\n        if received:\n            block_forever = asyncio.Event()\n            await block_forever.wait()\n\n        received = True\n        return {\n            \"body\": serialized_body,\n            \"type\": \"http.request\",\n            \"more_body\": False\n        }\n\n    return starlette.requests.Request(scope, mock_receive)\n\n\nclass Response:\n    \"\"\"ASGI compliant response class.\n\n    It is expected to be called in async context and pass along\n    `scope, receive, send` as in ASGI spec.\n\n    >>> from ray.serve.http_util import Response\n    >>> scope, receive = ... # doctest: +SKIP\n    >>> await Response({\"k\": \"v\"}).send(scope, receive, send) # doctest: +SKIP\n    \"\"\"\n\n    def __init__(self, content=None, status_code=200):\n        \"\"\"Construct a HTTP Response based on input type.\n\n        Args:\n            content: Any JSON serializable object.\n            status_code (int, optional): Default status code is 200.\n        \"\"\"\n        self.status_code = status_code\n        self.raw_headers = []\n\n        if content is None:\n            self.body = b\"\"\n            self.set_content_type(\"text\")\n        elif isinstance(content, bytes):\n            self.body = content\n            self.set_content_type(\"text\")\n        elif isinstance(content, str):\n            self.body = content.encode(\"utf-8\")\n            self.set_content_type(\"text-utf8\")\n        else:\n            # Delayed import since utils depends on http_util\n            self.body = json.dumps(\n                jsonable_encoder(content,\n                                 custom_encoder=serve_encoders)).encode()\n            self.set_content_type(\"json\")\n\n    def set_content_type(self, content_type):\n        if content_type == \"text\":\n            self.raw_headers.append([b\"content-type\", b\"text/plain\"])\n        elif content_type == \"text-utf8\":\n            self.raw_headers.append(\n                [b\"content-type\", b\"text/plain; charset=utf-8\"])\n        elif content_type == \"json\":\n            self.raw_headers.append([b\"content-type\", b\"application/json\"])\n        else:\n            raise ValueError(\"Invalid content type {}\".format(content_type))\n\n    async def send(self, scope, receive, send):\n        await send({\n            \"type\": \"http.response.start\",\n            \"status\": self.status_code,\n            \"headers\": self.raw_headers,\n        })\n        await send({\"type\": \"http.response.body\", \"body\": self.body})\n\n\nasync def receive_http_body(scope, receive, send):\n    body_buffer = []\n    more_body = True\n    while more_body:\n        message = await receive()\n        assert message[\"type\"] == \"http.request\"\n\n        more_body = message[\"more_body\"]\n        body_buffer.append(message[\"body\"])\n\n    return b\"\".join(body_buffer)\n\n\nclass RawASGIResponse(ASGIApp):\n    \"\"\"Implement a raw ASGI response interface.\n\n    We have to build this because starlette's base response class is\n    still too smart and perform header inference.\n    \"\"\"\n\n    def __init__(self, messages):\n        self.messages = messages\n\n    async def __call__(self, _scope, _receive, send):\n        for message in self.messages:\n            await send(message)\n\n    @property\n    def status_code(self):\n        return self.messages[0][\"status\"]\n\n\nclass ASGIHTTPSender(Send):\n    \"\"\"Implement the interface for ASGI sender to save data from varisous\n    asgi response type (fastapi, starlette, etc.)\n    \"\"\"\n\n    def __init__(self) -> None:\n        self.messages = []\n\n    async def __call__(self, message):\n        assert message[\"type\"] in (\"http.response.start\", \"http.response.body\")\n        self.messages.append(message)\n\n    def build_asgi_response(self) -> RawASGIResponse:\n        return RawASGIResponse(self.messages)\n\n\ndef make_fastapi_class_based_view(fastapi_app, cls: Type) -> None:\n    \"\"\"Transform the `cls`'s methods and class annotations to FastAPI routes.\n\n    Modified from\n    https://github.com/dmontagu/fastapi-utils/blob/master/fastapi_utils/cbv.py\n\n    Usage:\n    >>> from fastapi import FastAPI\n    >>> app = FastAPI() # doctest: +SKIP\n    >>> class A: # doctest: +SKIP\n    ...     @app.route(\"/{i}\") # doctest: +SKIP\n    ...     def func(self, i: int) -> str: # doctest: +SKIP\n    ...         return self.dep + i # doctest: +SKIP\n    >>> # just running the app won't work, here.\n    >>> make_fastapi_class_based_view(app, A) # doctest: +SKIP\n    >>> # now app can be run properly\n    \"\"\"\n    # Delayed import to prevent ciruclar imports in workers.\n    from fastapi import Depends, APIRouter\n    from fastapi.routing import APIRoute\n\n    def get_current_servable_instance():\n        from ray import serve\n\n        return serve.get_replica_context().servable_object\n\n    # Find all the class method routes\n    class_method_routes = [\n        route for route in fastapi_app.routes if\n        # User defined routes must all be APIRoute.\n        isinstance(route, APIRoute)\n        # We want to find the route that's bound to the `cls`.\n        # NOTE(simon): we can't use `route.endpoint in inspect.getmembers(cls)`\n        # because the FastAPI supports different routes for the methods with\n        # same name. See #17559.\n        and (cls.__qualname__ in route.endpoint.__qualname__)\n    ]\n\n    # Modify these routes and mount it to a new APIRouter.\n    # We need to to this (instead of modifying in place) because we want to use\n    # the laster fastapi_app.include_router to re-run the dependency analysis\n    # for each routes.\n    new_router = APIRouter()\n    for route in class_method_routes:\n        fastapi_app.routes.remove(route)\n\n        # This block just adds a default values to the self parameters so that\n        # FastAPI knows to inject the object when calling the route.\n        # Before: def method(self, i): ...\n        # After: def method(self=Depends(...), *, i):...\n        old_endpoint = route.endpoint\n        old_signature = inspect.signature(old_endpoint)\n        old_parameters = list(old_signature.parameters.values())\n        if len(old_parameters) == 0:\n            # TODO(simon): make it more flexible to support no arguments.\n            raise RayServeException(\n                \"Methods in FastAPI class-based view must have ``self`` as \"\n                \"their first argument.\")\n        old_self_parameter = old_parameters[0]\n        new_self_parameter = old_self_parameter.replace(\n            default=Depends(get_current_servable_instance))\n        new_parameters = [new_self_parameter] + [\n            # Make the rest of the parameters keyword only because\n            # the first argument is no longer positional.\n            parameter.replace(kind=inspect.Parameter.KEYWORD_ONLY)\n            for parameter in old_parameters[1:]\n        ]\n        new_signature = old_signature.replace(parameters=new_parameters)\n        setattr(route.endpoint, \"__signature__\", new_signature)\n        setattr(route.endpoint, \"_serve_cls\", cls)\n        new_router.routes.append(route)\n    fastapi_app.include_router(new_router)\n\n    routes_to_remove = list()\n    for route in fastapi_app.routes:\n        if not isinstance(route, APIRoute):\n            continue\n\n        # If there is a response model, FastAPI creates a copy of the fields.\n        # But FastAPI creates the field incorrectly by missing the outer_type_.\n        if route.response_model:\n            original_resp_fields = route.response_field.outer_type_.__fields__\n            cloned_resp_fields = (\n                route.secure_cloned_response_field.outer_type_.__fields__)\n            for key, field in cloned_resp_fields.items():\n                field.outer_type_ = original_resp_fields[key].outer_type_\n\n        # Remove endpoints that belong to other class based views.\n        serve_cls = getattr(route.endpoint, \"_serve_cls\", None)\n        if serve_cls is not None and serve_cls != cls:\n            routes_to_remove.append(route)\n    fastapi_app.routes[:] = [\n        r for r in fastapi_app.routes if r not in routes_to_remove\n    ]\n\n\ndef set_socket_reuse_port(sock: socket.socket) -> bool:\n    \"\"\"Mutate a socket object to allow multiple process listening on the same port.\n\n    Returns:\n        success: whether the setting was successful.\n    \"\"\"\n    try:\n        # These two socket options will allow multiple process to bind the the\n        # same port. Kernel will evenly load balance among the port listeners.\n        # Note: this will only work on Linux.\n        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        if hasattr(socket, \"SO_REUSEPORT\"):\n            sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)\n        # In some Python binary distribution (e.g., conda py3.6), this flag\n        # was not present at build time but available in runtime. But\n        # Python relies on compiler flag to include this in binary.\n        # Therefore, in the absence of socket.SO_REUSEPORT, we try\n        # to use `15` which is value in linux kernel.\n        # https://github.com/torvalds/linux/blob/master/tools/include/uapi/asm-generic/socket.h#L27\n        else:\n            sock.setsockopt(socket.SOL_SOCKET, 15, 1)\n        return True\n    except Exception as e:\n        logger.debug(\n            f\"Setting SO_REUSEPORT failed because of {e}. SO_REUSEPORT is disabled.\"\n        )\n        return False\n\n\ndef new_port(lower_bound=10000, upper_bound=65535, denylist=None):\n    if not denylist:\n        denylist = set()\n    port = random.randint(lower_bound, upper_bound)\n    retry = 0\n    while port in denylist:\n        if retry > 100:\n            break\n        port = random.randint(lower_bound, upper_bound)\n        retry += 1\n    if retry > 100:\n        raise ValueError(\"Failed to find a new port from the range \"\n                         f\"{lower_bound}-{upper_bound}. Denylist: {denylist}\")\n    return port\n\n\nclass _ServeCustomEncoders:\n    \"\"\"Group of custom encoders for common types that's not handled by FastAPI.\"\"\"\n\n    @staticmethod\n    def encode_np_array(obj):\n        assert isinstance(obj, np.ndarray)\n        if obj.dtype.kind == \"f\":  # floats\n            obj = obj.astype(float)\n        if obj.dtype.kind in {\"i\", \"u\"}:  # signed and unsigned integers.\n            obj = obj.astype(int)\n        return obj.tolist()\n\n    @staticmethod\n    def encode_np_scaler(obj):\n        assert isinstance(obj, np.generic)\n        return obj.item()\n\n    @staticmethod\n    def encode_exception(obj):\n        assert isinstance(obj, Exception)\n        return str(obj)\n\n    @staticmethod\n    def encode_pandas_dataframe(obj):\n        assert isinstance(obj, pd.DataFrame)\n        return obj.to_dict(orient=\"records\")\n\n\nserve_encoders = {\n    np.ndarray: _ServeCustomEncoders.encode_np_array,\n    np.generic: _ServeCustomEncoders.encode_np_scaler,\n    Exception: _ServeCustomEncoders.encode_exception,\n}\n\nif pd is not None:\n    serve_encoders[pd.DataFrame] = _ServeCustomEncoders.encode_pandas_dataframe\n\n\nclass ASGIHandler:\n\n    def __init__(self, controller):\n        self.controller = controller\n\n    async def __call__(self, scope, receive, send):\n        \"\"\"Implements the ASGI protocol.\n\n        See details at:\n            https://asgi.readthedocs.io/en/latest/specs/index.html.\n        \"\"\"\n        await self.controller.handle_asgi(scope, receive, send)\n\n\nclass RelayException:\n\n    def __init__(self, e):\n        self.e = str(e)\n        self.stacktrace = \"\".join(traceback.format_tb(e.__traceback__))\n\n\ndef make_error_response(e):\n    if isinstance(e, RelayException):\n        msg = str(e.e)\n        stacktrace = e.stacktrace\n    else:\n        msg = str(e)\n        stacktrace = \"\".join(traceback.format_tb(e.__traceback__))\n\n    return {\"type\": \"error\", \"message\": msg, \"stacktrace\": stacktrace}\n"
  },
  {
    "path": "alpa/serve/run.py",
    "content": "\"\"\"Run a controller.\"\"\"\nimport argparse\n\nimport ray\n\nfrom alpa.serve.controller import run_controller\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--host\", type=str, default=\"localhost\")\n    parser.add_argument(\"--port\", type=int)\n    parser.add_argument(\"--root-path\", type=str, default=\"/\")\n    args = parser.parse_args()\n\n    ray.init(address=\"auto\", namespace=\"alpa_serve\")\n    controller = run_controller(args.host, args.port, args.root_path)\n\n    while True:\n        pass\n"
  },
  {
    "path": "alpa/shard_parallel/__init__.py",
    "content": ""
  },
  {
    "path": "alpa/shard_parallel/auto_sharding.py",
    "content": "\"\"\"Use the auto sharding pass in XLA.\n\nThe compilation passes and status of an HloModule:\n\nUNOPTIMIZED\n  |\n  |  spmd_simplification passes\n  |\n  |  auto_sharding pass\n  V\nSHARDING_ANNOTATED\n  |\n  |  spmd partitioner pass\n  V\nSPMD_PARTITIONED\n  |\n  |  HLO optimization passes\n  V\nFULLY_OPTIMIZED\n\"\"\"\nimport dataclasses\nimport logging\nimport multiprocessing\nimport os\nimport time\nimport traceback\nfrom typing import Sequence, Optional, Union, Tuple\nimport warnings\n\nimport numpy as np\nfrom jax._src.lib import xla_client as xc, xla_extension as xe\nfrom jax.core import ShapedArray\nfrom jax.interpreters import pxla\n\nfrom alpa.global_env import global_config\nfrom alpa.parallel_plan import StagePlan\nfrom alpa.timer import timers\nfrom alpa.util import check_arithmetic_sequence, get_compile_options, XlaPassContext\nfrom alpa.wrapped_hlo import HloStatus, WrappedHlo\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n# A constant to represent infinity\nINFINITY_COST = 1e13\n\n\n@dataclasses.dataclass\nclass AutoShardingOption:\n    \"\"\"Options of the auto-sharding solver.\"\"\"\n    # Whether enable auto-sharding. If it is False, then the solver\n    # does tho run ILP but only uses the ShardingPropagation pass.\n    enable_auto_sharding: bool = True\n    # Whether to allow all-gather during re-sharding.\n    allow_all_gather: bool = True\n    # Whether to allow all-to-all during re-sharding.\n    allow_all_to_all: bool = True\n    # Whether to allow replicated parameters.\n    allow_replicated_parameters: bool = True\n    # Whether to forcibly generate data-parallel.\n    force_data_parallel: bool = False\n    # Forcibly map the batch dimension to a mesh dimension.\n    force_batch_dim_to_mesh_dim: Optional[int] = None\n    # Whether to forcibly generate a strategy similar to ZeRO optimizer stage 3.\n    force_zero_stage_3: bool = False\n    # The threshold of all-gather combiner if force_zero_stage_3 is true.\n    force_zero_stage_3_all_gather_threshold: int = 1 << 25\n    # Prefer reduce-scatter over all-reduce.\n    prefer_reduce_scatter: bool = False\n    # Allow mixed 1d mesh and 2d mesh shape.\n    allow_mixed_mesh_shape: bool = False\n    # Allow replicated dot computation.\n    allow_recompute_heavy_op: bool = False\n    # If it is not empty, forcibly use a simple heuristic instead of the ILP\n    # solver.\n    force_simple_heuristic: str = \"\"\n    # The threshold of all-reduce combiner in bytes.\n    all_reduce_threshold: int = 1 << 60\n\n\nclass LogicalDeviceMesh:\n    \"\"\"A logical view of a physical mesh. The logical view is used in the\n    auto-sharding pass.\n\n    A physical mesh can have multiple logical views. (e.g., a 2x8 physical mesh\n    can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its\n    own latency and bandwidth. We use alpha-beta model to model the\n    communication cost.\n    \"\"\"\n\n    def __init__(self, physical_mesh, id_mesh, mesh_alpha=None, mesh_beta=None):\n        self.physical_mesh = physical_mesh\n        self.id_mesh = np.array(id_mesh)\n        self.flatten_ids = tuple(int(x) for x in self.id_mesh.flatten())\n\n        # coefficient for alpha-beta communication model\n        if mesh_alpha is None:\n            mesh_alpha = [1] * len(self.id_mesh.shape)\n        if mesh_beta is None:\n            mesh_beta = [1] * len(self.id_mesh.shape)\n        self.mesh_alpha = tuple(mesh_alpha)\n        self.mesh_beta = tuple(mesh_beta)\n\n    @property\n    def shape(self):\n        return self.id_mesh.shape\n\n    @property\n    def num_devices(self):\n        return np.prod(self.id_mesh.shape)\n\n    def flatten(self):\n        \"\"\"\n        Flatten the logical mesh into an effective 1d logical mesh,\n        \"\"\"\n        return LogicalDeviceMesh(\n            self.physical_mesh, self.id_mesh.reshape(-1, 1),\n            [max(self.mesh_alpha), max(self.mesh_alpha)],\n            [min(self.mesh_beta), min(self.mesh_beta)])\n\n    def all_gather_cost(self, num_bytes, mesh_dim):\n        num_devices = self.id_mesh.shape[mesh_dim]\n        return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *\n                (num_devices - 1) / num_devices * num_bytes + 0.1)\n\n    def all_reduce_cost(self, num_bytes, mesh_dim):\n        num_devices = self.id_mesh.shape[mesh_dim]\n        return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * 2 *\n                (num_devices - 1) / num_devices * num_bytes + 0.01)\n\n    def reduce_scatter_cost(self, num_bytes, mesh_dim):\n        num_devices = self.id_mesh.shape[mesh_dim]\n        return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *\n                (num_devices - 1) / num_devices * num_bytes + 0.001)\n\n    def all_to_all_cost(self, num_bytes, mesh_dim):\n        num_devices = self.id_mesh.shape[mesh_dim]\n        penalty_factor = num_devices / 2.0\n        return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] *\n                (num_devices - 1) / num_devices / num_devices * num_bytes *\n                penalty_factor + 0.001)\n\n    def make_tile_spec(self, array, tensor_dims, mesh_dims):\n        shape = array.shape\n        sharding = [\n            pxla.NoSharding(),\n        ] * len(shape)\n        mesh_mapping = [\n            None,\n        ] * len(self.id_mesh.shape)\n\n        for i, (tensor_dim, mesh_dim) in enumerate(zip(tensor_dims, mesh_dims)):\n            sharding[tensor_dim] = pxla.Chunked([self.id_mesh.shape[mesh_dim]],)\n            mesh_mapping[mesh_dim] = pxla.ShardedAxis(i)\n\n        for i, mapping in enumerate(mesh_mapping):\n            if mapping is None:\n                mesh_mapping[i] = pxla.Replicated(self.id_mesh.shape[i])\n\n        return pxla.ShardingSpec(sharding, mesh_mapping)\n\n    def __hash__(self):\n        return hash((self.flatten_ids, self.id_mesh.shape, self.mesh_alpha,\n                     self.mesh_beta))\n\n    def __eq__(self, other):\n        return ((self.flatten_ids, self.id_mesh.shape, self.mesh_alpha,\n                 self.mesh_beta) == (other.flatten_ids, other.id_mesh.shape,\n                                     other.mesh_alpha, other.mesh_beta))\n\n\ndef run_auto_sharding_pass(\n        hlo: WrappedHlo,\n        logical_mesh: LogicalDeviceMesh,\n        return_mode: str,\n        num_micro_batches: int,\n        as_option: AutoShardingOption,\n        rewrite_for_grad_acc: bool = False,\n        rewrite_grad_acc_indices: Optional[Sequence[int]] = None,\n        memory_budget_per_device: Optional[float] = None):\n    \"\"\"Run the auto-sharding pass to annotate sharding specs for an XLA\n    Computation.\n\n    Args:\n      hlo: The hlo module got by tracing the jax function,\n        whose status should be UNOPTIMIZED.\n      logical_mesh: The logical device mesh.\n      return_mode: The mode of return value.\n        The choices are {\"single\", \"stages\", \"stage_and_hook_protos\"}.\n        If it is \"single\", return a single WrappedHlo, whose status is\n          SHARDING_ANNOTATED.\n        If it is \"stages\", return WrappedHlo of multiple pipeline stages,\n          whose statuses are SHARDING_ANNOTATED.\n        If it is \"stages_and_hook\", return WrappedHlos of multiple pipeline\n          stages and the hooked hlo sharding. The statuses of the returned\n          WrappedHlos are SHARDING_ANNOTATED.\n      num_micro_batches: The number of micro batches\n        if gradient accumulation is used. If this is set, the cost of all-reduce\n        for gradient synchronization is divided by this number.\n      as_option: The options of the auto-sharding solver.\n      rewrite_for_grad_acc: Whether to do rewriting for gradient accumulation.\n      rewrite_grad_acc_indices: The indices of tensors in output that are\n        gradients.\n      memory_budget_per_device: The memory budget per device in bytes.\n    \"\"\"\n    # pylint: disable=unused-argument\n    # Set compile options\n    if memory_budget_per_device is None:\n        memory_budget_per_device = -1\n    assert hlo.is_unoptimized()\n\n    multiple_stages = return_mode in [\"stages\", \"stages_and_hook\"]\n    num_devices = logical_mesh.num_devices\n    build_random_seed = global_config.compile_random_seed\n    compile_options = get_compile_options(\n        num_replicas=1,\n        num_partitions=num_devices,\n        device_assignment=np.arange(num_devices).reshape((1, -1)),\n        use_spmd_partitioning=True,\n        parameter_is_tupled_arguments=False,\n        build_random_seed=build_random_seed,\n        spmd_propagation_to_outputs=hlo.is_manually_annotated)\n\n    # Set configs for force_zero_stage_3\n    if as_option.force_zero_stage_3:\n        # Generate a strategy similar to ZeRO stage 3\n        force_data_parallel = True\n        prefer_reduce_scatter = True\n        reduce_scatter_aggressive_partition = True\n        all_gather_threshold = as_option.force_zero_stage_3_all_gather_threshold\n    else:\n        # Use default settings\n        force_data_parallel = as_option.force_data_parallel\n        prefer_reduce_scatter = as_option.prefer_reduce_scatter\n        reduce_scatter_aggressive_partition = False\n        all_gather_threshold = 1 << 60\n\n    # Set configs for force_data_parallel\n    if force_data_parallel:\n        # Forcibly generate data-parallel strategy\n        allow_all_gather = False\n        allow_all_to_all = False\n\n        logical_mesh = logical_mesh.flatten()\n        force_batch_dim_to_mesh_dim = 0\n    else:\n        # Use default settings\n        allow_all_gather = as_option.allow_all_gather\n        allow_all_to_all = as_option.allow_all_to_all\n\n        if as_option.force_batch_dim_to_mesh_dim is None:\n            # Automatically set force_batch_dim_to_mesh_dim\n            if logical_mesh.shape[0] > 1 and logical_mesh.shape[1] > 1:\n                # In 2d mesh, force the batch tensor dim to match the first\n                # mesh dim\n                force_batch_dim_to_mesh_dim = 0\n            else:\n                force_batch_dim_to_mesh_dim = -1\n        else:\n            force_batch_dim_to_mesh_dim = as_option.force_batch_dim_to_mesh_dim\n\n    # Set configs for reduce-scatter\n    reduce_scatter_grad_acc_friendly = (num_micro_batches is not None and\n                                        num_micro_batches > 1)\n\n    # Set configs for gradient accumulation rewrite pass\n    if rewrite_for_grad_acc and rewrite_grad_acc_indices is None:\n        rewrite_grad_acc_indices = tuple(\n            range(len(hlo.program_shape().result_shape().tuple_shapes())))\n\n    # Temporarily disable this.\n    grad_acc_num_micro_batches = None\n\n    with XlaPassContext({\n            # Auto-sharding solver options\n            \"auto_sharding::enable\":\n                as_option.enable_auto_sharding,\n            \"auto_sharding::memory_budget_per_device\":\n                memory_budget_per_device,\n            \"auto_sharding::force_all_gather_cost\":\n                not allow_all_gather,\n            \"auto_sharding::all_gather_cost\":\n                INFINITY_COST,\n            \"auto_sharding::force_all_to_all_cost\":\n                not allow_all_to_all,\n            \"auto_sharding::all_to_all_cost\":\n                INFINITY_COST,\n            \"auto_sharding::allow_replicated_parameters\":\n                as_option.allow_replicated_parameters,\n            \"auto_sharding::prefer_reduce_scatter\":\n                prefer_reduce_scatter,\n            \"auto_sharding::reduce_scatter_grad_acc_friendly\":\n                reduce_scatter_grad_acc_friendly,\n            \"auto_sharding::reduce_scatter_aggressive_partition\":\n                reduce_scatter_aggressive_partition,\n            \"auto_sharding::batch_matmul_always_split_batch\":\n                True,\n            \"auto_sharding::allow_recompute_heavy_op\":\n                as_option.allow_recompute_heavy_op,\n            \"auto_sharding::allow_mixed_mesh_shape\":\n                as_option.allow_mixed_mesh_shape,\n            \"auto_sharding::grad_acc_num_micro_batches\":\n                grad_acc_num_micro_batches or 1,\n            \"auto_sharding::force_batch_dim_to_mesh_dim\":\n                force_batch_dim_to_mesh_dim,\n            \"auto_sharding::force_simple_heuristic\":\n                as_option.force_simple_heuristic,\n\n            # Device mesh\n            \"auto_sharding::device_mesh_ids\":\n                logical_mesh.flatten_ids,\n            \"auto_sharding::device_mesh_shape\":\n                tuple(logical_mesh.shape),\n            \"auto_sharding::device_mesh_alpha\":\n                tuple(float(x) for x in logical_mesh.mesh_alpha),\n            \"auto_sharding::device_mesh_beta\":\n                tuple(float(x) for x in logical_mesh.mesh_beta),\n            \"auto_sharding::device_mesh_prof_result\":\n                getattr(logical_mesh.physical_mesh, \"prof_result\", None),\n\n            # Gradient accumulation rewrite\n            \"auto_sharding::rewrite_for_grad_acc\":\n                rewrite_for_grad_acc,\n            \"auto_sharding::rewrite_indices\":\n                rewrite_grad_acc_indices,\n\n            # Communication combiner options\n            \"combiner::all_gather_threshold\":\n                all_gather_threshold,\n            \"combiner::all_reduce_threshold\":\n                as_option.all_reduce_threshold,\n\n            # Debug options\n            \"auto_sharding::simplify_graph\":\n                True,\n            \"auto_sharding::print_strategy\":\n                os.environ.get(\"ALPA_DEBUG_PRINT_AS_STRATEGY\", \"False\").lower()\n                in [\"true\", \"1\"],\n            \"auto_sharding::force_strategy\":\n                False,\n            \"auto_sharding::force_strategy_inst_indices\": [],\n            \"auto_sharding::force_strategy_stra_names\": [],\n    }):\n        timers(\"auto-sharding\").start()\n        xe.run_auto_sharding(hlo.get_module(), compile_options)\n        timers(\"auto-sharding\").stop()\n    hlo.status = HloStatus.SHARDING_ANNOTATED\n\n    if multiple_stages:\n        hlo_stage_names, hlo_stages = get_auto_sharded_hlo_stages()\n        hooked_proto = get_hooked_sharding_protos()\n        hlo_stages = [\n            WrappedHlo(stage, HloStatus.SHARDING_ANNOTATED)\n            for stage in hlo_stages\n        ]\n\n    stage_plan = StagePlan(build_random_seed, logical_mesh.shape,\n                           all_gather_threshold, as_option.all_reduce_threshold,\n                           as_option, last_s_val, last_objective)\n\n    if return_mode == \"single\":\n        return hlo, stage_plan\n    elif return_mode == \"stages\":\n        return hlo_stage_names, hlo_stages, stage_plan\n    elif return_mode == \"stages_and_hook\":\n        return hlo_stage_names, hlo_stages, hooked_proto, stage_plan\n    else:\n        raise ValueError(\"Invalid return mode: \" + return_mode)\n\n\ndef run_spmd_partitioner_pass(\n        hlo: WrappedHlo,\n        num_devices: int,\n        rewrite_for_grad_acc: bool = False,\n        rewrite_grad_acc_indices: Optional[Sequence[int]] = None):\n    \"\"\"Run SPMD partitioner pass on a sharding annotated HLO Module.\n\n    Args:\n      hlo: The wrapped HLO module, whose status should be SHARDING_ANNOTATED.\n      num_devices: The total number of devices.\n      rewrite_for_grad_acc: Whether to do rewriting for gradient accumulation.\n      rewrite_grad_acc_indices: The indices of tensors in output that are\n        gradients.\n    \"\"\"\n    assert hlo.is_sharding_annotated(), hlo.status\n    compile_options = get_compile_options(\n        num_replicas=1,\n        num_partitions=num_devices,\n        device_assignment=np.arange(num_devices).reshape((1, -1)),\n        use_spmd_partitioning=True,\n        parameter_is_tupled_arguments=False,\n        build_random_seed=global_config.compile_random_seed)\n\n    if rewrite_for_grad_acc and rewrite_grad_acc_indices is None:\n        rewrite_grad_acc_indices = tuple(\n            range(len(hlo.program_shape().result_shape().tuple_shapes())))\n\n    with XlaPassContext({\n            # Gradient accumulation rewrite\n            \"auto_sharding::rewrite_for_grad_acc\": rewrite_for_grad_acc,\n            \"auto_sharding::rewrite_indices\": rewrite_grad_acc_indices,\n    }):\n        xe.run_spmd_partitioner(hlo.get_module(), compile_options)\n    hlo.status = HloStatus.SPMD_PARTITIONED\n\n    return hlo\n\n\ndef run_backend_compilation(backend: xe.Client,\n                            hlo: WrappedHlo,\n                            stage_plan: StagePlan,\n                            num_devices: int,\n                            bypass_device_assignment_check: bool = False):\n    \"\"\"Compile a spmd partitioned Hlo Module to an XLA executable.\n\n    Args:\n      backend: The XLA backend client.\n      hlo: The Wrapped input HLO.\n      stage_plan: The auto-sharding strategy solution.\n      num_devices: The total number of devices.\n      bypass_device_assignment_check: Whether to compile without exact devices.\n    \"\"\"\n    assert hlo.is_spmd_partitioned() or hlo.is_sharding_annotated()\n    compile_options = get_compile_options(\n        num_replicas=1,\n        num_partitions=num_devices,\n        device_assignment=np.arange(num_devices).reshape((1, -1)),\n        use_spmd_partitioning=hlo.is_sharding_annotated(),\n        parameter_is_tupled_arguments=False,\n        build_random_seed=stage_plan.build_random_seed)\n\n    with XlaPassContext({\n            # Build options\n            \"build_option::bypass_device_assignment_check\":\n                bypass_device_assignment_check,\n\n            # Communication combiner options\n            \"combiner::all_gather_threshold\":\n                stage_plan.all_gather_threshold,\n            \"combiner::all_reduce_threshold\":\n                stage_plan.all_reduce_threshold,\n            \"done-event::enable\":\n                global_config.enable_overlapping,\n    }):\n        compiled = backend.compile(hlo.get_computation(), compile_options)\n\n    return compiled\n\n\ndef get_input_output_sharding_specs(\n    hlo_module: xe.HloModule, avals: Sequence[ShapedArray],\n    out_avals: Sequence[ShapedArray], num_devices: int,\n    logical_mesh_shape: Sequence[int]\n) -> Tuple[Sequence[pxla.ShardingSpec], Sequence[pxla.ShardingSpec]]:\n    \"\"\"Get the sharding specs of input/output tensors from an HloModule.\n\n    Args:\n      hlo: The sharded HLO module.\n      avals: The abstract values of input tensors.\n      out_avals: The abstract values of output tensors.\n      num_devices: The total number of devices.\n      logical_mesh_shape: The shape of logical mesh.\n\n    Returns:\n      input_sharding_specs: The sharding specs of input tensors.\n      output_sharding_specs: The sharding specs of output tensors.\n    \"\"\"\n    if num_devices != 1:\n        input_shardings = hlo_module.spmd_parameters_shardings()\n        input_sharding_specs = [\n            hlo_sharding_to_sharding_spec(proto, aval, logical_mesh_shape)\n            for (proto, aval) in zip(input_shardings, avals)\n        ]\n        output_shardings = hlo_module.spmd_output_sharding()\n        output_sharding_specs = hlo_sharding_to_sharding_spec(\n            output_shardings, out_avals, logical_mesh_shape)\n    else:\n        # The spmd partition related code will be bypassed if\n        # num_partitions == 1.\n        # Assume all sharding specs are replicated.\n        input_sharding_specs = [\n            make_replicated_spec(aval, logical_mesh_shape) for aval in avals\n        ]\n        output_sharding_specs = [\n            make_replicated_spec(aval, logical_mesh_shape) for aval in out_avals\n        ]\n    return input_sharding_specs, output_sharding_specs\n\n\ndef _hlo_sharding_to_sharding_spec_no_tuple(\n        proto: xc.OpSharding, aval: ShapedArray,\n        logical_mesh: Sequence[int]) -> pxla.ShardingSpec:\n    \"\"\"The internal function of hlo_sharding_to_sharding_spec.\"\"\"\n    sharding_type, tile_assignment_dimensions, tile_assignment_devices = (\n        proto.type, proto.tile_assignment_dimensions,\n        proto.tile_assignment_devices)\n\n    sharding = []\n    mesh_mapping = []\n    if sharding_type == xc.OpSharding.Type.OTHER:\n        tile_assignment = np.array(tile_assignment_devices).reshape(\n            tile_assignment_dimensions)\n\n        tile_dims = []\n        for i in range(len(tile_assignment_dimensions)):\n            if tile_assignment_dimensions[i] != 1:\n                tile_dims.append(i)\n\n        tile_dims_delta = []\n        success = True\n        for dim in tile_dims:\n            indices = tuple(0 if i != dim else slice(None)\n                            for i in range(tile_assignment.ndim))\n            device_ids = tile_assignment[indices]\n            delta = check_arithmetic_sequence(device_ids)\n            if delta is None:\n                success = False\n                break\n            tile_dims_delta.append(delta)\n\n        if success:\n            tile_dims_order = list(range(len(tile_dims)))\n            tile_dims_order.sort(key=lambda i: -tile_dims_delta[i])\n\n            ct = 0\n            for i in range(len(aval.shape)):\n                if tile_assignment_dimensions[i] == 1:\n                    sharding.append(pxla.NoSharding())\n                else:\n                    sharding.append(\n                        pxla.Chunked([tile_assignment_dimensions[i]]))\n                    mesh_mapping.append(pxla.ShardedAxis(ct))\n                    ct += 1\n\n            if len(tile_dims) > len(mesh_mapping):\n                # replicate on the last tile dim\n                mesh_mapping.append(\n                    pxla.Replicated(tile_assignment_dimensions[-1]))\n\n            mesh_mapping = [mesh_mapping[idx] for idx in tile_dims_order]\n        else:\n            # The normal path fails, because one tensor dim is chunked into\n            # mutliple parts. We only handle a special case here.\n            assert len(aval.shape) == 1, \"Only support 1d case\"\n            assert len(tile_assignment_dimensions) == len(aval.shape)\n            for col in range(len(tile_assignment_devices)):\n                if tile_assignment_devices[col] == 1:\n                    break\n            sharding = (pxla.Chunked(\n                (tile_assignment_dimensions[0] // col, col)),)\n            mesh_mapping = (pxla.ShardedAxis(1), pxla.ShardedAxis(0))\n    elif sharding_type == xc.OpSharding.Type.REPLICATED:\n        sharding = (pxla.NoSharding(),) * len(aval.shape)\n        mesh_mapping = (pxla.Replicated(np.prod(logical_mesh.shape)),)\n    else:\n        raise NotImplementedError(\"Type: \" + str(sharding_type))\n\n    return pxla.ShardingSpec(sharding, mesh_mapping)\n\n\ndef hlo_sharding_to_sharding_spec(\n        hlo_sharding: \"xe.HloSharding\", aval: Union[Sequence[ShapedArray],\n                                                    ShapedArray],\n        logical_mesh_shape: Sequence[int]) -> pxla.ShardingSpec:\n    \"\"\"Convert hlo sharding to sharding spec.\"\"\"\n    logical_mesh = LogicalDeviceMesh(\n        None,\n        np.arange(np.prod(logical_mesh_shape)).reshape(logical_mesh_shape))\n    proto = hlo_sharding.to_proto()\n    sharding_type, tuple_shardings = proto.type, proto.tuple_shardings\n    if sharding_type == xc.OpSharding.Type.TUPLE:\n        avals = aval\n        return [\n            _hlo_sharding_to_sharding_spec_no_tuple(shard, aval, logical_mesh)\n            for (shard, aval) in zip(tuple_shardings, avals)\n        ]\n    else:\n        return _hlo_sharding_to_sharding_spec_no_tuple(proto, aval,\n                                                       logical_mesh)\n\n\ndef make_replicated_spec(\n        aval: ShapedArray,\n        logical_mesh_shape: Sequence[int]) -> pxla.ShardingSpec:\n    \"\"\"Make a replicated ShardingSpec.\"\"\"\n    sharding = (pxla.NoSharding(),) * len(aval.shape)\n    mesh_mapping = (pxla.Replicated(np.prod(logical_mesh_shape)),)\n    return pxla.ShardingSpec(sharding, mesh_mapping)\n\n\ndef call_solver_serialized_args(*args):\n    \"\"\"Call the solver with serialized arguments and handle python errors.\"\"\"\n    info = \"\"\n    try:\n        ret = _call_solver_serialized_args(*args)\n    except AssertionError:\n        ret = None\n        info = str(traceback.format_exc()[:-1])\n    except Exception:  # pylint: disable=broad-except\n        ret = None\n        info = str(traceback.format_exc()[:-1])\n\n    if ret is None:\n        print(info)\n\n    return ret\n\n\n# The last solution vector of auto sharding.\nlast_s_val = None\n\n# The last objective value of the best ILP solution.\nlast_objective = None\n\n\n# pylint: disable=import-outside-toplevel\ndef _call_solver_serialized_args(N,\n                                 M,\n                                 s_len_np,\n                                 s_follow_np,\n                                 E_np,\n                                 A_np,\n                                 L_np,\n                                 c_np,\n                                 d_np,\n                                 m_np,\n                                 r_np,\n                                 v_np,\n                                 s_init_np=None):\n    \"\"\"Call the solver with serialized arguments.\"\"\"\n    # pylint: disable=invalid-name\n    global last_s_val, last_objective\n\n    import pulp\n    from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus\n    tic = time.time()\n\n    for x in [s_len_np, E_np, A_np, L_np, c_np, d_np, m_np, r_np, v_np]:\n        assert isinstance(x, np.ndarray)\n    assert len(s_len_np) == N, \"s_len_np\"\n\n    # Dump arguments for re-solving\n    # pickle.dump([N, M, s_len_np, s_follow_np, E_np, A_np, L_np,\n    #              c_np, d_np, m_np, r_np, v_np, s_init_np],\n    #              open(\"args.pkl\", \"wb\"))\n    # TODO(lmzheng): cache the ILP solution.\n\n    def get_non_zero_index(binary_vector):\n        \"\"\"Get the index of non-zero item in a vector.\"\"\"\n        ct = 0\n        ret = None\n        for i, elem in enumerate(binary_vector):\n            if pulp.value(elem):\n                ret = i\n                ct += 1\n\n        assert ct == 1\n        return ret\n\n    # 0. Unpack flatten numpy arrays\n    s_len = s_len_np\n    s_follow = s_follow_np\n\n    E = E_np.reshape((-1, 2))  # noqa\n    r = []\n    pt = 0\n    edge_set = set()\n    for (i, j) in E:\n        prod_length = s_len[i] * s_len[j]\n\n        if (i, j) in edge_set:\n            raise ValueError(f\"Duplicated edges: {(i, j)}\")\n\n        edge_set.add((i, j))\n        r.append(r_np[pt:pt + prod_length])\n        pt += prod_length\n    assert pt == len(r_np)\n\n    A = A_np.reshape((-1, 2))  # noqa\n    v = []\n    pt = 0\n    for (i, j) in A:\n        prod_length = s_len[i] * s_len[j]\n        v.append(v_np[pt:pt + prod_length])\n        pt += prod_length\n    assert pt == len(v_np)\n\n    L = []  # noqa\n    pt = N\n    for i in range(N):\n        length = L_np[i]\n        L.append(L_np[pt:pt + length])\n        pt += length\n    assert pt == len(L_np)\n\n    c = []\n    d = []\n    m = []\n    pt = 0\n    for i in range(N):\n        length = s_len[i]\n        c.append(c_np[pt:pt + length])\n        d.append(d_np[pt:pt + length])\n        m.append(m_np[pt:pt + length])\n        pt += length\n    assert pt == len(c_np), f\"{pt} == {len(c_np)}\"\n    assert pt == len(d_np), f\"{pt} == {len(d_np)}\"\n    assert pt == len(m_np), f\"{pt} == {len(m_np)}\"\n\n    # 1. Create variables\n    s = []\n    e = []\n\n    num_nodes = 0\n    reverse_follow_backpatch = []\n    for i in range(N):\n        if s_follow[i] < 0:\n            if s_len[i] == 1:\n                s.append([1])\n            else:\n                num_nodes += 1\n                s.append(\n                    LpVariable.matrix(f\"s[{i}]\", (range(s_len[i]),),\n                                      cat=\"Binary\"))\n        else:\n            if s_follow[i] < len(s):\n                s.append(s[s_follow[i]])\n            else:\n                s.append(None)\n                reverse_follow_backpatch.append(i)\n\n    for i in reverse_follow_backpatch:\n        s[i] = s[s_follow[i]]\n\n    num_edges = 0\n    for (idx, (i, j)) in enumerate(E):\n        if len(s[i]) == 1:\n            e.append(s[j])\n        elif len(s[j]) == 1:\n            e.append(s[i])\n        else:\n            num_edges += 1\n            e.append(\n                LpVariable.matrix(f\"e[{i},{j}]\",\n                                  (range(len(s[i]) * len(s[j])),),\n                                  cat=\"Binary\"))\n        assert len(e[idx]) == len(r[idx])\n\n    # 2. Set initial value for warm start\n    if s_init_np is not None:\n        s_init = s_init_np.reshape((-1, 3))\n        for (idx, value, fix) in s_init:\n            for i in range(len(s[idx])):\n                s[idx][i].setInitialValue(i == value)\n                if fix:\n                    s[idx][i].fixValue()\n\n    # 3. Objective\n    prob = LpProblem(\"myProblem\", LpMinimize)\n    # compute cost\n    obj = 0\n    for i in range(N):\n        obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])\n\n    # communication cost\n    for i in range(len(E)):\n        obj += lpDot(e[i], r[i])\n\n    prob += obj\n\n    # 4. Constraints\n    # (a). specified by `cat=\"Binary\"`\n\n    # (b)\n    for i in range(N):\n        if s_follow[i] < 0:\n            prob += lpSum(s[i]) == 1\n\n    # (c)\n    if M > 0:\n        for t in range(N):\n            mem = 0\n            for i in L[t]:\n                mem += lpSum(s[i][j] * m[i][j] for j in range(len(s[i])))\n            prob += mem <= M\n\n    # (d). specified by `cat=\"Binary\"`\n\n    for (idx, (i, j)) in enumerate(E):\n        if s_len[i] == 1 or s_len[j] == 1:\n            continue\n\n        # (e)\n        prob += lpSum(e[idx]) == 1\n\n        # (f)\n        for row in range(len(s[i])):\n            C = len(s[j])  # noqa\n            prob += lpSum(\n                e[idx][row * C + col] for col in range(0, C)) <= s[i][row]\n\n        # (g)\n        for col in range(len(s[j])):\n            R = len(s[i])  # noqa\n            C = len(s[j])  # noqa\n            prob += lpSum(\n                e[idx][row * C + col] for row in range(0, R)) <= s[j][col]\n\n    # (h)\n    alias_set = set()\n    for (idx, (i, j)) in enumerate(A):\n        R = len(s[i])  # noqa\n        C = len(s[j])  # noqa\n        if (i, j) in alias_set:\n            raise ValueError(f\"Duplicated edges: {(i, j)}\")\n\n        alias_set.add((i, j))\n        alias_set.add((j, i))\n\n        for row in range(len(s[i])):\n            for col in range(len(s[j])):\n                if v[idx][row * C + col] > 0.5:\n                    prob += s[i][row] + s[j][col] <= 1\n\n    verbose = False\n\n    msg = verbose\n    time_limit = 600\n    assert \"PULP_CBC_CMD\" in pulp.listSolvers(onlyAvailable=True), (\n        \"Please install ILP solvers by 'sudo apt install coinor-cbc'\")\n\n    solver = pulp.PULP_CBC_CMD(mip=True,\n                               msg=msg,\n                               timeLimit=time_limit,\n                               threads=multiprocessing.cpu_count())\n    prob.solve(solver)\n\n    status = prob.status\n    objective = pulp.value(prob.objective)\n    objective = float(objective) if objective is not None else -1.0\n    if verbose:\n        print(f\"ILP Status: {LpStatus[status]}\\tObjective: {objective}\\t\"\n              f\"Time: {time.time() - tic}\")\n        print(f\"#nodes: {num_nodes},  #edges: {num_edges}\")\n\n    if prob.status in [pulp.LpStatusInfeasible]:\n        raise RuntimeError(\n            \"Cannot run the function under the given memory budget. \"\n            \"Please increase the memory budget.\")\n\n    # Get and check results\n    s_val = np.full((N,), -1, dtype=np.int32)\n    for i in range(N):\n        s_val[i] = get_non_zero_index(s[i])\n\n    e_val = np.full((len(E),), -1, dtype=np.int32)\n    for (idx, (i, j)) in enumerate(E):\n        e_val[idx] = get_non_zero_index(e[idx])\n        i_spec_index = e_val[idx] // len(s[j])\n        j_spec_index = e_val[idx] % len(s[j])\n        assert i_spec_index == s_val[i], f\"e_val[{i}][{j}]\"\n        assert j_spec_index == s_val[j], f\"e_val[{i}][{j}]\"\n        if verbose and r[idx][e_val[idx]] > 0:\n            print(f\"Edge cost {(i, j)} : {r[idx][e_val[idx]]}\")\n\n    last_s_val = s_val\n    last_objective = objective\n\n    if objective > INFINITY_COST:\n        warnings.warn(\"Detect unexpected behaviors in the auto-sharding pass.\")\n\n    return s_val, e_val, objective, status\n\n\n# Auto-sharded pipeline stages.\n# These global variables are used to receive values from XLA c++ passes.\nauto_sharded_hlo_stage_names: Sequence[str] = []\nauto_sharded_hlo_stages: Sequence[xe.HloModule] = []\n\nhooked_sharding_protos = None\n\n\ndef set_auto_sharded_hlo_stages(stages: Tuple[Sequence[str],\n                                              Sequence[xe.HloModule]]):\n    \"\"\"Set the sliced auto-sharded stages. This is called in XLA\n    SliceAutoShardedStages pass.\"\"\"\n    hlo_module_names, hlo_modules = stages\n    global auto_sharded_hlo_stage_names, auto_sharded_hlo_stages\n    auto_sharded_hlo_stage_names = hlo_module_names\n    auto_sharded_hlo_stages = hlo_modules\n\n\ndef set_hooked_sharding_protos(protos: Sequence[bytes]):\n    global hooked_sharding_protos\n    hooked_sharding_protos = protos\n\n\ndef get_auto_sharded_hlo_stages(\n) -> Tuple[Sequence[str], Sequence[xe.HloModule]]:\n    \"\"\"Get the sliced hlo stages from the SliceAutoShardedStages pass.\"\"\"\n    return auto_sharded_hlo_stage_names, auto_sharded_hlo_stages\n\n\ndef get_hooked_sharding_protos() -> bytes:\n    return hooked_sharding_protos\n"
  },
  {
    "path": "alpa/shard_parallel/compile_executable.py",
    "content": "\"\"\"Compile executables for shard parallelism.\"\"\"\nimport hashlib\nimport inspect\nfrom typing import Callable, Sequence, Optional, Union\n\nimport numpy as np\nfrom jax import linear_util as lu\nfrom jax._src import traceback_util\nfrom jax._src.lib import xla_extension as xe\nfrom jax.core import (Jaxpr, ClosedJaxpr, Literal, gensym, get_aval,\n                      raise_to_shaped, AbstractValue)\nfrom jax.lax import add_p, div_p\nfrom jax.tree_util import PyTreeDef\n\nfrom alpa.device_mesh import LogicalDeviceMesh, PhysicalDeviceMesh\nfrom alpa.global_env import global_config\nfrom alpa.mesh_executable import (NormalMeshDriverExecutable,\n                                  GradAccMeshDriverExecutable)\nfrom alpa.pipeline_parallel.apply_grad import APPLY_GRAD_MARKER_SUFFIX\nfrom alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass,\n                                               run_spmd_partitioner_pass,\n                                               AutoShardingOption)\nfrom alpa.shard_parallel.manual_sharding import (ManualShardingOption,\n                                                 get_manual_sharding_spec)\nfrom alpa.util import (jaxpr_to_hlo, new_jaxpr_eqn, setup_computation_alias,\n                       trace_jaxpr_with_micro_batch,\n                       undefined_sharding_spec_proto, OrderedSet)\n\ntraceback_util.register_exclusion(__file__)\n\n\ndef get_compute_key(fun: lu.WrappedFun, in_tree: PyTreeDef,\n                    donated_invars: Sequence[bool],\n                    *aval: Sequence[AbstractValue]):\n    \"\"\"Return a unique string as the query key of a computation definition.\"\"\"\n    # pylint: disable=unused-argument\n    # Algorithm:\n    # Concatenate the definition location, source code,\n    # input arguments specification to a string.\n    # Then compute a hash value of this string.\n    #\n    # TODO(lmzheng): use jaxpr or hlo instead of source code?\n\n    location = str(fun.f).split(\"at\", maxsplit=1)[0]\n    source_code = inspect.getsource(fun.f)\n    donated_invars = str(donated_invars)\n    aval = \"\".join(x.str_short() for x in aval)\n\n    string = location + source_code + donated_invars + aval\n    hash_key = hashlib.md5(string.encode(encoding=\"utf-8\")).hexdigest()\n    return hash_key\n\n\ndef compile_shard_executable(\n    fun: lu.WrappedFun,\n    in_tree: PyTreeDef,\n    out_tree_thunk: Callable,\n    static_argnums: Sequence[int],\n    donated_invars: Sequence[bool],\n    batch_invars: Sequence[bool],\n    device_mesh: Union[PhysicalDeviceMesh, LogicalDeviceMesh],\n    num_micro_batches: Optional[int],\n    as_option: AutoShardingOption,\n    ms_option: ManualShardingOption,\n    *avals: Sequence[AbstractValue],\n):\n    \"\"\"Compile an executable with auto-sharding pass.\"\"\"\n    if isinstance(device_mesh, PhysicalDeviceMesh):\n        physical_mesh = device_mesh\n        logical_mesh_choices = [physical_mesh.get_logical_mesh()]\n    elif isinstance(device_mesh, LogicalDeviceMesh):\n        physical_mesh = device_mesh.physical_mesh\n        logical_mesh_choices = [device_mesh]\n    else:\n        raise ValueError(\"Invalid value of devices\")\n\n    if num_micro_batches is None:\n        return shard_parallel_internal(fun, in_tree, out_tree_thunk,\n                                       static_argnums, donated_invars,\n                                       physical_mesh, logical_mesh_choices,\n                                       as_option, ms_option, *avals)\n    else:\n        if global_config.backend == \"tpu\":\n            raise NotImplementedError(\n                \"Gradient accumulation for tpu is not supported\")\n        return shard_parallel_internal_gradient_accumulation(\n            fun, in_tree, out_tree_thunk, static_argnums, donated_invars,\n            batch_invars, physical_mesh, logical_mesh_choices,\n            num_micro_batches, as_option, ms_option, *avals)\n\n\ndef shard_parallel_internal(\n        fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable,\n        static_argnums: Sequence[int], donated_invars: Sequence[bool],\n        physical_mesh: PhysicalDeviceMesh,\n        logical_mesh_choices: Sequence[LogicalDeviceMesh],\n        as_option: AutoShardingOption, ms_option: ManualShardingOption,\n        *avals: Sequence[AbstractValue]):\n    \"\"\"\n    Compile an executable with auto-sharding pass.\n\n    Args:\n      fun: The wrapped jax function to be compiled.\n      in_tree: The pytree of input arguments.\n      out_tree_thunk: The thunk to produce output pytree.\n      donated_invars: Whether to donate input parameters.\n      physical_mesh: The physical device mesh.\n      logical_mesh_choices: The candidates of logical mesh shape.\n        If there is only one choice, use the given one. If there are multiple\n        choices, we will try all of them and pick the best.\n      as_option: The options of auto-sharding solver.\n      avals: The input abstract values.\n    \"\"\"\n    # pylint: disable=unused-argument\n    # Trace to get jaxpr\n    closed_jaxpr, _ = trace_jaxpr_with_micro_batch(fun, [False] * len(avals), 1,\n                                                   avals)\n    out_avals = [v.aval for v in closed_jaxpr.jaxpr.outvars]\n\n    # Convert jaxpr to XLA HLO\n    name = f\"{fun.__name__}_shard_parallel\"\n    hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars)\n    # Set user specified sharding specs.\n    if ms_option:\n        if as_option.enable_auto_sharding:\n            raise NotImplementedError(\"hybrid auto sharding is unsupported\")\n        in_sharding_proto, out_sharding_proto = get_manual_sharding_spec(\n            ms_option, logical_mesh_choices[0].shape, in_tree, out_tree_thunk(),\n            avals, out_avals)\n        if in_sharding_proto is not None:\n            hlo.set_input_shardings(in_sharding_proto)\n            hlo.is_manually_annotated = True\n        if out_sharding_proto is not None:\n            hlo.set_output_shardings(out_sharding_proto)\n            hlo.is_manually_annotated = True\n    flop_count = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module())\n\n    # Compile a XLA executable\n    hlo, stage_plan = run_auto_sharding_pass(hlo, logical_mesh_choices[0],\n                                             \"single\", 1, as_option)\n    # This is a walkaround because XLA GpuCompiler has some issue\n    if global_config.backend == \"gpu\":\n        hlo = run_spmd_partitioner_pass(hlo,\n                                        np.prod(logical_mesh_choices[0].shape))\n\n    # Compile a mesh executable\n    return NormalMeshDriverExecutable(physical_mesh,\n                                      hlo,\n                                      stage_plan,\n                                      avals,\n                                      out_avals,\n                                      donated_invars,\n                                      static_argnums=static_argnums,\n                                      in_tree=in_tree,\n                                      out_tree=out_tree_thunk(),\n                                      flop_count=flop_count)\n\n\ndef shard_parallel_internal_gradient_accumulation(\n        fun: lu.WrappedFun, in_tree: PyTreeDef, out_tree_thunk: Callable,\n        static_argnums: Sequence[int], donated_invars: Sequence[bool],\n        batch_invars: Sequence[bool], physical_mesh: PhysicalDeviceMesh,\n        logical_mesh_choices: Sequence[LogicalDeviceMesh],\n        num_micro_batches: int, as_option: AutoShardingOption,\n        ms_option: ManualShardingOption, *raw_avals: Sequence[AbstractValue]):\n    \"\"\"Compile a gradient accumulation executable with auto-sharding pass.\"\"\"\n    # pylint: disable=unused-argument\n    # Split the batch dimension\n    closed_jaxpr, _ = trace_jaxpr_with_micro_batch(fun, batch_invars,\n                                                   num_micro_batches, raw_avals)\n\n    (closed_jaxpr, accumulate_grad_invar_indices, apply_grad_invar_indices,\n     num_grads) = (add_gradient_accumulation(closed_jaxpr, num_micro_batches))\n    in_avals = [x.aval for x in closed_jaxpr.jaxpr.invars[:-num_grads]]\n    out_avals = [x.aval for x in closed_jaxpr.jaxpr.outvars]\n    grad_avals = [x.aval for x in closed_jaxpr.jaxpr.invars[-num_grads:]]\n\n    # Run auto-sharding and slice the combined HLO into two HLO: accumulate_grad\n    # and apply_grad\n    donated_invars = donated_invars + (False,) * num_grads\n    name = f\"{fun.__name__}_shard_parallel\"\n    hlo = jaxpr_to_hlo(name, closed_jaxpr, donated_invars)\n    flop_count = xe.hlo_module_count_flop_dot_conv_only(hlo.get_module())\n    flop_count *= num_micro_batches\n\n    # Set user specified sharding specs.\n    if ms_option:\n        if as_option.enable_auto_sharding:\n            raise NotImplementedError(\"hybrid auto sharding is unsupported\")\n        in_sharding_proto, out_sharding_proto = get_manual_sharding_spec(\n            ms_option, logical_mesh_choices[0].shape, in_tree, out_tree_thunk(),\n            in_avals, out_avals)\n        grad_sharding_proto = [undefined_sharding_spec_proto()] * num_grads\n        if in_sharding_proto is not None:\n            in_sharding_proto += tuple(grad_sharding_proto)\n            hlo.set_input_shardings(in_sharding_proto)\n            hlo.is_manually_annotated = True\n        if out_sharding_proto is not None:\n            hlo.set_output_shardings(out_sharding_proto)\n            hlo.is_manually_annotated = True\n\n    # pylint: disable=unbalanced-tuple-unpacking\n    hlo_stage_names, hlo_stages, stage_plan = run_auto_sharding_pass(\n        hlo, logical_mesh_choices[0], \"stages\", num_micro_batches, as_option)\n    assert len(hlo_stages) == 2\n\n    if hlo_stage_names[0].endswith(APPLY_GRAD_MARKER_SUFFIX):\n        hlo_stage_names[0], hlo_stages[0], hlo_stage_names[1], hlo_stages[1] = (\n            hlo_stage_names[1], hlo_stages[1], hlo_stage_names[0],\n            hlo_stages[0])\n    assert hlo_stage_names[1].endswith(APPLY_GRAD_MARKER_SUFFIX)\n\n    # Compile these two HLOs separately to get two XLA executables\n    accumulate_grad, apply_grad = hlo_stages\n\n    ## donate old_grad to make the gradient accumulation in-place\n    tmp_donate_invars = ((False,) * len(accumulate_grad_invar_indices) +\n                         (True,) * num_grads)\n    setup_computation_alias(accumulate_grad, tmp_donate_invars)\n\n    ## donate old opt_state and params to make the weight update in-place\n    tmp_donate_invars = (\n        tuple(donated_invars[i] for i in apply_grad_invar_indices) +\n        (False,) * num_grads)\n    setup_computation_alias(apply_grad, tmp_donate_invars)\n\n    accumulate_grad = run_spmd_partitioner_pass(accumulate_grad,\n                                                physical_mesh.num_devices,\n                                                rewrite_for_grad_acc=True)\n    apply_grad = run_spmd_partitioner_pass(apply_grad,\n                                           physical_mesh.num_devices)\n\n    # Compile them to a single mesh executable\n    return GradAccMeshDriverExecutable(physical_mesh,\n                                       accumulate_grad,\n                                       apply_grad,\n                                       stage_plan,\n                                       in_avals,\n                                       out_avals,\n                                       grad_avals,\n                                       donated_invars,\n                                       batch_invars,\n                                       accumulate_grad_invar_indices,\n                                       apply_grad_invar_indices,\n                                       num_micro_batches,\n                                       in_tree=in_tree,\n                                       out_tree=out_tree_thunk(),\n                                       flop_count=flop_count)\n\n\ndef filter_used_vars(all_vars, eqns):\n    \"\"\"Return the vars in all_vars that are used by eqns.\n\n    The returned vars preserve their original order in all_vars.\n    \"\"\"\n    used_vars = OrderedSet()\n    for eqn in eqns:\n        used_vars.update(x for x in eqn.invars if not isinstance(x, Literal))\n    return [var for var in all_vars if var in used_vars]\n\n\ndef filter_pass_through_vars(in_vars, out_vars):\n    in_vars_set = set(x for x in in_vars if not isinstance(x, Literal))\n    return [x for x in out_vars if x in in_vars_set]\n\n\ndef clone_vars(var_list, gensym_func: Callable):\n    \"\"\"Clone variables.\"\"\"\n    return [gensym_func(x.aval) for x in var_list]\n\n\ndef add_gradient_accumulation(raw_jaxpr, num_micro_batches):\n    \"\"\"Add gradient accumulation logics into the raw jaxpr.\n\n    Signatures of functions:\n        raw_jaxpr(param, opt_state, batch) -> [new_param, new_opt_state]\n\n        The original_jaxpr can be split into:\n        \"compute_grad(param, batch) -> out_grad\"\n        \"apply_grad(param, opt_state, in_grad) -> [new_param, new_opt_state]\"\n\n        We then derive accumulate_grad from compute_grad:\n        \"accumulate_grad(param, batch, old_grad) -> new_grad\"\n\n        The returned jaxpr is composed by [\n            pipeline_marker_start\n            accumulate_grad\n            pipeline_marker_end\n\n            pipeline_marker_start\n            apply_grad\n            pipeline_marker_end\n        ], with the signature\n        \"new_jaxpr(param, opt_state, batch, grad) -> [new_param, new_opt_state]\"\n    \"\"\"\n    # pylint: disable=import-outside-toplevel\n    from alpa.pipeline_parallel.primitive_def import pipeline_p\n\n    global_invars = OrderedSet(raw_jaxpr.jaxpr.invars)\n    gensym_func = gensym([raw_jaxpr.jaxpr])\n\n    # Find the gradient separator marker.\n    # This separator partitions orginal_jaxpr into two part:\n    # compute_grad and apply_grad\n    marker_eqn = None\n    marker_pos = 0\n    for pos, eqn in enumerate(raw_jaxpr.jaxpr.eqns):\n        if eqn.primitive is pipeline_p and eqn.params[\"mark_type\"] == \"grad\":\n            marker_eqn = eqn\n            marker_pos = pos\n            break\n    assert marker_eqn is not None, \"Must have exactly one gradient marker\"\n    compute_grad_eqns = raw_jaxpr.jaxpr.eqns[:marker_pos]\n    apply_grad_eqns = raw_jaxpr.jaxpr.eqns[marker_pos + 1:]\n\n    # Build the new jaxpr with gradient accumulation and pipeline marker\n    global_invar_substitute = {}\n    combined_eqns = []\n\n    # Create vars for gradient accumulation\n    out_grad_vars = marker_eqn.invars\n    old_grad_vars = clone_vars(out_grad_vars, gensym_func)\n    new_grad_vars = clone_vars(out_grad_vars, gensym_func)\n    num_grads = len(out_grad_vars)\n\n    # Wrap all invars of accumulate_grad\n    old_invars = filter_used_vars(raw_jaxpr.jaxpr.invars,\n                                  compute_grad_eqns) + old_grad_vars\n    new_invars = clone_vars(old_invars, gensym_func)\n    combined_eqns.append(\n        new_jaxpr_eqn(new_invars, old_invars, pipeline_p, {\n            \"mark_type\": \"start\",\n            \"name\": \"accumulate_grad\"\n        }))\n    global_invar_substitute.update(zip(old_invars, new_invars))\n    accumulate_grad_invars = new_invars\n\n    # Append eqns of compute_grad\n    combined_eqns.extend(raw_jaxpr.jaxpr.eqns[:marker_pos])\n\n    # Append eqns of gradient accumulation\n    for i in range(len(out_grad_vars)):\n        combined_eqns.append(\n            new_jaxpr_eqn([old_grad_vars[i], out_grad_vars[i]],\n                          [new_grad_vars[i]], add_p, {}))\n\n    # Wrap all outvars of accumulate_grad\n    inter_grad_vars = [gensym_func(x.aval) for x in out_grad_vars]\n    combined_eqns.append(\n        new_jaxpr_eqn(new_grad_vars, inter_grad_vars, pipeline_p, {\n            \"mark_type\": \"end\",\n            \"name\": \"accumulate_grad\"\n        }))\n\n    # Wrap all invars of apply_grad\n    in_grad_vars = marker_eqn.outvars\n    old_invars = (filter_used_vars(raw_jaxpr.jaxpr.invars, apply_grad_eqns) +\n                  filter_pass_through_vars(raw_jaxpr.jaxpr.invars,\n                                           raw_jaxpr.jaxpr.outvars) +\n                  in_grad_vars)\n    new_invars = []\n    for var in old_invars:\n        if var in global_invars:\n            if var in global_invar_substitute:\n                new_invars.append(global_invar_substitute[var])\n            else:\n                new_var = gensym_func(var.aval)\n                global_invar_substitute[var] = new_var\n                new_invars.append(new_var)\n        else:\n            new_invars.append(inter_grad_vars[in_grad_vars.index(var)])\n    apply_grad_invars = new_invars\n    combined_eqns.append(\n        new_jaxpr_eqn(new_invars, old_invars, pipeline_p, {\n            \"mark_type\": \"start\",\n            \"name\": APPLY_GRAD_MARKER_SUFFIX\n        }))\n\n    # Append eqns for gradient reduction\n    for i in range(num_grads):\n        tmp_var = old_invars[-(i + 1)]\n        literal_val = np.array(num_micro_batches, tmp_var.aval.dtype)\n        combined_eqns.append(\n            new_jaxpr_eqn([\n                tmp_var,\n                Literal(literal_val, raise_to_shaped(get_aval(literal_val))),\n            ], [tmp_var], div_p, {}))\n    # TODO(lmzheng): This breaks the SSA form of the combined_eqns\n    # But I find jax can convert this non-SSA jaxpr to HLO correctly,\n    # so I leave this issue as todo. To fix this, we should substitute\n    # all grad vars in these equations with new vars.\n\n    # Append eqns of apply_grad\n    combined_eqns.extend(apply_grad_eqns)\n    # TODO(lmzheng): The param vars are used in both compute_grad and\n    #   apply_grad, so there will be some duplicated intermediate vars in\n    #   compute_grad_eqns and apply_grad_eqns. This breaks the SSA form of the\n    #   combined_eqns. But I find jax can convert this non-SSA jaxpr to HLO\n    #   correctly, so I leave this issue as todo. To fix this, we should\n    #   substitute all param vars in these equations with new vars.\n\n    # Wrap all outvars of apply_grad\n    old_outvars = raw_jaxpr.jaxpr.outvars\n    new_outvars = [gensym_func(x.aval) for x in old_outvars]\n    combined_eqns.append(\n        new_jaxpr_eqn(old_outvars, new_outvars, pipeline_p, {\n            \"mark_type\": \"end\",\n            \"name\": APPLY_GRAD_MARKER_SUFFIX\n        }))\n\n    # Make the new jaxpr\n    combined_jaxpr = ClosedJaxpr(\n        Jaxpr(raw_jaxpr.jaxpr.constvars, [\n            global_invar_substitute.get(x, x)\n            for x in (raw_jaxpr.jaxpr.invars + old_grad_vars)\n        ], new_outvars, combined_eqns), raw_jaxpr.consts)\n\n    # The indices of the arguments in global arguments.\n    # TODO(lmzheng): this step is O(n^2)\n    accumulate_grad_invar_indices = [\n        combined_jaxpr.jaxpr.invars.index(var)\n        for var in accumulate_grad_invars[:-num_grads]\n    ]\n    apply_grad_invar_indices = [\n        combined_jaxpr.jaxpr.invars.index(var)\n        for var in apply_grad_invars[:-num_grads]\n    ]\n    return (combined_jaxpr, accumulate_grad_invar_indices,\n            apply_grad_invar_indices, num_grads)\n"
  },
  {
    "path": "alpa/shard_parallel/manual_sharding.py",
    "content": "\"\"\"User specified manual sharding strategy following pjit's api.\"\"\"\nimport dataclasses\nfrom typing import Any, Optional, OrderedDict, Sequence, Tuple, Union\n\nfrom jax._src.lib import xla_client as xc\nfrom jax._src.tree_util import _replace_nones\nfrom jax._src.util import safe_zip\nfrom jax.experimental.pjit import (_is_unspecified, _is_auto, _is_from_gda,\n                                   _prepare_axis_resources, get_array_mapping,\n                                   _UNSPECIFIED, PartitionSpec,\n                                   ParsedPartitionSpec)\nfrom jax.interpreters import mlir, pxla\nfrom jax.tree_util import tree_unflatten, tree_flatten, tree_map\n\nfrom alpa.util import undefined_sharding_spec_proto\n\n\n@dataclasses.dataclass\nclass ManualShardingOption:\n    \"\"\"Options to manually set shardings in pjit convention.\"\"\"\n    mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None\n    submesh_axis_names: Tuple[Tuple[pxla.MeshAxisName, ...], ...] = None\n    # According to pjit, None means replicated.\n    in_axis_resources: Any = _UNSPECIFIED\n    out_axis_resources: Any = _UNSPECIFIED\n    # To enable data parallel for multiple pipeline stages, where the input\n    # activation is not a global invar. Currently defined by (dim_name, dim_idx)\n    # TODO: a better design to allow only applying this rule to a subset of\n    # intermediate, because some pipeline communicated tensors do not have a\n    # batch dim. e.g. the time vector in diffusion generated at the first stage.\n    pipeline_intermediate_axes: Sequence[Tuple[str, int]] = None\n\n\n@dataclasses.dataclass\nclass ParsedManualShardingOption:\n    \"\"\"Options \"\"\"\n    mesh_axis_names: Tuple[pxla.MeshAxisName, ...] = None\n    submesh_axis_names: Tuple[Tuple[pxla.MeshAxisName, ...], ...] = None\n    # Parsed and flatten status\n    in_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None\n    out_parsed_pspec: Tuple[ParsedPartitionSpec, ...] = None\n    pipeline_intermediate_axes: Sequence[Tuple[str, int]] = None\n\n\ndef _parsed_pspec_to_hlo_sharding(\n    mesh_shape,\n    mesh_axis_names,\n    parsed_pspec,\n    num_dimensions: int,\n    axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None\n) -> xc.OpSharding:\n    \"\"\"\n    TODO(yonghao): support auto(see how pxla.py lowers it)\n\n    This function inlines _create_mesh_pspec_sharding_from_parsed_pspec and\n    _process_in_axis_resources. It skips some checks there including\n    _is_unspecified_or_from_gda_or_auto, pjit_check_aval_sharding. It also skips\n    the local-global translation because we always assume alpa handles jaxprs at\n    the driver side.\n    \"\"\"\n    if _is_unspecified(parsed_pspec):\n        return undefined_sharding_spec_proto()\n    if _is_from_gda(parsed_pspec):\n        raise NotImplementedError(\"alpa does not support global device array.\")\n    if _is_auto(parsed_pspec):\n        raise NotImplementedError(\"\")\n\n    array_mapping = get_array_mapping(parsed_pspec)\n    sharding_spec = pxla.new_mesh_sharding_specs(mesh_shape, mesh_axis_names)(\n        num_dimensions, array_mapping)\n    # Used in `with_sharding_constraint`.\n    special_axes = {}\n    # Manual axes is only used with xmap.\n    # TODO: check whether this manual is conflict with what we use for the\n    # unspecified type(pjit uses REPLICATED as unspecified)\n    if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext):\n        axis_names = mesh_axis_names\n        for manual_axis in axis_ctx.manual_axes:\n            special_axes[axis_names.index(\n                manual_axis)] = xc.OpSharding.Type.MANUAL\n    op_sharding = sharding_spec.sharding_proto(special_axes=special_axes)\n    return op_sharding\n\n\ndef _flatten_axes(treedef, axis_tree):\n    \"\"\"Flatten the axis tree and consider None as an effective value.\"\"\"\n    proxy = object()\n    dummy = tree_unflatten(treedef, [object()] * treedef.num_leaves)\n\n    axes = []\n\n    def add_leaves(i, x):\n        axes.extend([i] * len(tree_flatten(x)[0]))\n\n    tree_map(add_leaves, _replace_nones(proxy, axis_tree), dummy)\n    axes = [None if a is proxy else a for a in axes]\n    assert len(axes) == treedef.num_leaves\n    return axes\n\n\ndef _prepare_axis_and_flatten(axis_resources, tree, name):\n    parsed_axis_resources, _, _, any_auto = _prepare_axis_resources(\n        axis_resources, name)\n    if any_auto:\n        raise NotImplementedError(\n            \"auto mode in manual partition is unsupported.\")\n    axis_flat = tuple(_flatten_axes(tree, parsed_axis_resources))\n    if any(_is_unspecified(in_axis) for in_axis in axis_flat):\n        assert all(_is_unspecified(in_axis) for in_axis in axis_flat)\n    return axis_flat\n\n\ndef get_flatten_axis_resources(sharding_option: ManualShardingOption, in_tree,\n                               out_tree) -> ParsedManualShardingOption:\n    \"\"\"Flatten axis resources for pipeline parallel to dispatch.\"\"\"\n    if sharding_option is None:\n        return None\n\n    # process input\n    if _is_unspecified(sharding_option.in_axis_resources):\n        in_axis_flat = None\n    else:\n        in_axis_flat = _prepare_axis_and_flatten(\n            sharding_option.in_axis_resources, in_tree, \"in_axis_resources\")\n\n    # process output\n    if _is_unspecified(sharding_option.out_axis_resources):\n        out_axis_flat = None\n    else:\n        out_axis_flat = _prepare_axis_and_flatten(\n            sharding_option.out_axis_resources, out_tree, \"out_axis_resources\")\n    return ParsedManualShardingOption(\n        sharding_option.mesh_axis_names, sharding_option.submesh_axis_names,\n        in_axis_flat, out_axis_flat, sharding_option.pipeline_intermediate_axes)\n\n\ndef parsed_spec_to_opsharding(axes, avals, mesh_shape, mesh_axis_names):\n    \"\"\"Translate axis(a sequence of ParsedPartitionSpec) into OpShardings\"\"\"\n    if axes is None:\n        return None\n\n    named_mesh_shape = OrderedDict(\n        (name, size) for name, size in safe_zip(mesh_axis_names, mesh_shape))\n    op_shardings = tuple(\n        _parsed_pspec_to_hlo_sharding(named_mesh_shape, mesh_axis_names, axis,\n                                      len(aval.shape))\n        for axis, aval in safe_zip(axes, avals))\n    return op_shardings\n\n\ndef get_manual_sharding_spec(\n        sharding_option: ManualShardingOption, mesh_shape, in_tree, out_tree,\n        in_avals, out_avals) -> Tuple[Tuple[xc.OpSharding, ...], xc.OpSharding]:\n    \"\"\"Create input and output sharding spec from user's in_axis_resources.\"\"\"\n    parsed_resources = get_flatten_axis_resources(sharding_option, in_tree,\n                                                  out_tree)\n    if parsed_resources is None:\n        return None, None\n    assert parsed_resources.mesh_axis_names is not None\n    mesh_axis_names = sharding_option.mesh_axis_names\n    in_op_shardings = parsed_spec_to_opsharding(\n        parsed_resources.in_parsed_pspec, in_avals, mesh_shape, mesh_axis_names)\n    out_op_shardings = parsed_spec_to_opsharding(\n        parsed_resources.out_parsed_pspec, out_avals, mesh_shape,\n        mesh_axis_names)\n    return in_op_shardings, out_op_shardings\n\n\ndef get_intermediate_parsed_spec(intermediate_dims,\n                                 dim_len,\n                                 allow_unconstrained_dims=False):\n    axes = [None] * dim_len\n    for (name, dim) in intermediate_dims:\n        axes[dim] = name\n    pspec = PartitionSpec(*axes)\n    parsed_pspec = ParsedPartitionSpec.from_user_input(\n        pspec,\n        \"intermediate specifications\",\n        allow_unconstrained_dims=allow_unconstrained_dims)\n    return parsed_pspec\n"
  },
  {
    "path": "alpa/test_install.py",
    "content": "\"\"\"Some basic tests to test installation.\"\"\"\nimport os\nimport unittest\n\nfrom alpa import (init, parallelize, ShardParallel, PipeshardParallel,\n                  AutoLayerOption, prefetch)\nfrom alpa.device_mesh import get_global_cluster\nfrom alpa.testing import assert_allclose, get_mlp_train_state_and_step\n\n\nclass InstallationTest(unittest.TestCase):\n\n    def setUp(self):\n        os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\"\n\n    def test_1_shard_parallel(self):\n        state, batch, train_step = get_mlp_train_state_and_step(batch_size=128,\n                                                                hidden_size=128,\n                                                                num_layers=4)\n\n        # Serial execution\n        expected_output = train_step(state, batch)\n\n        # Parallel execution\n        p_train_step = parallelize(train_step,\n                                   method=ShardParallel(num_micro_batches=2))\n        actual_output = p_train_step(state, batch)\n\n        # Check results\n        assert_allclose(expected_output, actual_output)\n\n    def test_2_pipeline_parallel(self):\n        init(cluster=\"ray\")\n\n        state, batch, train_step = get_mlp_train_state_and_step(batch_size=128,\n                                                                hidden_size=128,\n                                                                num_layers=6)\n\n        # Serial execution\n        expected_output = train_step(state, batch)\n\n        # Parallel execution\n        layer_num = min(get_global_cluster().num_devices, 2)\n        p_train_step = parallelize(\n            train_step,\n            method=PipeshardParallel(\n                num_micro_batches=2,\n                layer_option=AutoLayerOption(layer_num=layer_num)))\n        actual_output = p_train_step(state, batch)\n\n        # Check results\n        prefetch(actual_output)\n        assert_allclose(expected_output, actual_output)\n\n\ndef suite():\n    s = unittest.TestSuite()\n    s.addTest(InstallationTest(\"test_1_shard_parallel\"))\n    s.addTest(InstallationTest(\"test_2_pipeline_parallel\"))\n    return s\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "alpa/testing.py",
    "content": "\"\"\"Utilities for testing.\"\"\"\nfrom functools import partial\nimport unittest\nfrom collections.abc import Iterable\nfrom typing import Callable, Optional\n\nimport jax\nimport jax.numpy as jnp\nfrom jax.tree_util import tree_leaves\nfrom jax.experimental.maps import FrozenDict as FrozenDictJax\nimport numpy as np\nimport optax\nfrom flax import linen as nn\nfrom flax.core.frozen_dict import FrozenDict as FrozenDictFlax\n\nfrom alpa.api import init, shutdown, parallelize, value_and_grad\nfrom alpa.model.bert_model import BertConfig, FlaxBertLayer\nfrom alpa.model.model_util import FlaxBaseModelOutput, DynamicScale, TrainState\nfrom alpa.parallel_method import PipeshardParallel\nfrom alpa.pipeline_parallel.layer_construction import (AutoLayerOption,\n                                                       ManualLayerOption)\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\nfrom alpa.pipeline_parallel.stage_construction import (UniformStageOption,\n                                                       StageOption)\nfrom alpa.shard_parallel.auto_sharding import AutoShardingOption\n\n\ndef assert_allclose(x, y, rtol=1e-4, atol=1e-4):\n    \"\"\"Assert the arrays in x and y are all close.\"\"\"\n    if isinstance(x, (dict, FrozenDictJax, FrozenDictFlax)):\n        assert isinstance(y, (dict, FrozenDictJax, FrozenDictFlax))\n        assert set(x.keys()) == set(y.keys())\n        for k in x.keys():\n            assert_allclose(x[k], y[k], rtol, atol)\n    elif isinstance(x, Iterable) and not hasattr(x, \"__array__\"):\n        assert isinstance(y, Iterable) and not hasattr(y, \"__array__\")\n        assert len(x) == len(y)\n        for x_elt, y_elt in zip(x, y):\n            assert_allclose(x_elt, y_elt, rtol, atol)\n    elif hasattr(x, \"__array__\") or np.isscalar(x):\n        assert hasattr(y, \"__array__\") or np.isscalar(y), f\"{y}\"\n        x = np.asarray(x)\n        y = np.asarray(y)\n        np.testing.assert_allclose(x, y, rtol, atol)\n    elif isinstance(x, TrainState):\n        assert isinstance(y, TrainState)\n        assert_allclose(tree_leaves(x), tree_leaves(y), rtol, atol)\n    elif x == y:\n        return\n    else:\n        raise TypeError((type(x), type(y)))\n\n\nclass MLPModel(nn.Module):\n    \"\"\"An MLP model for testing.\"\"\"\n    num_layers: int\n    hidden_size: int\n    use_bias: bool = True\n    add_manual_pipeline_marker: bool = True\n\n    @nn.compact\n    def __call__(self, x):\n        for i in range(self.num_layers):\n            x = nn.Dense(self.hidden_size, use_bias=self.use_bias)(x)\n\n            if (self.add_manual_pipeline_marker and\n                    i == self.num_layers // 2 - 1):\n                mark_pipeline_boundary()\n        return x\n\n\ndef get_mlp_train_state_and_step(batch_size,\n                                 hidden_size,\n                                 num_layers=4,\n                                 use_bias=True,\n                                 add_manual_pipeline_marker=False):\n    # Init input batch\n    rngkey = jax.random.PRNGKey(0)\n    x = jax.random.normal(rngkey, (batch_size, hidden_size))\n    y = jax.random.normal(rngkey, (batch_size, hidden_size))\n    batch = {\"x\": x, \"y\": y}\n\n    # Init model and optimizer\n    model = MLPModel(num_layers=num_layers,\n                     hidden_size=hidden_size,\n                     use_bias=use_bias,\n                     add_manual_pipeline_marker=add_manual_pipeline_marker)\n    params = model.init(rngkey, batch[\"x\"])\n    tx = optax.sgd(learning_rate=1e-2, momentum=0.9)\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              dynamic_scale=None)\n\n    # Define train step\n    def train_step(state, batch):\n\n        def loss_func(params):\n            out = state.apply_fn(params, batch[\"x\"])\n            return jnp.mean((out - batch[\"y\"])**2)\n\n        val, grads = value_and_grad(loss_func)(state.params)\n        new_state = state.apply_gradients(grads=grads)\n        return new_state, val\n\n    return state, batch, train_step\n\n\nclass BertLayerModel(nn.Module):\n    \"\"\"A BERT model for testing.\"\"\"\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n    add_manual_pipeline_marker: bool = True\n\n    def setup(self):\n        # pylint: disable=attribute-defined-outside-init\n        self.layers = [\n            FlaxBertLayer(config=self.config, dtype=self.dtype)\n            for _ in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(self, x, attention_mask):\n        for i, layer in enumerate(self.layers):\n            layer_outputs = layer(x, attention_mask)\n            x = layer_outputs[0]\n\n            if self.add_manual_pipeline_marker and i != len(self.layers) - 1:\n                mark_pipeline_boundary()\n        return x\n\n\ndef get_bert_layer_train_state_and_step(batch_size, seq_len, num_layers,\n                                        hidden_size, num_heads,\n                                        clip_by_global_norm, use_dynamic_scale,\n                                        add_manual_pipeline_marker):\n    rngkey = jax.random.PRNGKey(0)\n    x = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size))\n    y = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size))\n    attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)\n    batch = {\"x\": x, \"y\": y, \"attention_mask\": attention_mask}\n\n    model = BertLayerModel(\n        config=BertConfig(hidden_size=hidden_size,\n                          intermediate_size=hidden_size * 4,\n                          num_attention_heads=num_heads,\n                          num_hidden_layers=num_layers),\n        add_manual_pipeline_marker=add_manual_pipeline_marker)\n    params = model.init(rngkey, batch[\"x\"], batch[\"attention_mask\"])\n\n    if clip_by_global_norm:\n        tx = optax.chain(optax.clip_by_global_norm(0.05),\n                         optax.adam(learning_rate=1e-2))\n    else:\n        tx = optax.adam(learning_rate=1e-2)\n\n    if use_dynamic_scale:\n        use_master_copy = False\n        dynamic_scale = DynamicScale()\n    else:\n        dynamic_scale = None\n        use_master_copy = False\n\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              dynamic_scale=dynamic_scale,\n                              use_master_copy=use_master_copy)\n\n    def train_step(state, batch):\n\n        def loss_func(params):\n            out = state.apply_fn(params, batch[\"x\"], batch[\"attention_mask\"])\n            loss = jnp.mean((out - batch[\"y\"])**2)\n            return loss\n\n        dynamic_scale = state.dynamic_scale\n        if dynamic_scale:\n            grad_fn = dynamic_scale.value_and_grad(loss_func)\n            dynamic_scale, is_fin, val, grads = grad_fn(state.params)\n        else:\n            grad_fn = value_and_grad(loss_func)\n            val, grads = grad_fn(state.params)\n\n        new_state = state.apply_gradients(grads=grads)\n\n        if dynamic_scale:\n            new_state = new_state.replace(\n                opt_state=jax.tree_map(partial(jnp.where, is_fin),\n                                       new_state.opt_state, state.opt_state),\n                params=jax.tree_map(partial(jnp.where, is_fin),\n                                    new_state.params, state.params),\n                master_copy=jax.tree_map(partial(jnp.where,\n                                                 is_fin), new_state.master_copy,\n                                         state.master_copy),\n                dynamic_scale=dynamic_scale)\n        return new_state, val\n\n    return state, batch, train_step\n\n\ndef create_train_state(rngkey, model, inputs):\n    params = model.init(rngkey, *inputs)\n    tx = optax.adam(learning_rate=1e-2)\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              dynamic_scale=None)\n    return state\n\n\ndef mlp_inference_step(state, batch):\n    out = state.apply_fn(state.params, batch[\"x\"])\n    loss = jnp.mean((out - batch[\"y\"])**2)\n    return out, loss\n\n\ndef bert_layer_collection_inference_step(state, batch):\n    out = state.apply_fn(state.params,\n                         batch[\"x\"],\n                         batch[\"attention_mask\"],\n                         output_attentions=True,\n                         output_hidden_states=True)\n    loss = jnp.mean((out.last_hidden_state - batch[\"y\"])**2)\n    # FIXME(yonghao): Otherwise, the first hidden state is an input,\n    #   but we do not support outputing an input(not batch-related\n    #   outputs).\n    out = FlaxBaseModelOutput(last_hidden_state=out.last_hidden_state,\n                              hidden_states=out.hidden_states[1:],\n                              attentions=out.attentions)\n    return out, loss\n\n\nclass PipelineBasicTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def tearDown(self):\n        shutdown()\n\n    def run_mlp(self,\n                manual_pipeline_layer: bool = True,\n                use_remat: bool = False,\n                stage_option: Optional[StageOption] = None,\n                as_option: Optional[AutoShardingOption] = None,\n                do_numerical_test: bool = True):\n        method = PipeshardParallel(\n            num_micro_batches=4,\n            default_auto_sharding_option=as_option or AutoShardingOption(),\n            layer_option=ManualLayerOption(remat_layer=use_remat)\n            if manual_pipeline_layer else AutoLayerOption(\n                layer_num=2,\n                remat_mode=\"coarse_grained_remat\" if use_remat else \"none\"),\n            stage_option=stage_option or UniformStageOption())\n\n        # Init model\n        state, batch, train_step = get_mlp_train_state_and_step(\n            batch_size=64,\n            hidden_size=16,\n            num_layers=4,\n            add_manual_pipeline_marker=manual_pipeline_layer)\n\n        # Compile\n        serial_train_step = train_step\n        parallel_train_step = parallelize(train_step, method=method)\n        executable = parallel_train_step.get_executable(state, batch)\n\n        # Run correctnesss test\n        if do_numerical_test:\n            expected_new_state = None\n            actual_new_state = None\n            for i in range(3):\n                if i > 0:\n                    state = expected_new_state\n                expected_new_state, expected_val = serial_train_step(\n                    state, batch)\n\n                if i > 0:\n                    state = actual_new_state\n                actual_new_state, actual_val = parallel_train_step(state, batch)\n\n                assert_allclose(expected_new_state.params,\n                                actual_new_state.params, 1e-3, 1e-3)\n                assert_allclose(expected_val, actual_val, 1e-3, 1e-3)\n\n        hlo_text = executable.get_hlo_text()\n        return hlo_text\n\n    def run_n_layer_bert(self,\n                         num_layers,\n                         batch_size=16,\n                         seq_len=256,\n                         hidden_size=512,\n                         num_heads=512 // 64,\n                         use_remat=False,\n                         clip_by_global_norm=False,\n                         use_dynamic_scale=False,\n                         inject_train_step=None,\n                         manual_pipeline_layer=True,\n                         stage_option: Optional[StageOption] = None,\n                         as_option: Optional[AutoShardingOption] = None,\n                         do_numerical_test: bool = True):\n        method = PipeshardParallel(\n            num_micro_batches=4,\n            default_auto_sharding_option=as_option or AutoShardingOption(),\n            layer_option=ManualLayerOption(remat_layer=use_remat)\n            if manual_pipeline_layer else AutoLayerOption(\n                layer_num=num_layers,\n                remat_mode=\"coarse_grained_remat\" if use_remat else \"none\"),\n            stage_option=stage_option or UniformStageOption())\n\n        # Init model\n        state, batch, train_step = get_bert_layer_train_state_and_step(\n            batch_size=batch_size,\n            seq_len=seq_len,\n            num_layers=num_layers,\n            hidden_size=hidden_size,\n            num_heads=num_heads,\n            clip_by_global_norm=clip_by_global_norm,\n            use_dynamic_scale=use_dynamic_scale,\n            add_manual_pipeline_marker=manual_pipeline_layer)\n        if inject_train_step is not None:\n            assert isinstance(inject_train_step, Callable)\n            train_step = inject_train_step\n\n        # Compile\n        serial_train_step = train_step\n        parallel_train_step = parallelize(train_step, method=method)\n        executable = parallel_train_step.get_executable(state, batch)\n\n        # Run correctnesss test\n        if do_numerical_test:\n            expected_new_state = None\n            actual_new_state = None\n            for i in range(1):\n                if i > 0:\n                    state = expected_new_state\n                expected_new_state, expected_val = serial_train_step(\n                    state, batch)\n\n                if i > 0:\n                    state = actual_new_state\n\n                actual_new_state, actual_val = parallel_train_step(state, batch)\n\n                assert_allclose(expected_new_state.params,\n                                actual_new_state.params, 1e-3, 1.5e-3)\n                assert_allclose(expected_val, actual_val, 1e-3, 1e-3)\n\n        hlo_text = executable.get_hlo_text()\n        return hlo_text\n\n\ndef data_loader_input_iter_func(start, end, batch_size):\n    \"\"\"A data loader function for testing.\"\"\"\n    dataset_x = np.arange(1024 * 32).reshape(-1, 32).astype(np.float32)\n    dataset_y = np.arange(1024).astype(np.int32)\n\n    num_batches = (end - start) // batch_size\n\n    for i in range(num_batches):\n        idx = start + i * batch_size\n        yield dataset_x[idx:idx + batch_size], dataset_y[idx:idx + batch_size]\n\n\nclass HloParser:\n    \"\"\"\n    Parse Hlo text to check whether the parameter and output has correct\n    sharding.\n    \"\"\"\n\n    @staticmethod\n    def get_param_line(text: str):\n        text = text[text.find(\"ENTRY\"):]\n        text = text[:text.find(\"\\n\")]\n        return text\n\n    @staticmethod\n    def get_root_line(text: str):\n        text = text[text.find(\"ENTRY\"):]\n        text = text[text.find(\"ROOT\"):]\n        text = text[:text.find(\"\\n\")]\n        return text\n\n    @staticmethod\n    def parse_param_shapes(text: str):\n        # the first one is \"ENTRY %xxx (\"\n        params = text.split(\"param\")[1:]\n        shapes = tuple(map(lambda x: x[x.find(\"f32\"):x.find(\"]\") + 1], params))\n        return shapes\n\n    @staticmethod\n    def parse_root_shapes(text: str):\n        tuple_shape = text[text.find(\"=\") + 2:text.find(\"tuple(\")]\n        # the last one is ')'\n        shapes = tuple_shape.split(\"0}\")[:-1]\n        shapes = tuple(map(lambda x: x[x.find(\"f32\"):x.find(\"{\")], shapes))\n        return shapes\n"
  },
  {
    "path": "alpa/timer.py",
    "content": "\"\"\"Global timer for profiling.\"\"\"\nfrom collections import namedtuple\nimport time\nfrom typing import Callable, Any\n\n\nclass _Timer:\n    \"\"\"An internal timer.\"\"\"\n\n    def __init__(self, name: str):\n        self.name = name\n        self.started = False\n        self.start_time = None\n\n        # start-stop timestamp pairs\n        self.start_times = []\n        self.stop_times = []\n        self.costs = []\n\n    def start(self, sync_func: Callable = None):\n        \"\"\"Start the timer.\"\"\"\n        assert not self.started, f\"timer {self.name} has already been started.\"\n        if sync_func:\n            sync_func()\n\n        self.start_time = time.time()\n        self.start_times.append(self.start_time)\n        self.started = True\n\n    def stop(self, sync_func: Callable = None):\n        \"\"\"Stop the timer.\"\"\"\n        assert self.started, f\"timer {self.name} is not started.\"\n        if sync_func:\n            sync_func()\n\n        stop_time = time.time()\n        self.costs.append(stop_time - self.start_time)\n        self.stop_times.append(stop_time)\n        self.started = False\n\n    def reset(self):\n        \"\"\"Reset timer.\"\"\"\n        self.started = False\n        self.start_time = None\n        self.start_times = []\n        self.stop_times = []\n        self.costs = []\n\n    def elapsed(self, mode: str = \"average\"):\n        \"\"\"Calculate the elapsed time.\"\"\"\n        if not self.costs:\n            return 0.0\n        if mode == \"average\":\n            return sum(self.costs) / len(self.costs)\n        elif mode == \"sum\":\n            return sum(self.costs)\n        else:\n            raise RuntimeError(\"Supported mode is: average | sum\")\n\n\nclass Timers:\n    \"\"\"A group of timers.\"\"\"\n\n    def __init__(self):\n        self.timers = {}\n\n    def __call__(self, name: str):\n        if name not in self.timers:\n            self.timers[name] = _Timer(name)\n        return self.timers[name]\n\n    def __contains__(self, name: str):\n        return name in self.timers\n\n\ntimers = Timers()\n\nEvent = namedtuple(\"Event\", (\"tstamp\", \"name\", \"info\"))\n\n\nclass Tracer:\n    \"\"\"An activity tracer.\"\"\"\n\n    def __init__(self):\n        self.events = []\n\n    def log(self, name: str, info: Any, sync_func: Callable = None):\n        if sync_func:\n            sync_func()\n\n        self.events.append(Event(time.time(), name, info))\n\n\ntracer = Tracer()\n"
  },
  {
    "path": "alpa/torch/__init__.py",
    "content": "\"\"\"Miscellaneous functions available in `alpa.torch.*` namespace.\"\"\"\n\ntry:\n    import torch\nexcept ImportError as e:\n    print(\"\"\"\n        Attempted to use Alpa-PyTorch frontend, but PyTorch is not installed.\n\n        Please follow instructions at \n        https://alpa-projects.github.io/install.html#pytorch-frontend-experimental\n        to install PyTorch and related dependencies.\"\"\")\n    raise e\n\nfrom typing import Any, Callable, Union, Tuple\nfrom functools import partial, wraps\n\nfrom packaging import version\nimport numpy as np\n\nimport alpa\nfrom alpa.device_mesh import DistributedArray\nfrom alpa.torch.nn import functionalize, meta_init\nfrom alpa.torch.ops.mapping import enable_dist_for_func\nfrom alpa.torch.tensor_utils import (make_shaped_array_from_pt_tensor,\n                                     initialize_with_zeros, to_format,\n                                     assert_format)\nfrom alpa.torch import trainer\n\n# If True, prints verbose log for debugging.\ndebug = False\n\n\ndef set_mode(new_mode: str):\n    \"\"\"This sets the current alpa.torch mode. Supports one of following:\n\n    \"local\":\n    - Pure PT eager mode on a single CPU/GPU\n    - Allows print in middle of graph\n    - No dist training\n\n    \"dist\":\n    - Graph mode by lowering PT programs to JAX and then run them with Alpa\n    - Doesn't allow print in middle of graph\n    - Supports dist training\n    \"\"\"\n    assert new_mode in [\"local\", \"dist\"]\n    if new_mode == \"dist\":\n        torch.local_mode = False\n    elif new_mode == \"local\":\n        torch.local_mode = True\n\n\ndef mode():\n    if torch.local_mode:\n        return \"local\"\n    else:\n        return \"dist\"\n\n\ndef functorch_value_and_grad(func: Callable,\n                             argnums: Union[int, Tuple[int, ...]] = 0,\n                             has_aux: bool = False) -> Callable:\n    \"\"\"The same implementation as functorch.grad_and_value,\n    but puts value first and grad second in output.\n    \"\"\"\n\n    @wraps(func)\n    def wrapper(*args, **kwargs):\n        # pylint: disable=import-outside-toplevel\n        # functorch imports based on PT version\n        if version.parse(torch.__version__) < version.parse(\"1.13\"):\n            from functorch._C import (_grad_increment_nesting,\n                                      _grad_decrement_nesting)\n            from functorch._src.eager_transforms import (\n                _wrap_all_tensors, _slice_argnums, _create_differentiable,\n                _as_tuple, _autograd_grad, _undo_create_differentiable)\n            from functorch._src.pytree_hacks import tree_map_\n\n        elif version.parse(torch.__version__) == version.parse(\"1.13\"):\n            from torch._C._functorch import (_grad_increment_nesting,\n                                             _grad_decrement_nesting)\n            from functorch._src.eager_transforms import (\n                _wrap_all_tensors, _slice_argnums, _create_differentiable,\n                _as_tuple, _autograd_grad, _undo_create_differentiable)\n            from functorch._src.pytree_hacks import tree_map_\n\n        else:\n            from torch._C._functorch import (_grad_increment_nesting,\n                                             _grad_decrement_nesting)\n            from torch._functorch.eager_transforms import (\n                _wrap_all_tensors, _slice_argnums, _create_differentiable,\n                _as_tuple, _autograd_grad, _undo_create_differentiable)\n            from torch._functorch.pytree_hacks import tree_map_\n\n        from torch.utils._pytree import tree_flatten, tree_unflatten\n        level = _grad_increment_nesting()\n        try:\n            output, aux, grad_input = None, None, None\n            # See NOTE [grad and vjp interaction with no_grad]\n            with torch.enable_grad():\n                args = _wrap_all_tensors(args, level)\n                kwargs = _wrap_all_tensors(kwargs, level)\n                diff_args = _slice_argnums(args, argnums, as_tuple=False)\n                tree_map_(partial(_create_differentiable, level=level),\n                          diff_args)\n\n                output = func(*args, **kwargs)\n                if has_aux:\n                    if not (isinstance(output, tuple) and len(output) == 2):\n                        raise RuntimeError(\n                            \"value_and_grad(f)(*args): output of function f \"\n                            \"should be a tuple: (output, aux) \"\n                            \"if has_aux is True\")\n                    output, aux = output\n\n                if not isinstance(output, torch.Tensor):\n                    raise RuntimeError(\n                        \"value_and_grad(f)(*args): Expected f(*args) \"\n                        f\"to return a Tensor, got {type(output)}\")\n                if output.dim() != 0:\n                    raise RuntimeError(\n                        \"value_and_grad(f)(*args): Expected f(*args) \"\n                        \"to return a scalar Tensor, got tensor with \"\n                        f\"{output.dim()} dims. Maybe you wanted to \"\n                        \"use the vjp or jacrev APIs instead?\")\n\n                flat_diff_args, spec = tree_flatten(diff_args)\n\n                # NB: need create_graph so that backward pass isn't run\n                # in no_grad mode\n                flat_outputs = _as_tuple(output)\n                flat_grad_input = _autograd_grad(flat_outputs,\n                                                 flat_diff_args,\n                                                 create_graph=True)\n                grad_input = tree_unflatten(flat_grad_input, spec)\n\n                grad_input = _undo_create_differentiable(grad_input, level)\n                output = _undo_create_differentiable(output, level)\n                if aux is not None:\n                    aux = _undo_create_differentiable(aux, level)\n\n            if has_aux:\n                return (output, aux), grad_input\n            return output, grad_input\n        finally:\n            _grad_decrement_nesting()\n\n    return wrapper\n\n\ndef value_and_grad(func, argnums=0, has_aux=False):\n    if mode() == \"local\":\n        return functorch_value_and_grad(func, argnums=argnums, has_aux=has_aux)\n    else:\n        return alpa.value_and_grad(func, argnums=argnums, has_aux=has_aux)\n"
  },
  {
    "path": "alpa/torch/nn/__init__.py",
    "content": "\"\"\"PyTorch module conversion related functions.\n\"\"\"\nimport copy\nfrom typing import List, Callable, Dict\nfrom collections import OrderedDict\n\nimport torch\nfrom torch import Tensor, nn\nfrom torch.fx.experimental.normalize import NormalizeOperators\nfrom torchdistx import deferred_init as torchdistx_deferred_init\nfrom torchdistx.fake import meta_like\n\nimport alpa.torch as atorch\nfrom alpa.torch.tensor_utils import make_shaped_array_from_pt_tensor\nfrom alpa.torch.nn.utils import (DONT_EXPAND_MODULES, extract_buffers,\n                                 extract_weights, named_buffers, named_members,\n                                 named_parameters, normalize)\n\nmapping_prefix = \"alpa_torch_ops_mapping\"\n\n\ndef fx_ir_to_alpa_func_code(fx_ir, alpa_func_name):\n    # TODO: maybe we can operate on FX IR node to clean up this impl\n\n    fx_ir_code_cleaned = \"\"\n    for line in fx_ir.code.strip().split(\"\\n\"):\n        line = line.replace(\";  \", \"\\n    \")\n        fx_ir_code_cleaned += line + \"\\n\"\n\n    if atorch.debug:\n        print(\"FX IR code (cleaned): \")\n        print(fx_ir_code_cleaned)\n\n    lines = fx_ir_code_cleaned.split(\"\\n\")\n    assert \"def forward(\" in lines[0]\n    signature_line = lines[0]\n    sig_args = signature_line.split(\"def forward(\")[1].split(\"):\")[0].split(\n        \", \")\n    sig_args = sig_args[1:]  # remove `self`\n    sig_args.insert(0, \"params\")\n    sig_args.insert(1, \"bufs\")\n    signature_line = f\"def {alpa_func_name}(\" + \", \".join(sig_args) + \"):\"\n\n    out_body_lines = []\n\n    bufs_set = set(fx_ir.buffers(recurse=True))\n    bufs_n_to_key = {}\n\n    for line in lines[1:]:\n        line = line.replace(\" : torch.Tensor\", \"\")\n        if \"self.\" in line:\n            if \"getattr(\" in line:\n                # Example line in IR:\n                # `... = getattr(self.layers, \"0\").encoder.self_attn.qkv.weight`\n                # For RHS, FQN in param dict should be:\n                # \"layers.0.encoder.self_attn.qkv.weight\"\n                attr_fqn_name_in_original_ir = line.split(\" = \")[1]\n                attr_fqn_name_in_param_dict = (\n                    line.split(\"getattr(self.\")[1].split(\"(\")[0].replace(\n                        ', \"', \".\").replace('\")', \"\"))\n            else:\n                # Example line in IR:\n                # `self_layers_0__w_attention = self.self_layers_0__w_attention`\n                # For RHS, FQN in param dict should be:\n                # \"self_layers_0__w_attention\"\n                attr_fqn_name_in_original_ir = line.split(\" = \")[1]\n                attr_fqn_name_in_param_dict = line.split(\"self.\")[1].split(\n                    \"(\")[0]\n            line_rhs = line.split(\" = \")[1]\n            try:\n                if \").\" in line_rhs:\n                    # Example line in IR:\n                    # `... = getattr(self.layers, \"0\").conv(reshape_7)`\n                    # Attribute access statement should be\n                    # `getattr(self.layers, \"0\").conv`\n                    attr_access_stmt = (\"_tmp_value = \" +\n                                        line_rhs.split(\").\")[0].replace(\n                                            \"self.\", \"locals()['fx_ir'].\") +\n                                        \").\" +\n                                        line_rhs.split(\").\")[1].split(\"(\")[0])\n                else:\n                    attr_access_stmt = \"_tmp_value = \" + line_rhs.replace(\n                        \"self.\", \"locals()['fx_ir'].\")\n            except IndexError as e:\n                print(line_rhs)\n                raise e\n            # pylint: disable=exec-used\n            exec(attr_access_stmt)\n            attr_value = locals()[\"_tmp_value\"]\n            if isinstance(attr_value, torch.nn.Module):\n                # Full list of NN modules that need this handling is at\n                # torchdynamo/torchdynamo/optimizations/normalize.py\n                # `DONT_EXPAND_MODULES`.\n                assert attr_value.__class__.__name__ in DONT_EXPAND_MODULES, \\\n                    \"unknown module: \" + str(attr_value.__class__.__name__)\n                call_args = line.split(\"self.\")[1].split(\"(\")[1].split(\n                    \")\")[0].split(\", \")\n                if attr_value.__class__.__name__ == \"Conv2d\":\n                    call_args += [\n                        f\"params['{attr_fqn_name_in_param_dict}.weight']\",\n                        f\"bias=params['{attr_fqn_name_in_param_dict}.bias']\",\n                        f\"stride={attr_value.stride}\",\n                        f\"padding={attr_value.padding}\",\n                        f\"dilation={attr_value.dilation}\",\n                        f\"groups={attr_value.groups}\",\n                    ]\n                    lhs = line.split(\" = \")[0]\n                    line = lhs + \" = \" + f\"torch.conv2d({', '.join(call_args)})\"\n                else:\n                    raise NotImplementedError\n            elif isinstance(attr_value, torch.nn.Parameter):  # Parameter\n                line = line.replace(f\"{attr_fqn_name_in_original_ir}\",\n                                    f\"params['{attr_fqn_name_in_param_dict}']\")\n            elif isinstance(attr_value, torch.Tensor):\n                if attr_value in bufs_set:  # Buffer\n                    # TODO: verify whether torch.fx.symbolic_trace\n                    # puts both buffer and non-buffer Tensors\n                    # (i.e. both `self.register_buffer(...)` and\n                    # `self.tensor = torch.tensor(...)`)\n                    # into buffers dict.\n                    # This code assumes so.\n                    line = line.replace(\n                        f\"{attr_fqn_name_in_original_ir}\",\n                        f\"bufs['{attr_fqn_name_in_param_dict}']\")\n                else:  # Const\n                    raise ValueError(\n                        \"We assume torch.fx treats non-buffer \"\n                        \"tensor attributes as buffers, \"\n                        \"but this assumption no longer holds true for \"\n                        \".{attr_fqn_name_in_param_dict}\")\n            else:  # Const\n                raise ValueError(\n                    \"non-module / non-tensor attribute is not supported, \"\n                    \"but found type of \"\n                    f\"'{attr_fqn_name_in_param_dict}' to be {type(attr_value)}\")\n\n        # Record all buffers' name and their correponding key in `bufs` dict\n        if \" = bufs['\" in line:\n            buf_name = line.split(\" = bufs['\")[0].strip()\n            buf_key = line.split(\" = bufs['\")[1].split(\"']\")[0]\n            bufs_n_to_key[buf_name] = buf_key\n\n        # Rewrite stateful modules / ops\n        if \"torch.nn.functional.batch_norm\" in line:\n            lhs = line.split(\" = torch.nn.functional.batch_norm\")[0]\n            call_args = line.split(\" = torch.nn.functional.batch_norm(\"\n                                  )[1].split(\")\")[0].split(\", \")\n            r_mean_arg_n = call_args[1]\n            assert \"running_mean\" in r_mean_arg_n\n            r_var_arg_n = call_args[2]\n            assert \"running_var\" in r_var_arg_n\n            line = (lhs + \", r_mean_new, r_var_new\" +\n                    \" = torch.nn.functional.batch_norm(\" +\n                    \", \".join(call_args) + \")\")\n            line += \"\\n\"\n            line += f\"    bufs['{bufs_n_to_key[r_mean_arg_n]}'] = r_mean_new\"\n            line += \"\\n\"\n            line += f\"    bufs['{bufs_n_to_key[r_var_arg_n]}'] = r_var_new\"\n\n        # Op lowering\n        if \"torch._C._nn.\" in line:\n            op_name = line.split(\"torch._C._nn.\")[1].split(\"(\")[0]\n            line = line.replace(f\"torch._C._nn.{op_name}\",\n                                f\"torch.nn.functional.{op_name}\")\n        if f\"{mapping_prefix}_torch_nn_functional_\" in line:\n            op_name = line.split(\n                f\"{mapping_prefix}_torch_nn_functional_\")[1].split(\"(\")[0]\n            line = line.replace(\n                f\"{mapping_prefix}_torch_nn_functional_{op_name}\",\n                f\"torch.nn.functional.{op_name}\")\n        if f\"{mapping_prefix}_torch_\" in line:\n            op_name = line.split(f\"{mapping_prefix}_torch_\")[1].split(\"(\")[0]\n            line = line.replace(f\"{mapping_prefix}_torch_{op_name}\",\n                                f\"torch.{op_name}\")\n        if \".dim()\" in line:\n            tensor_name = line.split(\" = \")[1].split(\".dim()\")[0]\n            line = line.replace(f\"{tensor_name}.dim()\",\n                                f\"len({tensor_name}.shape)\")\n        if \".size()\" in line:\n            tensor_name = line.split(\" = \")[1].split(\".size()\")[0]\n            line = line.replace(f\"{tensor_name}.size()\", f\"{tensor_name}.shape\")\n        if \".permute(\" in line:\n            tensor_name = line.split(\" = \")[1].split(\".permute(\")[0]\n            line = line.replace(f\"{tensor_name}.permute(\",\n                                f\"torch.permute({tensor_name}, (\") + \")\"\n        if \".expand(\" in line:\n            tensor_name = line.split(\" = \")[1].split(\".expand(\")[0]\n            line = line.replace(f\"{tensor_name}.expand(\",\n                                f\"torch.expand({tensor_name}, (\") + \")\"\n        if \".view(\" in line:\n            tensor_name = line.split(\" = \")[1].split(\".view(\")[0]\n            line = line.replace(f\"{tensor_name}.view(\",\n                                f\"torch.view({tensor_name}, (\") + \")\"\n        if \" @ \" in line:\n            lhs = line.split(\" = \")[0]\n            rhs = line.split(\" = \")[1]\n            line = lhs + \" = \" + \"torch.matmul(\" + rhs.replace(\" @ \",\n                                                               \", \") + \")\"\n\n        if \"return \" in line:\n            rhs_of_return = line.split(\"return \")[1]\n            output_args = rhs_of_return.split(\",\")\n            output_args.insert(0, \"bufs\")\n            line = line.split(\"return \")[0] + \"return \" + \", \".join(output_args)\n\n        out_body_lines.append(line)\n\n    # `alpa_func_code` is string form of a function that contains\n    # (mostly) PyTorch operations.\n    # \"mostly\" because ops like `torch.expand` and `torch.view` are not actually\n    # valid PyTorch ops and only work within `atorch.bind_ops()` context.\n    alpa_func_code = signature_line + \"\\n\" + \"\\n\".join(out_body_lines) + \"\\n\"\n    alpa_func_code = alpa_func_code.strip()\n\n    return alpa_func_code\n\n\n# Copied from torchdynamo/torchdynamo/optimizations/normalize.py\ndef normalize_ir_no_run(fx_ir):\n    normalize(fx_ir)\n    try:\n        fx_ir = NormalizeOperators(fx_ir).transform()\n    except AttributeError:\n        # log.exception(\"NormalizeOperators() failed\")\n        pass\n    # ShapeAliasingAndMutationProp(fx_ir).run(*example_inputs)\n    # fx_ir = Functionalization(fx_ir).transform()\n    fx_ir.recompile()\n    # record_graph_stats(fx_ir)\n    return fx_ir\n\n\n# Copied from functorch/functorch/_src/make_functional.py\ndef _del_nested_attr(obj: nn.Module, names: List[str]) -> None:\n    \"\"\"Deletes the attribute specified by the given list of names.\n    For example, to delete the attribute obj.conv.weight,\n    use _del_nested_attr(obj, ['conv', 'weight'])\n    \"\"\"\n    if len(names) == 1:\n        delattr(obj, names[0])\n    else:\n        _del_nested_attr(getattr(obj, names[0]), names[1:])\n\n\ndef _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:\n    \"\"\"Set the attribute specified by the given list of names to value.\n    For example, to set the attribute obj.conv.weight,\n    use _del_nested_attr(obj, ['conv', 'weight'], value)\n    \"\"\"\n    if len(names) == 1:\n        setattr(obj, names[0], value)\n    else:\n        _set_nested_attr(getattr(obj, names[0]), names[1:], value)\n\n\ndef _get_nested_attr(obj: nn.Module, names: List[str]) -> None:\n    if len(names) == 1:\n        return getattr(obj, names[0])\n    else:\n        return _get_nested_attr(getattr(obj, names[0]), names[1:])\n\n\ndef _swap_state(mod: nn.Module, names_map: Dict[str, List[str]], elems):\n    result = []\n    for (_, attr_names), elem in zip(names_map.items(), elems):\n        for i, attr_name in enumerate(attr_names):\n            if i == 0:\n                result.append(_get_nested_attr(mod, attr_name))\n            _del_nested_attr(mod, attr_name)\n            _set_nested_attr(mod, attr_name, elem)\n    return result\n\n\n# Adapted from `FunctionalModuleWithBuffers`\n# in functorch/functorch/_src/make_functional.py\nclass FunctionalModuleWithBuffersInInputAndOutput(torch.nn.Module):\n    \"\"\"Given a ``torch.nn.Module``, `create_from` extracts the\n    state (params and buffers) and returns a functional version of the model\n    ``func`` that can be invoked like a function.\n\n    Compared to `FunctionalModuleWithBuffers` in functorch, the returned\n    functional version of the model also has buffers in the output, since\n    buffer values can be changed with operations like batchnorm and should be\n    tracked as part of output.\n    \"\"\"\n\n    def __init__(self, stateless_model, param_names, buffer_names,\n                 param_names_map, buffer_names_map):\n        super().__init__()\n        self.stateless_model = stateless_model\n        self.param_names = param_names\n        self.buffer_names = buffer_names\n\n        self.all_names_map = dict(param_names_map)\n        self.all_names_map.update(buffer_names_map)\n\n    @staticmethod\n    def create_from(model, disable_autograd_tracking=False):\n        # TODO: We don't need to copy the model to create a stateless copy\n        model_copy = copy.deepcopy(model)\n        param_values, param_names, param_names_map = extract_weights(model_copy)\n        buffer_values, buffer_names, buffer_names_map = extract_buffers(\n            model_copy)\n        params = OrderedDict(zip(param_names, param_values))\n        buffers = OrderedDict(zip(buffer_names, buffer_values))\n        if disable_autograd_tracking:\n            for param in param_values:\n                param.requires_grad_(False)\n        return (\n            FunctionalModuleWithBuffersInInputAndOutput(model_copy, param_names,\n                                                        buffer_names,\n                                                        param_names_map,\n                                                        buffer_names_map),\n            params,\n            buffers,\n        )\n\n    def forward(self, params, buffers, *args, **kwargs):\n        # Temporarily load the state back onto self.stateless_model\n        old_state = _swap_state(self.stateless_model, self.all_names_map,\n                                list(params.values()) + list(buffers.values()))\n        try:\n            return buffers, self.stateless_model(*args, **kwargs)\n        finally:\n            # Remove the loaded state on self.stateless_model\n            _swap_state(self.stateless_model, self.all_names_map, old_state)\n\n\ndef functionalize(module: torch.nn.Module):\n    \"\"\"Returns:\n        - `module_func`: a function that has same logic as x.forward but\n        callable with either PT or Alpa inputs. It:\n            - wraps the original inputs in a tuple\n            - takes `params` and `bufs` as extra at beginning of input list\n            - produces `bufs` as extra output at beginning of output list\n            - all calls are made compatible with Alpa, e.g.:\n                - replaces all unexpandable module calls (e.g. nn.Conv2d) with\n                  equivalent `torch.*` function calls\n                - replaces all torch.nn.functional calls that has in-place ops\n                  (e.g. F.batch_norm) with equivalent `atorch.*` function calls\n                  that has buffer as part of output\n                - complex torch function calls (e.g. F.dropout) are decomposed\n                  and implemented with `torch.*` calls\n        - `params`: a dict of shape-only tensors representing the trainable\n           parameters of the module.\n           In PT format if \"local\", in Alpa format if \"dist\".\n        - `bufs`: a dict of shape-only tensors representing the no-gradient\n           parameters of the module.\n           In PT format if \"local\", in Alpa format if \"dist\".\n    Throws error if x.forward:\n        - has in-place ops\n        - or, has data-dependent control flow\n        - or, has other graph-breaking statements (e.g. `print()`) that\n          prevents the program from being captured as a single graph\n          (only in \"dist\" mode)\n    \"\"\"\n\n    # This param/buffer name map is used for mapping from FQN in original\n    # PyTorch model to FQN in PyTorch FX IR.\n    tensor_to_name_map = {}\n\n    all_tensors_pt_orig = dict(named_parameters(module))\n    all_tensors_pt_orig.update(dict(named_buffers(module)))\n\n    for k, v in all_tensors_pt_orig.items():\n        assert v not in tensor_to_name_map\n        tensor_to_name_map[v] = {\"orig_name\": k}\n\n    def add_transformed_name(tensor_to_name_map, k, v):\n        assert v in tensor_to_name_map\n        assert \"transformed_name\" not in tensor_to_name_map[v]\n        tensor_to_name_map[v][\"transformed_name\"] = k\n\n    if atorch.mode() == \"dist\":\n        # In dist mode, use TorchDynamo to enforce:\n        # 1) no data-dependent control flow\n        # 2) no graph break points\n        # 3) no in-place ops\n\n        def convert_pt_module_to_alpa_func(module):\n            fx_ir = torch.fx.symbolic_trace(module)\n\n            fx_ir = normalize_ir_no_run(fx_ir)\n\n            # NOTE: due to some unknown reason, only the second normalize pass\n            # can convert tensor method to torch function\n            # (e.g. `.t()` to `torch.t()`)\n            fx_ir = normalize_ir_no_run(fx_ir)\n\n            m_func_name = \"_alpa_forward_func\"\n            m_func_code = fx_ir_to_alpa_func_code(fx_ir, m_func_name)\n\n            if atorch.debug:\n                print(\"JAX function code: \")\n                print(m_func_code)\n\n            # pylint: disable=exec-used\n            exec(m_func_code)\n            module_func = locals()[m_func_name]\n\n            return fx_ir, module_func\n\n        # NOTE: torch.fx.symbolic_trace doesn't hardcode the batch size\n        # for `.view()` and `.reshape()` ops, so we DON'T need to trace\n        # two graphs (one full-batch, one micro-batch).\n        fx_ir, module_func = convert_pt_module_to_alpa_func(module)\n\n        params_pt = dict(named_parameters(fx_ir))\n        bufs_pt = dict(named_buffers(fx_ir))\n\n        for k, v in params_pt.items():\n            add_transformed_name(tensor_to_name_map, k, v)\n\n        for k, v in bufs_pt.items():\n            add_transformed_name(tensor_to_name_map, k, v)\n\n        for k, v in tensor_to_name_map.items():\n            if \"transformed_name\" not in v:\n                print(v[\"orig_name\"])\n\n        params_alpa = {\n            k: make_shaped_array_from_pt_tensor(v)\n            for k, v in params_pt.items()\n        }\n        bufs_alpa = {\n            k: make_shaped_array_from_pt_tensor(v) for k, v in bufs_pt.items()\n        }\n\n        if atorch.mode() == \"local\":\n            params = params_pt\n            bufs = bufs_pt\n        elif atorch.mode() == \"dist\":\n            params = params_alpa\n            bufs = bufs_alpa\n\n        name_map = {}\n        for elem in tensor_to_name_map.values():\n            try:\n                name_map[elem[\"orig_name\"]] = elem[\"transformed_name\"]\n            except KeyError as e:\n                print(f'elem[\"orig_name\"]: {elem[\"orig_name\"]}')\n                raise e\n    elif atorch.mode() == \"local\":\n        # In local mode, use functionalization pass adapted from functorch\n        # TODO: add more rigorous unit tests for this branch\n        module_func, params, bufs = \\\n            FunctionalModuleWithBuffersInInputAndOutput.create_from(module)\n        name_map = {}\n        for elem in tensor_to_name_map.values():\n            name_map[elem[\"orig_name\"]] = elem[\"orig_name\"]\n\n    return module_func, params, bufs, name_map\n\n\ndef meta_init(module_fn: Callable[..., torch.nn.Module], *args, **kwargs):\n    pt_module = torchdistx_deferred_init.deferred_init(module_fn, *args,\n                                                       **kwargs)\n    # pylint: disable=protected-access\n    return pt_module._apply(meta_like)\n"
  },
  {
    "path": "alpa/torch/nn/utils.py",
    "content": "# pylint: skip-file\n\n# All code in this file are extracted from torchdynamo and functorch.\n# Skipping pylint for this file so that it's easy to find out the difference\n# when we need to pull in new changes again.\n\nimport builtins\nimport dataclasses\nimport functools\nimport itertools\nimport math\nimport operator\nfrom typing import List\n\nimport torch\nfrom torch import nn\nfrom torch import Tensor\nfrom torch.fx import Transformer\nfrom torch.fx.experimental.normalize import NormalizeOperators\nfrom torch.fx.operator_schemas import get_signature_for_torch_op\n\n# Copied from torchdynamo/torchdynamo/optimizations/normalize.py\nVIEW_OPS = {\n    # list taken from https://pytorch.org/docs/stable/tensor_view.html\n    \"getitem\",\n    \"as_strided\",\n    \"detach\",\n    \"diagonal\",\n    \"expand\",\n    \"expand_as\",\n    \"movedim\",\n    \"narrow\",\n    \"permute\",\n    \"select\",\n    \"squeeze\",\n    \"transpose\",\n    \"t\",\n    \"T\",\n    \"real\",\n    \"imag\",\n    \"view_as_real\",\n    \"view_as_imag\",\n    \"unflatten\",\n    \"unfold\",\n    \"unsqueeze\",\n    \"view\",\n    \"view_as\",\n    \"unbind\",\n    \"split\",\n    \"split_with_sizes\",\n    \"swapaxes\",\n    \"swapdims\",\n    \"chunk\",\n    \"indices\",\n    \"values\",\n}\nMAYBE_VIEW_OPS = {\"contiguous\", \"reshape\"}\n\n# convert x.foo(...) to torch.foo(x, ...)\nNORMALIZE_METHODS = {\n    # These ones aren't normalized:\n    # ('view', 342)\n    # ('reshape', 285)\n    # ('expand', 87)\n    # ('permute', 78)\n    # ('to', 66)\n    # ('contiguous', 62)\n    # ('reshape_as', 57)\n    # ('masked_fill', 30)\n    # ('float', 22) -- could rewrite\n    # ('expand_as', 14) -- could rewrite\n    # ('detach', 4)\n    # ('repeat', 2)\n    # TODO(jansel): debug why this causes issues in detectron2_maskrcnn\n    # \"div\": torch.div,\n    \"add_\": operator.iadd,\n    \"all\": torch.all,\n    \"any\": torch.any,\n    \"ceil\": torch.ceil,\n    \"chunk\": torch.chunk,\n    \"clamp\": torch.clamp,\n    \"clone\": torch.clone,\n    \"exp\": torch.exp,\n    \"flatten\": torch.flatten,\n    \"flip\": torch.flip,\n    \"floor\": torch.floor,\n    \"index_select\": torch.index_select,\n    \"log2\": torch.log2,\n    \"log_softmax\": torch.nn.functional.log_softmax,\n    \"max\": torch.max,\n    \"mean\": torch.mean,\n    \"min\": torch.min,\n    \"mul_\": operator.imul,\n    \"narrow\": torch.narrow,\n    \"ne\": torch.ne,\n    \"nonzero\": torch.nonzero,\n    \"numel\": torch.numel,\n    \"pow\": torch.pow,\n    \"round\": torch.round,\n    \"rsqrt\": torch.rsqrt,\n    \"sigmoid\": torch.sigmoid,\n    \"softmax\": torch.nn.functional.softmax,\n    \"sort\": torch.sort,\n    \"split\": torch.split,\n    \"squeeze\": torch.squeeze,\n    \"std\": torch.std,\n    \"sum\": torch.sum,\n    \"topk\": torch.topk,\n    \"transpose\": torch.transpose,\n    \"tril\": torch.tril,\n    \"t\": torch.t,\n    \"unbind\": torch.unbind,\n    \"unsqueeze\": torch.unsqueeze,\n}\nDONT_EXPAND_MODULES = {\n    # These have internal control flow\n    \"ConvTranspose1d\",\n    \"ConvTranspose2d\",\n    \"Conv2d\",\n    \"ConvReLU2d\",\n    \"ConvBn2d\",\n    \"ConvBnReLU2d\",\n    \"EmbeddingBag\",\n    \"InstanceNorm2d\",\n    \"LSTM\",\n}\n\nF = torch.nn.functional\nINPLACE_KEYWORD_OPS = {\n    F.mish,\n    F.silu,\n    F.hardsigmoid,\n    F.rrelu,\n    F.leaky_relu,\n    F.celu,\n    F.selu,\n    F.elu,\n    F.relu6,\n    F.hardswish,\n    F.hardtanh,\n    F.relu,\n    F.threshold,\n}\nIOPERATOR_REPLACEMENTS = {\n    \"masked_fill_\": \"masked_fill\",\n    \"scatter_\": \"scatter\",\n    \"unsqueeze_\": \"unsqueeze\",\n    torch.relu_: torch.relu,\n    torch.sigmoid_: torch.sigmoid,\n    operator.iadd: torch.add,\n    operator.iand: torch.bitwise_and,\n    operator.ifloordiv: functools.partial(torch.div, rounding_mode=\"floor\"),\n    operator.itruediv: torch.div,\n    operator.imul: torch.mul,\n    operator.imatmul: torch.matmul,\n    operator.ior: torch.bitwise_or,\n    operator.ipow: torch.pow,\n    operator.isub: torch.sub,\n    operator.ixor: torch.bitwise_xor,\n}\nOPERATOR_REPLACEMENTS = {\n    operator.lt:\n        torch.lt,\n    operator.le:\n        torch.le,\n    operator.eq:\n        torch.eq,\n    operator.ne:\n        torch.ne,\n    operator.ge:\n        torch.ge,\n    operator.gt:\n        torch.gt,\n    operator.abs:\n        torch.abs,\n    operator.add:\n        torch.add,\n    operator.and_:\n        torch.bitwise_and,\n    operator.floordiv:\n        functools.partial(torch.div, rounding_mode=\"floor\"),\n    # operator.truediv: torch.div,  # TODO(jansel): debug issue in vision_maskrcnn\n    operator.inv:\n        torch.bitwise_not,\n    operator.invert:\n        torch.bitwise_not,\n    operator.mod:\n        torch.remainder,\n    operator.mul:\n        torch.mul,\n    operator.matmul:\n        torch.matmul,\n    operator.neg:\n        torch.neg,\n    operator.or_:\n        torch.bitwise_or,\n    operator.pos:\n        torch.positive,\n    operator.pow:\n        torch.pow,\n    operator.sub:\n        torch.sub,\n    operator.xor:\n        torch.bitwise_xor,\n    torch.nn.functional.sigmoid:\n        torch.sigmoid,\n    torch.nn.functional.tanh:\n        torch.tanh,\n    torch.nn.functional.relu:\n        torch.relu,\n}\n\nSKIP_INPLACE = {\n    v for v in itertools.chain(math.__dict__.values(), builtins.__dict__.values(\n    ), operator.__dict__.values()) if callable(v)\n}\n\n\ndef always_true(*args, **kwargs):\n    return True\n\n\nclass InliningTracer(torch.fx.Tracer):\n\n    def is_leaf_module(self, m: torch.nn.Module,\n                       module_qualified_name: str) -> bool:\n        return False\n\n\ndef expand_module_call(prefix, graph: torch.fx.Graph, module, args, kwargs):\n    # this patch is needed to make BatchNorm2D FX trace\n    module.__dict__[\"_check_input_dim\"] = always_true\n    try:\n        assert not kwargs\n        arg_index = itertools.count()\n        vars = dict()\n        for node in InliningTracer().trace(module).nodes:\n            if node.op == \"placeholder\":\n                vars[node] = args[next(arg_index)]\n            elif node.op == \"output\":\n                assert len(node.args) == 1\n                return vars[node.args[0]]\n            elif node.op == \"get_attr\":\n                vars[node] = graph.get_attr(f\"{prefix}{node.target}\")\n            else:\n                vars[node] = graph.node_copy(node, vars.__getitem__)\n        assert False\n    except Exception:\n        print(f\"Error while expanding {module.__class__.__name__}\")\n        raise\n    finally:\n        del module.__dict__[\"_check_input_dim\"]\n\n\n@dataclasses.dataclass\nclass NodeCounts:\n    usages: int = 0\n\n\ndef short_name(gm, node: torch.fx.Node):\n    if node.op == \"call_function\":\n        return node.target.__name__\n    elif node.op == \"call_method\":\n        return node.target\n    elif node.op == \"call_module\":\n        return gm.get_submodule(node.target).__class__.__name__\n    elif node.op == \"get_attr\":\n        return node.target\n    elif node.op == \"output\":\n        return \"output\"\n    assert False, node.op\n\n\ndef long_name(gm, node: torch.fx.Node):\n    name = short_name(gm, node)\n    target = node.target\n    if node.op == \"call_function\":\n        return torch_get_name(node.target,\n                              f\"{getattr(target, '__module__', '')}.{name}\")\n    elif node.op == \"call_method\":\n        return name\n    elif node.op == \"call_module\":\n        target = gm.get_submodule(target).__class__\n        return f\"{getattr(target, '__module__', '')}.{getattr(target, '__name__', '')}\"\n    elif node.op == \"get_attr\":\n        return name\n    elif node.op == \"output\":\n        return \"output\"\n    assert False\n\n\nclass Inplacifier:\n\n    def __init__(self, gm: torch.fx.GraphModule):\n        self.gm = gm\n\n    def can_be_view(self, node):\n        name = short_name(self.gm, node)\n        return name in VIEW_OPS or name in MAYBE_VIEW_OPS\n\n    def inplacify(self):\n        counts = dict()\n\n        def record_usage(node):\n            counts[node].usages += 1\n            return node\n\n        for node in self.gm.graph.nodes:\n            if node.op in (\"call_function\", \"call_method\", \"call_module\"):\n                if self.can_be_view(node):\n                    # Aliasing\n                    counts[node] = counts[node.args[0]]\n                elif \"out\" in node.kwargs:\n                    counts[node] = counts[node.kwargs[\"out\"]]\n                else:\n                    counts[node] = NodeCounts(0)\n            else:\n                counts[node] = NodeCounts(float(\"inf\"))\n\n        for node in reversed(list(self.gm.graph.nodes)):\n            kwargs = dict(node.kwargs)\n            if \"inplace\" in kwargs:\n                kwargs.pop(\"inplace\")\n            if node.op == \"call_function\" and len(node.args) + len(kwargs) == 1:\n                arg = node.args[0] if node.args else next(kwargs.values())\n                if isinstance(arg, torch.fx.Node) and counts[arg].usages == 0:\n                    if node.target in SKIP_INPLACE:\n                        continue\n                    elif node.target in INPLACE_KEYWORD_OPS:\n                        kwargs[\"inplace\"] = True\n                        counters[\"optimizations\"][\"inplace\"] += 1\n                    elif \" out: torch.Tensor\" in repr(\n                            get_signature_for_torch_op(node.target)):\n                        kwargs[\"out\"] = arg\n                        counters[\"optimizations\"][\"out\"] += 1\n                    else:\n                        continue\n                    with self.gm.graph.inserting_before(node):\n                        node.replace_all_uses_with(\n                            self.gm.graph.call_function(node.target, node.args,\n                                                        kwargs))\n                    self.gm.graph.erase_node(node)\n\n            torch.fx.map_arg((node.args, node.kwargs), record_usage)\n\n\nclass Functionalization(Transformer):\n    \"\"\"Remove most cases of mutation from a given fx Graph.\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(Functionalization, self).__init__(*args, **kwargs)\n        self.tracer.tensor_attrs = dict()  # TODO(jansel): upstream this fix\n\n    def run_node(self, n: torch.fx.Node):\n\n        patches = []\n        target = n.target\n        args, kwargs = self.fetch_args_kwargs_from_env(n)\n        kwargs = dict(kwargs)\n\n        if (not n.meta[\"is_input_mutation\"] and\n                not n.meta[\"partial_mutation\"] and\n                issubclass(n.meta[\"type\"], torch.Tensor)):\n            if \"inplace\" in n.kwargs:\n                if kwargs[\"inplace\"]:\n                    patches.append(n.args[0])\n                kwargs.pop(\"inplace\")\n            elif \"out\" in n.kwargs:\n                kwargs.pop(\"out\")\n                patches.append(n.kwargs[\"out\"])\n            elif n.target in IOPERATOR_REPLACEMENTS:\n                target = IOPERATOR_REPLACEMENTS[n.target]\n                patches.append(n.args[0])\n            elif n.meta[\"is_mutation\"]:\n                counters[\"mutation\"][long_name(self.module, n)] += 1\n\n            if target in OPERATOR_REPLACEMENTS and not kwargs:\n                target = OPERATOR_REPLACEMENTS[target]\n\n        if target is builtins.getattr:\n            if args[1] == \"dtype\":\n                return n.args[0].meta[\"dtype\"]\n            elif args[1] == \"device\":\n                return n.args[0].meta[\"device\"]\n            else:\n                counters[\"getattr\"][args[1]] += 1\n\n        if isinstance(target, functools.partial):\n            assert not target.args\n            kwargs.update(target.keywords)\n            target = target.func\n\n        if not issubclass(n.meta[\"type\"], torch.Tensor):\n            counters[\"nontensor\"][long_name(self.module, n)] += 1\n\n        result = getattr(self, n.op)(target, args, kwargs)\n\n        for patch in patches:\n            assert isinstance(\n                patch, torch.fx.Node), f\"{patch} {n.target} {n.args} {n.kwargs}\"\n            if patch in self.env:\n                self.env[patch] = result\n\n        return result\n\n\ndef swap_node(graph, old_node, new_node):\n    old_node.replace_all_uses_with(new_node)\n    graph.erase_node(old_node)\n\n\ndef normalize(gm: torch.fx.GraphModule):\n    # gm.graph.print_tabular()\n    graph: torch.fx.Graph = gm.graph\n\n    for node in list(graph.nodes):\n        with graph.inserting_before(node):\n            if node.op == \"call_method\" and node.target in NORMALIZE_METHODS:\n                swap_node(\n                    graph,\n                    node,\n                    graph.call_function(NORMALIZE_METHODS[node.target],\n                                        node.args, node.kwargs),\n                )\n            elif node.op == \"call_module\":\n                submod = gm.get_submodule(node.target)\n                if submod.__class__.__name__ not in DONT_EXPAND_MODULES:\n                    swap_node(\n                        graph,\n                        node,\n                        expand_module_call(f\"{node.target}.\", graph, submod,\n                                           node.args, node.kwargs),\n                    )\n\n    # gm.graph.print_tabular()\n\n\ndef create_names_map(named_params, tied_named_params):\n    \"\"\"named_params is a dictionary of tensors: {'A': A, 'B': B}\n    tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}\n    with potentially tied (or 'duplicated') tensors\n\n    This function creates a mapping from the names in named_params to the\n    names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.\n    \"\"\"\n    named_params = {k: v for k, v in named_params}\n    tied_named_params = {k: v for k, v in tied_named_params}\n\n    tensors_dict_keys = set(named_params.keys())\n    tied_tensors_dict_keys = set(tied_named_params.keys())\n    assert tensors_dict_keys.issubset(tied_tensors_dict_keys)\n\n    tensor_to_mapping = {}\n    for key, tensor in named_params.items():\n        tensor_to_mapping[tensor] = (key, [])\n    for key, tensor in tied_named_params.items():\n        assert tensor in tensor_to_mapping\n        tensor_to_mapping[tensor][1].append(key.split(\".\"))\n    result = {key: value for key, value in tensor_to_mapping.values()}\n    return result\n\n\ndef _set_nested_attr(obj: nn.Module, names: List[str], value: Tensor) -> None:\n    \"\"\"Set the attribute specified by the given list of names to value.\n    For example, to set the attribute obj.conv.weight,\n    use _del_nested_attr(obj, ['conv', 'weight'], value)\n    \"\"\"\n    if len(names) == 1:\n        setattr(obj, names[0], value)\n    else:\n        _set_nested_attr(getattr(obj, names[0]), names[1:], value)\n\n\ndef _extract_members(mod: nn.Module, _named_members, named_members, subclass):\n    all_named_members = tuple(_named_members(mod, remove_duplicate=False))\n    named_members = tuple(named_members())\n    names_map = create_names_map(named_members, all_named_members)\n\n    # Remove all the members in the model\n    memo = {}\n    for name, p in all_named_members:\n        if p not in memo:\n            memo[p] = subclass(torch.empty_like(p, device=\"meta\"))\n        replacement = memo[p]\n        _set_nested_attr(mod, name.split(\".\"), replacement)\n\n    if len(named_members) == 0:\n        names, params = (), ()\n    else:\n        names, params = zip(*named_members)\n    return params, names, names_map\n\n\ndef extract_weights(mod: nn.Module):\n    \"\"\"This function removes all the Parameters from the model and\n    return them as a tuple as well as their original attribute names.\n    The weights must be re-loaded with `load_weights` before the model\n    can be used again.\n    Note that this function modifies the model in place and after this\n    call, mod.parameters() will be empty.\n    \"\"\"\n    return _extract_members(mod, named_parameters, mod.named_parameters,\n                            nn.Parameter)\n\n\ndef extract_buffers(mod: nn.Module):\n    return _extract_members(mod, named_buffers, mod.named_buffers, lambda x: x)\n\n\n# Copied from functorch/functorch/_src/named_members_polyfill.py\ndef named_members(mod,\n                  get_members_fn,\n                  prefix='',\n                  recurse=True,\n                  remove_duplicate=True):\n    \"\"\"Helper method for yielding various names + members of modules.\n    \"\"\"\n    memo = set()\n    modules = mod.named_modules(\n        prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [\n            (prefix, mod)\n        ]\n    for module_prefix, module in modules:\n        members = get_members_fn(module)\n        for k, v in members:\n            if v is None or v in memo:\n                continue\n            if remove_duplicate:\n                memo.add(v)\n            name = module_prefix + ('.' if module_prefix else '') + k\n            yield name, v\n\n\ndef named_parameters(mod,\n                     prefix: str = '',\n                     recurse: bool = True,\n                     remove_duplicate: bool = True):\n    return named_members(mod,\n                         lambda module: module._parameters.items(),\n                         prefix=prefix,\n                         recurse=recurse,\n                         remove_duplicate=remove_duplicate)\n\n\ndef named_buffers(mod,\n                  prefix: str = '',\n                  recurse: bool = True,\n                  remove_duplicate: bool = True):\n    return named_members(mod,\n                         lambda module: module._buffers.items(),\n                         prefix=prefix,\n                         recurse=recurse,\n                         remove_duplicate=remove_duplicate)\n"
  },
  {
    "path": "alpa/torch/ops/__init__.py",
    "content": ""
  },
  {
    "path": "alpa/torch/ops/mapping.py",
    "content": "# pylint: disable=line-too-long, unused-argument\n\"\"\"Maps PyTorch ops to JAX ops\"\"\"\nimport contextlib\nimport math\nfrom typing import Any, Optional, Sequence, Callable\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom jax import lax\nimport torch\nfrom alpa.torch.tensor_utils import numpy_to_torch_dtype_dict\n\n\n# Adapted from aten/src/ATen/InferSize.h infer_size_impl()\ndef infer_size(shape, numel):\n    newsize = 1\n    infer_dim = None\n    len(shape)\n    res = list(shape)\n    for dim in range(len(shape)):\n        if shape[dim] == -1:\n            if infer_dim is not None:\n                raise ValueError(\"only one dimension can be inferred\")\n            infer_dim = dim\n        elif shape[dim] >= 0:\n            newsize *= shape[dim]\n        else:\n            raise Exception(f\"invalid shape dimension {shape[dim]}\")\n\n    if (numel == newsize) or (infer_dim is not None and newsize > 0 and\n                              numel % newsize == 0):\n        if infer_dim is not None:\n            # We have a degree of freedom here to select the dimension size;\n            # follow NumPy semantics and just bail.  However, a nice error\n            # message is needed because users often use `view` as a way to\n            # flatten & unflatten dimensions and will otherwise be confused\n            # why\n            #   empty_tensor.view( 0, 0)\n            # works yet\n            #   empty_tensor.view(-1, 0)\n            # doesn't.\n            assert newsize != 0, (\n                \"cannot reshape tensor of 0 elements into shape \" + str(shape) +\n                \" because the unspecified dimension size -1 can be any \" +\n                \"value and is ambiguous\")\n            res[infer_dim] = numel // newsize\n        return res\n\n    raise Exception(f\"shape {shape} is invalid for input of size {numel}\")\n\n\ndef init_buffer(\n    init_func,\n    init_func_kwargs,\n    local_rng_seed,\n    worker,\n    device_id: int,\n    shape: Sequence[int],\n    dtype: np.dtype,\n):\n\n    torch_local_rng = torch.Generator()\n    torch_local_rng.manual_seed(local_rng_seed)\n    init_func_kwargs[\"rng\"] = torch_local_rng\n    init_func_kwargs[\"shape\"] = shape\n    init_func_kwargs[\"dtype\"] = numpy_to_torch_dtype_dict[dtype]\n\n    return worker.backend.buffer_from_pyval(init_func(**init_func_kwargs),\n                                            worker.local_devices[device_id])\n\n\ndef torch_abs(x):\n    return jnp.absolute(x)\n\n\ndef torch_add(x, other):\n    return jnp.add(x, other)\n\n\ndef torch_addmm(x, mat1, mat2, beta=1, alpha=1):\n    out = alpha * torch.matmul(mat1, mat2)\n    if beta == 0:\n        return out\n    return beta * x + out\n\n\ndef torch_bmm(x, mat2):\n    return lax.batch_matmul(x, mat2)\n\n\ndef torch_cat(tensors, dim=0):\n    return lax.concatenate(tensors, dim)\n\n\ndef torch_clone(x, memory_format=torch.preserve_format):\n    return jnp.array(x, dtype=x.dtype, copy=True, order=\"K\")\n\n\ndef torch_conv2d(x,\n                 weight,\n                 bias=None,\n                 stride=1,\n                 padding=0,\n                 dilation=1,\n                 groups=1):\n    # References:\n    # - torch-xla impl and haiku / flax impl\n    # - https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb\n    conv_out = lax.conv_general_dilated(\n        x,\n        weight,\n        stride,\n        [(x, x) for x in padding],\n        lhs_dilation=None,\n        rhs_dilation=None,\n        dimension_numbers=lax.conv_dimension_numbers(\n            x.shape,\n            weight.shape,\n            (\"NCHW\", \"OIHW\",\n             \"NCHW\"),  # TODO: parameterize this! don't assume NCHW format.\n        ),\n        feature_group_count=groups,\n        batch_group_count=1,\n    )\n    if bias is not None:\n        bias_reshaped = bias.reshape(1, bias.shape[0], 1, 1)\n        bias_reshaped = jnp.broadcast_to(bias_reshaped, [\n            conv_out.shape[0], bias.shape[0], conv_out.shape[2],\n            conv_out.shape[3]\n        ])\n        return conv_out + bias_reshaped\n    else:\n        return conv_out\n\n\ndef torch_div(x, other, rounding_mode=None):\n    ret = None\n    if rounding_mode is None:\n        ret = jnp.true_divide(x, other)\n    elif rounding_mode == \"trunc\":\n        ret = jnp.trunc(jnp.true_divide(x, other))\n    elif rounding_mode == \"floor\":\n        ret = jnp.floor_divide(x, other)\n    if ret is not None:\n        return ret\n    else:\n        raise NotImplementedError(f\"{rounding_mode} is not supported\")\n\n\ndef torch_dropout(x, p=0.5, training=True, inplace=False):\n    assert not inplace, \"Inplace dropout is not supported\"\n    if p == 0.0:\n        return x\n    if training:\n        # Copied from flax.linen.Dropout impl\n        keep_prob = 1.0 - p\n        # NOTE: pass None for rng, since Alpa ignores it anyway.\n        mask = jax.random.bernoulli(None, p=keep_prob, shape=x.shape)\n        return lax.select(mask, x, jnp.zeros_like(x))\n    else:\n        return x\n\n\ndef torch_exp(x):\n    return jnp.exp(x)\n\n\ndef torch_expand(x, sizes):\n    computed_sizes = list(sizes)\n    for dim, size in enumerate(sizes):\n        if size == -1:\n            computed_sizes[dim] = x.shape[dim]\n    return lax.broadcast_in_dim(x, computed_sizes, list(range(len(x.shape))))\n\n\ndef maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):\n    if dim_post_expr <= 0:\n        assert wrap_scalar\n        dim_post_expr = 1\n    min_dim = -dim_post_expr\n    max_dim = dim_post_expr - 1\n    assert not (dim < min_dim or dim > max_dim)\n    if dim < 0:\n        dim += dim_post_expr\n    return dim\n\n\ndef torch_flatten(x, start_dim=0, end_dim=-1):\n    input_shape = x.shape\n    start_dim = maybe_wrap_dim(start_dim, len(input_shape))\n    end_dim = maybe_wrap_dim(end_dim, len(input_shape))\n    assert start_dim <= end_dim\n    if start_dim == end_dim:\n        return x\n    slice_numel = 1\n    for i in range(start_dim, end_dim + 1):\n        slice_numel *= input_shape[i]\n    shape = []\n    for i in range(start_dim):\n        shape.append(input_shape[i])\n    shape.append(slice_numel)\n    for i in range(end_dim + 1, len(input_shape)):\n        shape.append(input_shape[i])\n    return torch_view(x, shape)\n\n\ndef torch_full_like(x,\n                    fill_value,\n                    dtype=None,\n                    layout=torch.strided,\n                    device=None,\n                    requires_grad=False,\n                    memory_format=torch.preserve_format):\n    return jnp.full_like(x, fill_value, dtype=dtype)\n\n\ndef torch_gelu(x, approximate=False):\n    # TODO: use approximate=True or not?\n    return jax.nn.gelu(x)\n\n\ndef torch_layer_norm(x,\n                     normalized_shape,\n                     weight=None,\n                     bias=None,\n                     eps=1e-05,\n                     cudnn_enable=True):\n    # TODO: this formula might be wrong\n    axis = len(x.shape) - len(normalized_shape)\n    mean_val = jnp.mean(x, axis=axis, keepdims=True)\n    var = jnp.mean((x - mean_val)**2, axis=axis, keepdims=True)\n    out = (x - mean_val) / jnp.sqrt(var + eps)\n    if weight is not None:\n        out = out * weight\n    if bias is not None:\n        out = out + bias\n    return out\n\n\ndef torch_matmul(x, other):\n    return jnp.matmul(x, other)\n\n\ndef torch_max(x, dim=None, keepdim=False):\n    return jnp.max(x, axis=dim, keepdims=keepdim)\n\n\ndef torch_mean(x, dim=None, keepdim=False):\n    return jnp.mean(x, axis=dim, keepdims=keepdim)\n\n\ndef torch_mm(x, mat2):\n    return jnp.matmul(x, mat2)\n\n\ndef torch_mul(x1, x2):\n    return jnp.multiply(x1, x2)\n\n\ndef torch_permute(x, dims):\n    return jnp.transpose(x, dims)\n\n\ndef torch_pow(x, exponent):\n    return jnp.power(x, exponent)\n\n\ndef torch_relu(x):\n    return jax.nn.relu(x)\n\n\ndef torch_select(x, dim, index):\n    # TODO: likely inefficient. What's the better way?\n    return lax.slice_in_dim(x, index, index + 1, stride=1, axis=dim)[0]\n\n\ndef torch_slice(x, dim, start, end, step=1):\n    if end > x.shape[dim]:\n        end = x.shape[dim]\n    return lax.slice_in_dim(x, start, end, stride=step, axis=dim)\n\n\ndef torch_softmax(x, dim):\n    x_max = jnp.max(x, axis=dim, keepdims=True)\n    unnormalized = jnp.exp(x - x_max)\n    return unnormalized / jnp.sum(unnormalized, axis=dim, keepdims=True)\n\n\ndef torch_split(x, split_size_or_sections, dim=0):\n    if isinstance(split_size_or_sections, int):\n        split_size = split_size_or_sections\n        sections = list(range(split_size, x.shape[dim], split_size))\n    else:\n        assert isinstance(split_size_or_sections, list)\n        sections = split_size_or_sections\n    return jnp.split(x, sections, axis=dim)\n\n\ndef torch_sqrt(x):\n    return jnp.sqrt(x)\n\n\ndef torch_sub(x, other, alpha=1):\n    return x - alpha * other\n\n\ndef torch_sum(x, dim, keepdim=False):\n    return jnp.sum(x, axis=dim, keepdims=keepdim)\n\n\ndef torch_t(x):\n    return jnp.transpose(x)\n\n\ndef torch_transpose(x, dim0, dim1):\n    return jnp.swapaxes(x, dim0, dim1)\n\n\ndef torch_unbind(x, dim=0):\n    return tuple(\n        jnp.squeeze(p, axis=dim) for p in jnp.split(x, x.shape[dim], axis=dim))\n\n\ndef torch_view(x, shape):\n    return lax.reshape(x, infer_size(shape, x.size))\n\n\ndef torch_zeros_like(x,\n                     *,\n                     dtype=None,\n                     layout=None,\n                     device=None,\n                     requires_grad=False,\n                     memory_format=torch.preserve_format):\n    return jnp.zeros_like(x, dtype=dtype)\n\n\ndef _normalize(x, mean, var, weight, bias, reduction_axes, feature_axes, eps):\n    stats_shape = list(x.shape)\n    for axis in reduction_axes:\n        stats_shape[axis] = 1\n    mean = mean.reshape(stats_shape)\n    var = var.reshape(stats_shape)\n    feature_shape = [1] * x.ndim\n    for ax in feature_axes:\n        feature_shape[ax] = x.shape[ax]\n    y = x - mean\n    mul = lax.rsqrt(var + eps)\n    if weight is not None:\n        mul *= weight.reshape(feature_shape)\n    y *= mul\n    if bias is not None:\n        y += bias.reshape(feature_shape)\n    return jnp.asarray(y, x.dtype)\n\n\ndef torch_batch_norm(\n    x: torch.Tensor,\n    running_mean: Optional[torch.Tensor],\n    running_var: Optional[torch.Tensor],\n    weight: Optional[torch.Tensor] = None,\n    bias: Optional[torch.Tensor] = None,\n    training: bool = False,\n    momentum: float = 0.1,\n    eps: float = 1e-5,\n):\n    # Ref: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.BatchNorm.html\n    def _abs_sq(x):\n        \"\"\"Computes the elementwise square of the absolute value |x|^2.\"\"\"\n        if jnp.iscomplexobj(x):\n            return lax.square(lax.real(x)) + lax.square(lax.imag(x))\n        else:\n            return lax.square(x)\n\n    def _compute_stats(x,\n                       axes,\n                       axis_name: Optional[str] = None,\n                       axis_index_groups: Any = None):\n        # promote x to at least float32, this avoids half precision computation\n        # but preserves double or complex floating points\n        x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))\n        mean = jnp.mean(x, axes)\n        mean2 = jnp.mean(_abs_sq(x), axes)\n        if axis_name is not None:\n            concatenated_mean = jnp.concatenate([mean, mean2])\n            mean, mean2 = jnp.split(\n                lax.pmean(concatenated_mean,\n                          axis_name=axis_name,\n                          axis_index_groups=axis_index_groups), 2)\n        # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due\n        # to floating point round-off errors.\n        var = jnp.maximum(0.0, mean2 - _abs_sq(mean))\n        return mean, var\n\n    feature_axes = [1]  # Expect (N, C, ...) shape\n    reduction_axes = tuple(i for i in range(x.ndim) if i not in feature_axes)\n    feature_shape = [x.shape[ax] for ax in feature_axes]\n\n    if not training:\n        mean, var = running_mean, running_var\n    else:\n        running_mean = jnp.zeros(feature_shape, jnp.float32)\n        running_var = jnp.ones(feature_shape, jnp.float32)\n        mean, var = _compute_stats(x, reduction_axes)\n\n        running_mean = momentum * running_mean + (1 - momentum) * mean\n        running_var = momentum * running_var + (1 - momentum) * var\n\n    out = _normalize(x, mean, var, weight, bias, reduction_axes, feature_axes,\n                     eps)\n\n    return out, running_mean, running_var\n\n\ndef torch_nn_functional_batch_norm(\n    x: torch.Tensor,\n    running_mean: Optional[torch.Tensor],\n    running_var: Optional[torch.Tensor],\n    weight: Optional[torch.Tensor] = None,\n    bias: Optional[torch.Tensor] = None,\n    training: bool = False,\n    momentum: float = 0.1,\n    eps: float = 1e-5,\n):\n    return torch_batch_norm(\n        x=x,\n        running_mean=running_mean,\n        running_var=running_var,\n        weight=weight,\n        bias=bias,\n        training=training,\n        momentum=momentum,\n        eps=eps,\n    )\n\n\ndef torch_nn_functional_dropout(x, p=0.5, training=True, inplace=False):\n    return torch_dropout(x, p=p, training=training, inplace=inplace)\n\n\ndef torch_nn_functional_linear(x, weight, bias=None):\n    output = torch.matmul(x, torch.t(weight))\n    if bias is not None:\n        output = output + bias\n    return output\n\n\ndef torch_nn_functional_mse_loss(\n    x: torch.Tensor,\n    target: torch.Tensor,\n    size_average: Optional[bool] = None,\n    reduce: Optional[bool] = None,\n    reduction: str = \"mean\",\n):\n    # TODO: add handling for `size_average` / `reduce` / `reduction`\n    return jnp.mean((x - target)**2)\n\n\ndef torch_nn_functional_softmax(x, dim):\n    return torch_softmax(x=x, dim=dim)\n\n\ndef _calculate_fan_in_and_fan_out(tensor):\n    dimensions = len(tensor.shape)\n    if dimensions < 2:\n        raise ValueError(\"Fan in and fan out can not be computed \"\n                         \"for tensor with fewer than 2 dimensions\")\n\n    num_input_fmaps = tensor.shape[1]\n    num_output_fmaps = tensor.shape[0]\n    receptive_field_size = 1\n    if len(tensor.shape) > 2:\n        # math.prod is not always available, accumulate the product manually\n        # we could use functools.reduce but that is not supported by TorchScript\n        for s in tensor.shape[2:]:\n            receptive_field_size *= s\n    fan_in = num_input_fmaps * receptive_field_size\n    fan_out = num_output_fmaps * receptive_field_size\n\n    return fan_in, fan_out\n\n\ndef torch_nn_init_xavier_uniform(x, gain: float = 1.0):\n    fan_in, fan_out = _calculate_fan_in_and_fan_out(x)\n    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))\n    a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation\n    useless_key = jax.random.PRNGKey(0)\n    return jax.random.uniform(useless_key, x.shape, x.dtype, -a, a)\n\n\ndef torch_nn_init_normal(x, mean: float = 0.0, std: float = 1.0):\n    useless_key = jax.random.PRNGKey(0)\n    return (jax.random.normal(useless_key, x.shape, x.dtype) + mean) * std\n\n\n# PyTorch .detach() is equivalent to JAX lax.stop_gradient():\n# - https://github.com/google/jax/issues/2025\n# PyTorch .view() is equivalent to JAX lax.reshape():\n# - https://jax.readthedocs.io/en/latest/_autosummary/lax.reshape.html\n\nop_orig_impl_dict = {}\nop_patch_list = [\n    (torch, \"abs\", torch_abs),\n    (torch, \"add\", torch_add),\n    (torch, \"addmm\", torch_addmm),\n    (torch, \"bmm\", torch_bmm),\n    (torch, \"cat\", torch_cat),\n    (torch, \"clone\", torch_clone),\n    (torch, \"conv2d\", torch_conv2d),\n    (torch, \"div\", torch_div),\n    (torch, \"dropout\", torch_dropout),\n    (torch, \"exp\", torch_exp),\n    (torch, \"expand\", torch_expand),\n    (torch, \"flatten\", torch_flatten),\n    (torch, \"full_like\", torch_full_like),\n    # (torch, \"gelu\", torch_gelu),\n    (torch, \"layer_norm\", torch_layer_norm),\n    (torch, \"matmul\", torch_matmul),\n    (torch, \"max\", torch_max),\n    (torch, \"mean\", torch_mean),\n    (torch, \"mm\", torch_mm),\n    (torch, \"mul\", torch_mul),\n    (torch, \"permute\", torch_permute),\n    (torch, \"pow\", torch_pow),\n    (torch, \"relu\", torch_relu),\n    (torch, \"select\", torch_select),\n    # (torch, \"slice\", torch_slice),\n    (torch, \"softmax\", torch_softmax),\n    (torch, \"split\", torch_split),\n    (torch, \"sqrt\", torch_sqrt),\n    (torch, \"sub\", torch_sub),\n    (torch, \"sum\", torch_sum),\n    (torch, \"t\", torch_t),\n    (torch, \"transpose\", torch_transpose),\n    (torch, \"unbind\", torch_unbind),\n    (torch, \"view\", torch_view),\n    (torch, \"zeros_like\", torch_zeros_like),\n    (torch.nn.functional, \"batch_norm\", torch_nn_functional_batch_norm),\n    (torch.nn.functional, \"dropout\", torch_nn_functional_dropout),\n    (torch.nn.functional, \"linear\", torch_nn_functional_linear),\n    (torch.nn.functional, \"mse_loss\", torch_nn_functional_mse_loss),\n    (torch.nn.functional, \"softmax\", torch_nn_functional_softmax),\n    (torch.nn.init, \"xavier_uniform\", torch_nn_init_xavier_uniform),\n    (torch.nn.init, \"normal\", torch_nn_init_normal),\n    # TODO: add hard error for in-place ops\n]\n\n\ndef patch_ops():\n    for python_module, op_name, new_impl in op_patch_list:\n        python_module_fqn = str(python_module).split(\"<module '\")[1].split(\n            \"'\")[0]\n        op_orig_impl_dict[f\"{python_module_fqn}.{op_name}\"] = getattr(\n            python_module, op_name, None)\n        setattr(python_module, op_name, new_impl)\n\n\ndef unpatch_ops():\n    for python_module, op_name, _ in op_patch_list:\n        python_module_fqn = str(python_module).split(\"<module '\")[1].split(\n            \"'\")[0]\n        op_orig_impl = op_orig_impl_dict.get(f\"{python_module_fqn}.{op_name}\",\n                                             None)\n        if op_orig_impl is not None:\n            setattr(python_module, op_name, op_orig_impl)\n        else:\n            delattr(python_module, op_name)\n\n\n@contextlib.contextmanager\ndef bind_ops(enabled=True):\n    \"\"\"Context manager within which many PyTorch ops are monkey-patched\n    to support distributed computation with Alpa.\n    \"\"\"\n    if enabled:\n        patch_ops()\n    try:\n        yield\n    finally:\n        if enabled:\n            unpatch_ops()\n\n\ndef enable_dist_for_func(func: Callable = None):\n    \"\"\"Returns a callable that executes `func` within `bind_ops` context.\n    \"\"\"\n\n    def wrapped_func(*args, **kwargs):\n        with bind_ops():\n            return func(*args, **kwargs)\n\n    return wrapped_func\n"
  },
  {
    "path": "alpa/torch/optim/__init__.py",
    "content": "\"\"\"Optimizers\n\"\"\"\nfrom .adam import adam\n"
  },
  {
    "path": "alpa/torch/optim/adam.py",
    "content": "\"\"\"Adam optimizer\"\"\"\nimport copy\n\nimport torch\n\n\ndef adam(lr=1e-4):\n    \"\"\"torchoptim.adam(**adam_config)(params)\n        Factory that generates functional version of Adam optimizer.\n        Implementation has no in-place op and no data-dependent control flow.\n\n        Returns:\n            - `optim_func`: a function that:\n                - takes (`params`, `optim_state`, `params_grad`) as input\n                - returns (`params`, `optim_state`)\n                  after applying Adam algorithm\n            - `optim_state_init_func`: a function that:\n                - takes `optim_state` as input\n                - returns `optim_state` which is Adam optimizer state\n            - `optim_state`: tracked state (shape-only) of Adam optimizer.\n    \"\"\"\n\n    # TODO FIXME: properly implement Adam optimizer\n\n    def optim_gen(params):\n\n        def optim_func(params, optim_state, params_grad):\n            for k in params:\n                params[k] = params[k] + params_grad[k] * lr\n                optim_state[k] = optim_state[k] + params_grad[k]\n            return params, optim_state\n\n        optim_state = copy.deepcopy(params)\n\n        def optim_state_init_func(optim_state):\n            new_state = {}\n            for k, v in optim_state.items():\n                new_state[k] = torch.full_like(v, 0.0)\n            return new_state\n\n        return optim_func, optim_state_init_func, optim_state\n\n    return optim_gen\n"
  },
  {
    "path": "alpa/torch/tensor_utils.py",
    "content": "\"\"\"Tensor-related utility functions.\n\"\"\"\nfrom typing import Any\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport torch\n\nimport alpa\nimport alpa.torch as atorch\n\n# Copied from torch/testing/_internal/common_utils.py#L349\n# Dict of NumPy dtype -> torch dtype (when the correspondence exists)\nnumpy_to_torch_dtype_dict = {\n    np.dtype(np.bool): torch.bool,\n    np.dtype(np.uint8): torch.uint8,\n    np.dtype(np.int8): torch.int8,\n    np.dtype(np.int16): torch.int16,\n    np.dtype(np.int32): torch.int32,\n    np.dtype(np.int64): torch.int64,\n    np.dtype(np.float16): torch.float16,\n    np.dtype(np.float32): torch.float32,\n    np.dtype(np.float64): torch.float64,\n    np.dtype(np.complex64): torch.complex64,\n    np.dtype(np.complex128): torch.complex128,\n}\n\n# Dict of torch dtype -> NumPy dtype\ntorch_to_numpy_dtype_dict = {\n    value: key for (key, value) in numpy_to_torch_dtype_dict.items()\n}\n\n\ndef make_shaped_array_from_pt_tensor(pt_tensors):\n\n    def transform(pt_tensor):\n        shape = list(pt_tensor.shape)\n        np_dtype = torch_to_numpy_dtype_dict[pt_tensor.dtype]\n        return jax.abstract_arrays.ShapedArray(shape, np_dtype)\n\n    return jax.tree_map(transform, pt_tensors)\n\n\ndef initialize_with_zeros(*args):\n    if atorch.mode() == \"local\":\n        return jax.tree_map(lambda x: torch.zeros(*x.shape, dtype=x.dtype),\n                            args)\n    else:\n        return jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), args)\n\n\ndef to_format(target_format: str, inp: Any):\n    \"\"\"Converts inputs to the format specified by `target_format`.\n    Supported formats are \"local\" and \"dist\".\n    \"\"\"\n    assert target_format in [\"local\", \"dist\"]\n    ret = None\n    if isinstance(inp, tuple):\n        ret = tuple(to_format(target_format, x) for x in inp)\n    elif isinstance(inp, list):\n        ret = [to_format(target_format, x) for x in inp]\n    elif isinstance(inp, dict):\n        ret = dict(\n            zip(inp.keys(),\n                [to_format(target_format, x) for x in inp.values()]))\n    elif isinstance(inp, torch.Tensor):\n        if target_format == \"dist\":\n            if str(inp.device) == \"meta\":\n                ret = make_shaped_array_from_pt_tensor(inp)\n            elif str(inp.device) == \"cpu\":\n                ret = inp.numpy()\n            else:\n                # TODO: add support for CUDA input tensor\n                raise NotImplementedError(\n                    f\"PyTorch tensor of device {type(inp.device)} \"\n                    \"is not supported yet.\")\n        elif target_format == \"local\":\n            ret = inp\n    elif isinstance(inp, alpa.device_mesh.DistributedArray):\n        if target_format == \"local\":\n            ret = torch.from_numpy(np.array(inp))\n        elif target_format == \"dist\":\n            ret = inp\n    if ret is not None:\n        return ret\n    else:\n        raise NotImplementedError(\n            f\"Value of type {type(inp)} is not supported yet.\")\n\n\ndef assert_format(target_format: str, *inputs):\n    \"\"\"Asserts inputs are in the format specified by `target_format`.\n    Supported formats are \"local\" and \"dist\".\n    \"\"\"\n    assert target_format in [\"local\", \"dist\"]\n    for inp in inputs:\n        if isinstance(inp, (tuple, list)):\n            assert_format(target_format, *inp)\n        elif isinstance(inp, dict):\n            assert_format(target_format, *inp.values())\n        else:\n            assert (\n                isinstance(inp, torch.Tensor) and target_format == \"local\"\n            ) or (\n                isinstance(inp,\n                           (alpa.device_mesh.DistributedArray,\n                            alpa.device_mesh.ReplicatedDistributedArray)) and\n                target_format == \"dist\"\n            ), f\"This input is not of {target_format} format: {inp}, \" + \\\n            \"of type {type(inp)}\"\n"
  },
  {
    "path": "alpa/torch/trainer.py",
    "content": "# pylint: disable=line-too-long, pointless-string-statement, cell-var-from-loop\n\"\"\"Example trainer that runs an SGD training loop\"\"\"\nfrom collections import namedtuple\n\nimport alpa\nimport alpa.torch as atorch\n\"\"\"\nFAQ: When to use atorch vs. torch?\n\nAnswer:\n- All `atorch` usage is contained within the trainer code (i.e. this file),\nno `atorch` mentions in user code (e.g. test_torch_simple.py).\n- No `torch` usage in trainer code. e.g. PyTorch dataloader will be\nencapsulated in alpa.torch dataloader (TBD), where we will add features\nrelated to dist dataloading.\n\"\"\"\n\n# A tuple to wrap all training states.\nTrainState = namedtuple(\"TrainState\", [\"params\", \"bufs\", \"optim_state\"])\n\n\ndef train_torch_module(pt_module_gen, weight_init_func, dataloader, loss_func,\n                       optim_gen, parallel_method):\n    for mode in [\"local\", \"dist\"]:\n        # \"local\": pure PT eager mode on a single GPU,\n        #     allows print in middle of graph, no dist training\n        # \"dist\": graph mode by lowering PT program to JAX,\n        #     doesn't allow print, supports dist training\n        # NOTE: as we see below, the two modes can share most of the code.\n        atorch.set_mode(mode)\n\n        # Prints verbose log for debugging.\n        atorch.debug = True\n\n        if atorch.mode() == \"dist\":\n            alpa.init(cluster=\"ray\")\n\n        # Functionalize the PyTorch model and optimizer\n        pt_module = atorch.meta_init(pt_module_gen)\n        module_func, params_aval, bufs_aval, name_map = atorch.functionalize(\n            pt_module)\n        optim_func, optim_state_init_func, optim_state_aval = optim_gen(\n            params_aval)\n\n        # Define one gradient descent step\n        def train_step(state, batch):\n            inputs, targets = batch\n\n            # wrap forward pass + loss computation in a function\n            def compute_loss(params, bufs, inputs, targets):\n                # do forward pass\n                bufs, out = module_func(params, bufs, inputs)\n\n                # do loss computation\n                loss_value = loss_func(out, targets)\n                return loss_value, bufs\n\n            # do model forward + backward pass\n            (loss_value, bufs), params_grad = atorch.value_and_grad(\n                compute_loss, has_aux=True)(state.params, state.bufs, inputs,\n                                            targets)\n\n            # do optimizer step\n            params, optim_state = optim_func(state.params, state.optim_state,\n                                             params_grad)\n\n            return TrainState(params, bufs, optim_state), loss_value\n\n        # Define the state initialization function\n        def create_train_state():\n            params, bufs, optim_state = atorch.initialize_with_zeros(\n                params_aval, bufs_aval, optim_state_aval)\n            params, bufs = weight_init_func(pt_module, name_map, params, bufs)\n            optim_state = optim_state_init_func(optim_state)\n            return TrainState(params, bufs, optim_state)\n\n        # Parallelize train function and state initialization function\n        if atorch.mode() == \"dist\":\n            train_step = alpa.parallelize(\n                atorch.enable_dist_for_func(train_step),\n                method=parallel_method,\n                # NOTE: preserves mem addr and sharding spec for the first argument\n                donate_argnums=(0,),\n                # NOTE: the second argument is input batch\n                batch_argnums=(1,),\n                static_argnums=(),\n            )\n\n            # Assume we have a dataloader that supports `peek` function\n            # (i.e. look at next batch but don't advance the pointer)\n            pt_batch = dataloader[0]  # dataloader.peek()\n            pt_batch = atorch.make_shaped_array_from_pt_tensor(pt_batch)\n\n            create_train_state = alpa.parallelize(\n                atorch.enable_dist_for_func(create_train_state),\n                method=alpa.CreateStateParallel(train_step, pt_batch))\n\n        # Initialize weights and optimizer states\n        state = create_train_state()\n\n        # Run training loops\n        for i, pt_batch in enumerate(dataloader):\n            pt_batch = atorch.to_format(atorch.mode(), pt_batch)\n            state, loss_value = train_step(state, pt_batch)\n\n            # do whatever with the loss value, e.g. plot it on a graph\n            print(f\"Iter: {i}, Loss: {float(loss_value):.6f}\")\n\n        if atorch.mode() == \"dist\":\n            alpa.shutdown()\n"
  },
  {
    "path": "alpa/util.py",
    "content": "# pylint: disable=consider-using-enumerate\n\"\"\"Common utilities.\"\"\"\nimport functools\nimport itertools as it\nimport logging\nimport os\nimport subprocess\nimport re\nimport socket\nimport time\nfrom collections import OrderedDict\nfrom functools import partial, partialmethod\nimport threading\nfrom typing import Iterable, Dict, Sequence, Any, List\nfrom warnings import warn\n\nfrom flax.training import train_state\nfrom flax.training.common_utils import stack_forest\nimport jax\nfrom jax._src.source_info_util import SourceInfo\nimport jax.numpy as jnp\nfrom jax._src import dispatch, util\nfrom jax._src.api import FLAGS, ShapeDtypeStruct\nfrom jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe\nfrom jax.api_util import shaped_abstractify\nfrom jax import core\nfrom jax.core import (Atom, ClosedJaxpr, DropVar, Jaxpr, JaxprEqn, Literal,\n                      Primitive, ShapedArray, Var, AbstractValue, gensym)\nfrom jax.experimental.maps import FrozenDict\nfrom jax import linear_util as lu\nfrom jax.interpreters import partial_eval as pe\nfrom jax.interpreters import xla, pxla, mlir\nfrom jax.interpreters.xla import _DeviceArray\nfrom jax.tree_util import tree_map, tree_flatten, PyTreeDef\nimport numpy as np\nimport ray\nfrom ray.util.placement_group import get_current_placement_group,\\\n    PlacementGroup\nimport tqdm\n\nfrom alpa import device_mesh\nfrom alpa.global_env import global_config, is_worker\nfrom alpa.monkey_patch import (restore_random, monkey_patch_random,\n                               rng_primitives)\nfrom alpa.wrapped_hlo import HloStatus, WrappedHlo\n\nPLACEMENT_GROUP_TIMEOUT_S_ENV = \"ALPA_PLACEMENT_GROUP_TIMEOUT_S_ENV\"\n\n########################################\n##### Alpa API Utilities\n########################################\n\nlogger = logging.getLogger(__name__)\n\n\ndef freeze_dict(pytree: PyTreeDef):\n    \"\"\"Convert a pytree to a FrozenDict.\"\"\"\n\n    def is_leaf(x):\n        return isinstance(x, dict)\n\n    def freeze(x):\n        if isinstance(x, dict):\n            return FrozenDict(x)\n        return x\n\n    return tree_map(freeze, pytree, is_leaf)\n\n\ndef auto_static_argnums(args: Sequence[Any]):\n    \"\"\"Return the indices of static arguments according to heuristic rules.\"\"\"\n\n    def is_static_arg(arg):\n        if isinstance(arg, (bool, int, float, str)):\n            return True\n\n        if isinstance(arg, train_state.TrainState):\n            return False\n\n        xs, _ = tree_flatten(arg)\n        for x in xs:\n            try:\n                x = shaped_abstractify(x)\n            except TypeError:\n                return True\n        return False\n\n    return tuple(i for i in range(len(args)) if is_static_arg(args[i]))\n\n\ndef auto_donate_argnums(args: Sequence[Any]):\n    \"\"\"Return the indices of donated arguments according to heuristic rules.\"\"\"\n\n    def should_donate(x):\n        # Always donate optimizer\n        if isinstance(x, train_state.TrainState):\n            return True\n        return False\n\n    return tuple(i for i in range(len(args)) if should_donate(args[i]))\n\n\ndef abstractify_with_aval(x):\n    if isinstance(x, ShapedArray):\n        return x\n    elif isinstance(x, ShapeDtypeStruct):\n        return ShapedArray(x.shape, x.dtype, named_shape=x.named_shape)\n    else:\n        return xla.abstractify(x)\n\n\ndef update_jax_platform(platform):\n    \"\"\"Update the jax backend platform.\"\"\"\n    jax.config.update(\"jax_platform_name\", platform)\n    xb.get_backend.cache_clear()\n\n\nclass GradFuncTransformContext:\n    \"\"\"\n    A context to hold transformations applied to the forward function\n    before calling alpa.grad or alpa.value_and_grad.\n    \"\"\"\n    transforms = []\n\n    def __init__(self, transform):\n        self.transform = transform\n\n    def __enter__(self):\n        GradFuncTransformContext.transforms.append(self.transform)\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        GradFuncTransformContext.transforms.pop()\n\n\n########################################\n##### Data Structure Utilities\n########################################\n\n\ndef to_int_tuple(array: np.ndarray):\n    \"\"\"Convert a numpy array to int tuple.\"\"\"\n    if array is None:\n        return tuple()\n    return tuple(int(x) for x in array)\n\n\ndef check_arithmetic_sequence(array: np.ndarray):\n    \"\"\"Check the input 1-D array is an arithmetic sequence. Return\n    the delta if Ture and None otherwise.\"\"\"\n    if len(array) < 2:\n        return None\n    delta = array[1] - array[0]\n    for i in range(2, len(array)):\n        if array[i] - array[i - 1] != delta:\n            return None\n    return delta\n\n\nclass OrderedSet:\n    \"\"\"An ordered set implemented by using the built-in OrderedDict.\"\"\"\n\n    def __init__(self, iterable=()):\n        self.dict = OrderedDict()\n        self.dict.update({x: None for x in iterable})\n\n    def add(self, *args):\n        self.dict.update({x: None for x in args})\n\n    def update(self, other):\n        self.dict.update({x: None for x in other})\n\n    def union(self, other):\n        result = OrderedSet(self)\n        result.update(other)\n        return result\n\n    def intersection_update(self, other):\n        for x in [x for x in self.dict if x not in other]:\n            del self.dict[x]\n\n    def intersection(self, other):\n        return OrderedSet(x for x in self if x in other)\n\n    def discard(self, element):\n        if element in self:\n            del self.dict[element]\n\n    def remove(self, element):\n        if element not in self:\n            raise KeyError(element)\n        del self.dict[element]\n\n    def clear(self):\n        self.dict.clear()\n\n    def difference(self, other):\n        return OrderedSet([x for x in self if x not in other])\n\n    def difference_update(self, other):\n        for x in other:\n            self.discard(x)\n\n    def symmetric_difference(self, other):\n        result = OrderedSet()\n        for x in self:\n            if x not in other:\n                result.add(x)\n        for x in other:\n            if x not in self:\n                result.add(x)\n        return result\n\n    def __iter__(self):\n        return iter(self.dict)\n\n    def __len__(self):\n        return len(self.dict)\n\n    def __contains__(self, element):\n        return element in self.dict\n\n    def __repr__(self):\n        return \"OrderedSet([\" + \", \".join(repr(x) for x in self) + \"])\"\n\n    def __or__(self, other):\n        return self.union(other)\n\n    def __and__(self, other):\n        return self.intersection(other)\n\n    def __sub__(self, other):\n        return self.difference(other)\n\n    def __xor__(self, other):\n        return self.symmetric_difference(other)\n\n    def __ior__(self, other):\n        self.update(other)\n\n    def __iand__(self, other):\n        self.intersection_update(other)\n\n    def __isub__(self, other):\n        self.difference_update(other)\n\n    def __eq__(self, other):\n        if isinstance(other, OrderedSet):\n            return self.dict == other.dict\n        return False\n\n    @classmethod\n    def __class_getitem__(cls, item):\n        return f\"{cls.__name__}[{item.__name__}]\"\n\n\nclass DisjointDict:\n    \"\"\"A dictionary for recursive lookup.\n    Path compression is used to avoid excess of maximum recursion depth.\"\"\"\n\n    def __init__(self):\n        self.values = {}\n\n    def update(self, keys, values):\n        if not isinstance(keys, Iterable):\n            assert not isinstance(values, Iterable)\n            self.values[keys] = values\n            return\n        for key, value in zip(keys, values):\n            self.values[key] = value\n\n    def recursive_lookup(self, key):\n        lookup_queue = [key]\n        value = None\n        while len(lookup_queue) > 0:\n            k = lookup_queue.pop()\n            if value is not None:\n                self.values[k] = value\n                continue\n            if k not in self.values:\n                value = k\n                continue\n            lookup_queue.append(k)\n            lookup_queue.append(self.values[k])\n        return value\n\n    def keys(self):\n        return list(self.values.keys())\n\n\ndef cached_property(fn, *args, **kwargs):\n    \"\"\"\n    Decorator to make a function a \"cached property\".\n\n    This means that it is a property whose return value is cached after the\n    first time it is called.\n\n    Args:\n        fn: The function to be made a cached property\n        *args: Any args for the function\n        **kwargs: Any kwargs for the function\n    Returns:\n        function\n    \"\"\"\n    return property(functools.lru_cache()(fn, *args, **kwargs))\n\n\n########################################\n##### XLA API Utilities\n########################################\n\n\ndef get_compile_options(num_replicas: int,\n                        num_partitions: int,\n                        device_assignment: np.ndarray,\n                        use_spmd_partitioning: bool,\n                        parameter_is_tupled_arguments: int,\n                        build_random_seed: int,\n                        spmd_propagation_to_outputs: bool = False):\n    \"\"\"Return CompileOptions for XLA compilation.\"\"\"\n    compile_options = xb.get_compile_options(\n        num_replicas=num_replicas,\n        num_partitions=num_partitions,\n        device_assignment=device_assignment,\n        use_spmd_partitioning=use_spmd_partitioning,\n    )\n    compile_options.parameter_is_tupled_arguments = (\n        parameter_is_tupled_arguments)\n    build_options = compile_options.executable_build_options\n    build_options.seed = build_random_seed\n    build_options.allow_spmd_sharding_propagation_to_output =\\\n        spmd_propagation_to_outputs\n    return compile_options\n\n\ndef jaxpr_to_hlo(name: str,\n                 closed_jaxpr: ClosedJaxpr,\n                 donated_invars: Sequence[bool],\n                 platform: str = \"cuda\"):\n    \"\"\"Convert a jaxpr to a wrapped XLA HloModule.\n\n    Reference code: jax/jax/_src/dispatch.py::lower_xla_callable\n    \"\"\"\n    consts = closed_jaxpr.consts\n    map(dispatch.prefetch,\n        it.chain(consts, dispatch.jaxpr_literals(closed_jaxpr.jaxpr)))\n\n    # Convert jaxpr to XLA HLO\n    tuple_args = False\n    axis_env = xla.AxisEnv(nreps=1, names=(), sizes=())\n    name_stack = util.new_name_stack(xla.wrap_name(name, \"parallelize\"))\n    closed_jaxpr = ClosedJaxpr(closed_jaxpr.jaxpr, consts)\n    unordered_effects = [\n        eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects\n    ]\n    ordered_effects = [\n        eff for eff in closed_jaxpr.effects if eff in core.ordered_effects\n    ]\n    lowering_result = mlir.lower_jaxpr_to_module(\n        name, closed_jaxpr, unordered_effects, ordered_effects, None, platform,\n        mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)\n    xla_computation = xe.mlir.mlir_module_to_xla_computation(\n        mlir.module_to_string(lowering_result.module),\n        use_tuple_args=tuple_args,\n        return_tuple=True)\n    return WrappedHlo(xla_computation)\n\n\ndef setup_computation_alias(hlo: WrappedHlo, donated_invars: Sequence[bool]):\n    \"\"\"Set input/output alias in xla computation.\n\n    Assume the tensors in output tuple strictly match the donated parameters.\n    \"\"\"\n    program_shape = hlo.program_shape()\n    parameter_shapes = program_shape.parameter_shapes()\n    result_shapes = program_shape.result_shape().tuple_shapes()\n\n    assert len(parameter_shapes) == len(donated_invars), (\n        \"Zhuohan: This error might be caused by an error in \"\n        \"XLA stage slicing.\")\n\n    p_in = 0\n    p_out = 0\n    while p_in < len(parameter_shapes) and p_out < len(result_shapes):\n        if donated_invars[p_in]:\n            if parameter_shapes[p_in] == result_shapes[p_out]:\n                hlo.get_module().setup_alias((p_out,), p_in, ())\n                p_in += 1\n                p_out += 1\n            else:\n                p_out += 1\n        else:\n            p_in += 1\n\n    while p_in < len(parameter_shapes):\n        if donated_invars[p_in]:\n            warn(\"Some vars are not donated\")\n        p_in += 1\n\n\ndef count_communication_primitives(hlo_ir: str,\n                                   ignore_scalar_all_reduce: bool = False):\n    \"\"\"Count the communication primitives in a HLO IR.\"\"\"\n    total = hlo_ir.count(\"channel_id\")\n    all_reduce = hlo_ir.count(\"all-reduce(\") + hlo_ir.count(\"all-reduce-start(\")\n    all_gather = hlo_ir.count(\"all-gather(\") + hlo_ir.count(\"all-gather-start(\")\n    reduce_scatter = hlo_ir.count(\"reduce-scatter(\") + hlo_ir.count(\n        \"reduce-scatter-start(\")\n    all_to_all = hlo_ir.count(\"all-to-all(\") + hlo_ir.count(\"all-to-all-start(\")\n\n    if ignore_scalar_all_reduce:\n        # Ignore allreduce of scalar values\n        scalar_all_reduce = 0\n        scalar_all_reduce += hlo_ir.count(\"all-reduce(f32[]\")\n        scalar_all_reduce += hlo_ir.count(\"all-reduce-start(f32[]\")\n        scalar_all_reduce += hlo_ir.count(\"all-reduce(f16[]\")\n        scalar_all_reduce += hlo_ir.count(\"all-reduce-start(f16[]\")\n        total -= scalar_all_reduce\n        all_reduce -= scalar_all_reduce\n\n    return total, all_reduce, all_gather, reduce_scatter, all_to_all\n\n\ndef compile_dummy_zero_constant():\n    \"\"\"Compile an Hlo module that returns a constant zero.\"\"\"\n    c = xc.XlaBuilder(\"dummy_zero_constant\")\n    sharding = xc.OpSharding()\n    sharding.type = sharding.type.REPLICATED\n    c.set_sharding(sharding)\n    zero = xc.ops.Constant(c, np.array(0, dtype=np.dtype(np.int32)))\n    c.clear_sharding()\n    c = c.build(xc.ops.Tuple(c, [zero]))\n    return WrappedHlo(c, HloStatus.SHARDING_ANNOTATED)\n\n\ndef compile_allocate_zero_buffers(backend, num_devices: int,\n                                  shapes: Sequence[Sequence[int]],\n                                  dtypes: Sequence[jnp.dtype]):\n    \"\"\"Compile an XLA executable that returns zero buffers with given shape and\n    dtypes.\"\"\"\n    c = xc.XlaBuilder(\"allocate_zero_buffers\")\n    sharding = xc.OpSharding()\n    sharding.type = sharding.type.REPLICATED\n    c.set_sharding(sharding)\n    ret = []\n    for shape, dtype in zip(shapes, dtypes):\n        if dtype == \"V2\":\n            dtype = jnp.bfloat16\n\n        zero = xc.ops.Constant(c, jnp.array(0, dtype=dtype))\n        zero = xc.ops.Broadcast(zero, shape)\n        ret.append(zero)\n    c.clear_sharding()\n    c = c.build(xc.ops.Tuple(c, ret))\n\n    compile_options = xb.get_compile_options(\n        num_replicas=1,\n        num_partitions=num_devices,\n        device_assignment=np.arange(num_devices).reshape((1, -1)),\n        use_spmd_partitioning=True,\n    )\n    with XlaPassContext({\n            \"done-event::enable\": global_config.enable_overlapping,\n    }):\n        compiled = backend.compile(c, compile_options)\n    return compiled\n\n\ndef compile_concatenate(mesh_shape, sharding_spec, batch_size, batch_dim, aval):\n    \"\"\"\n    Compile an XLA executable that concatenates values over the batch dimension,\n    keeping the sharding spec unchanged.\n    \"\"\"\n    c = xc.XlaBuilder(\"concatenate buffers\")\n    sharding = pxla.sharding_spec_sharding_proto(sharding_spec)\n    c.set_sharding(sharding)\n    operands = []\n    for batch_idx in range(batch_size):\n        operands.append(\n            xc.ops.Parameter(\n                c, batch_idx,\n                xc.shape_from_pyval(np.ones(aval.shape, aval.dtype))))\n    concated = xc.ops.ConcatInDim(c, operands, batch_dim)\n    hlo_module = c.build(concated).as_hlo_module()\n\n    num_devices = np.prod(mesh_shape)\n    build_random_seed = global_config.compile_random_seed\n    compile_options = get_compile_options(\n        num_replicas=1,\n        num_partitions=num_devices,\n        device_assignment=np.arange(num_devices).reshape((1, -1)),\n        use_spmd_partitioning=True,\n        parameter_is_tupled_arguments=False,\n        build_random_seed=build_random_seed)\n    xe.run_spmd_partitioner(hlo_module, compile_options)\n    return WrappedHlo(hlo_module, HloStatus.SPMD_PARTITIONED)\n\n\ndef compile_allgather(shape, dtype, src_spec, dst_spec, num_devices):\n    \"\"\"\n    Compile an XLA executable that runs allgather to reshard the tensor from src\n    sharding spec to dst sharding spec.\n    \"\"\"\n    c = xc.XlaBuilder(\"allgather\")\n    src_sharding = pxla.sharding_spec_sharding_proto(src_spec)\n    c.set_sharding(src_sharding)\n    operand = xc.ops.Parameter(c, 0, xc.shape_from_pyval(np.ones(shape, dtype)))\n    c.clear_sharding()\n\n    dst_sharding = xc.OpSharding()\n    dst_sharding.type = dst_sharding.type.TUPLE\n    dst_sharding.tuple_shardings = [pxla.sharding_spec_sharding_proto(dst_spec)]\n\n    c.set_sharding(dst_sharding)\n    hlo_module = c.build(xc.ops.Tuple(c, [operand])).as_hlo_module()\n\n    build_random_seed = global_config.compile_random_seed\n    compile_options = get_compile_options(\n        num_replicas=1,\n        num_partitions=num_devices,\n        device_assignment=np.arange(num_devices).reshape((1, -1)),\n        use_spmd_partitioning=True,\n        parameter_is_tupled_arguments=False,\n        build_random_seed=build_random_seed)\n    xe.run_spmd_partitioner(hlo_module, compile_options)\n    return WrappedHlo(hlo_module, HloStatus.SPMD_PARTITIONED)\n\n\ndef get_index_select_computation(sharding_specs, dim, avals, index_shape):\n    \"\"\"Compile an XLA executable that runs index select for each tensor.\"\"\"\n    c = xc.XlaBuilder(\"index_select\")\n    shardings = []\n    selected = []\n    index = xc.ops.Parameter(c, len(avals), index_shape)\n    for i, aval in enumerate(avals):\n        sharding_spec = sharding_specs[i]\n        sharding = pxla.sharding_spec_sharding_proto(sharding_spec)\n        c.set_sharding(sharding)\n        operand = xc.ops.Parameter(\n            c, i, xc.shape_from_pyval(np.ones(aval.shape, aval.dtype)))\n        c.clear_sharding()\n        index_selected = xc.ops.IndexSelect(operand, index, dim)\n        shardings.append(sharding)\n        selected.append(index_selected)\n    sharding2 = xc.OpSharding()\n    sharding2.type = sharding.type.TUPLE\n    sharding2.tuple_shardings = shardings\n    c.set_sharding(sharding2)\n    c = c.build(xc.ops.Tuple(c, selected))\n    return WrappedHlo(c, HloStatus.SHARDING_ANNOTATED)\n\n\ndef get_shard_shape(aval: ShapedArray, sharding_spec: pxla.ShardingSpec):\n    \"\"\"Return the shape of a shard.\"\"\"\n    shape = []\n    for dim, spec_dim in zip(aval.shape, sharding_spec.sharding):\n        if isinstance(spec_dim, pxla.NoSharding):\n            shape.append(dim)\n        elif isinstance(spec_dim, pxla.Chunked):\n            shape.append(dim // np.prod(spec_dim.chunks))\n        elif isinstance(spec_dim, pxla.Unstacked):\n            shape.append(spec_dim.size)\n    return tuple(shape)\n\n\ndef get_microbatch_sharding_spec(spec: pxla.ShardingSpec, batch_dim,\n                                 num_micro_batch):\n    batch_dim_chunks = [num_micro_batch]\n    if isinstance(spec.sharding[batch_dim], pxla.Chunked):\n        batch_dim_chunks.extend(spec.sharding[batch_dim].chunks)\n    batch_dim_axis = 0\n    for sharding in spec.sharding[:batch_dim]:\n        if isinstance(sharding, pxla.Chunked):\n            batch_dim_axis += 1\n\n    new_sharding = list(spec.sharding)\n    new_sharding[batch_dim] = pxla.Chunked(batch_dim_chunks)\n\n    new_mapping = []\n    for mapping in spec.mesh_mapping:\n        if isinstance(mapping, pxla.Replicated):\n            new_mapping.append(mapping)\n            continue\n        assert isinstance(mapping, pxla.ShardedAxis)\n        new_axis = mapping.axis\n        if mapping.axis >= batch_dim_axis:\n            new_axis += 1\n        new_mapping.append(pxla.ShardedAxis(new_axis))\n    new_mapping.append(pxla.ShardedAxis(batch_dim_axis))\n\n    return pxla.ShardingSpec(sharding=tuple(new_sharding),\n                             mesh_mapping=tuple(new_mapping))\n\n\nclass XlaPassContext:\n    \"\"\"A global context for passing arguments from python to XLA c++ passes.\"\"\"\n\n    current = None\n\n    def __init__(self, value_dict):\n        self.value_dict = value_dict\n\n    def __enter__(self):\n        assert XlaPassContext.current is None, (\"Do not support nested context\")\n        XlaPassContext.current = self\n        xe.set_pass_context(self.value_dict)\n\n    def __exit__(self, exc_type, exc_value, exc_traceback):\n        XlaPassContext.current = None\n        xe.clear_pass_context()\n\n\ndef undefined_sharding_spec_proto():\n    \"\"\"Return a proto of ShardingSpec which represents an undefined spec.\"\"\"\n    # We reuse \"Manual\" to represent \"Undefined\"\n    proto = xc.OpSharding()\n    proto.type = xc.OpSharding.Type.MANUAL\n    return proto\n\n\ndef replicated_sharding_spec_proto():\n    \"\"\"Return a proto of ShardingSpec which represents a replicated spec.\"\"\"\n    proto = xc.OpSharding()\n    proto.type = xc.OpSharding.Type.REPLICATED\n    return proto\n\n\n########################################\n##### Jaxpr Utilities\n########################################\ndef clone_jaxpr(closed_jaxpr: ClosedJaxpr,\n                invars: Sequence[Atom] = None,\n                outvars: Sequence[Var] = None,\n                eqns: Sequence[JaxprEqn] = None,\n                constvars: Sequence[Var] = None,\n                consts: Sequence = None):\n    \"\"\"Clone a jaxpr and replace members if they are provided.\"\"\"\n    constvars = closed_jaxpr.jaxpr.constvars if constvars is None else constvars\n    invars = closed_jaxpr.jaxpr.invars if invars is None else invars\n    outvars = closed_jaxpr.jaxpr.outvars if outvars is None else outvars\n    eqns = closed_jaxpr.jaxpr.eqns if eqns is None else eqns\n    consts = closed_jaxpr.consts if consts is None else consts\n    jaxpr = Jaxpr(constvars, invars, outvars, eqns)\n    return ClosedJaxpr(jaxpr, consts)\n\n\ndef new_jaxpr_eqn(invars,\n                  outvars,\n                  primitive,\n                  params,\n                  effects=None,\n                  source_info=None):\n    \"\"\"Create a new jaxpr equation.\"\"\"\n    effects = effects or core.no_effects\n    return core.new_jaxpr_eqn(invars, outvars, primitive, params, effects,\n                              source_info)\n\n\ndef clone_jaxpr_eqn(eqn: JaxprEqn,\n                    invars: Sequence[Atom] = None,\n                    outvars: Sequence[Var] = None,\n                    primitive: Primitive = None,\n                    params: Dict[str, Any] = None,\n                    effects: Any = None,\n                    source_info: SourceInfo = None):\n    invars = list(invars or eqn.invars)\n    outvars = list(outvars or eqn.outvars)\n    primitive = primitive or eqn.primitive\n    params = dict(params or eqn.params)\n    source_info = source_info or eqn.source_info\n    effects = effects or eqn.effects\n    return new_jaxpr_eqn(invars, outvars, primitive, params, effects,\n                         source_info)\n\n\ndef process_remat(closed_jaxpr: ClosedJaxpr):\n    \"\"\"Offload remat call from forward to backward.\n\n    remat in Jax generates some remat_call in the forward part, but these\n    remat_call only outputs constant and does not rely on inputs.\n    Hence, offloading them into the backward part does not enlong any liveness\n    interval, while helps reduce forward output size.\n\n    As Alpa monkey patches random number generation to stateful version,\n    this function also gets the generated rng state and set it an input\n    of the offloaded remat part.\n\n    Args:\n        closed_jaxpr: the original jaxpr.\n\n    Returns:\n        new_jaxpr: the processed jaxpr\n    \"\"\"\n    # pylint: disable=import-outside-toplevel\n    from alpa.pipeline_parallel.primitive_def import pipeline_p\n\n    def only_create_consts(jaxpr: Jaxpr):\n        const_vars = OrderedSet()\n        for eqn in jaxpr.eqns:\n            for var in eqn.invars:\n                if isinstance(var, Var) and var not in const_vars:\n                    return False\n            const_vars.update(\n                [v for v in eqn.outvars if not isinstance(v, DropVar)])\n        return True\n\n    def only_input_consts(eqn: JaxprEqn):\n        in_bytes = 0\n        for var in eqn.invars:\n            if not isinstance(var, Var):\n                continue\n            if isinstance(var, DropVar):\n                continue\n            in_bytes += np.prod(var.aval.shape) * np.dtype(\n                var.aval.dtype).itemsize\n        return in_bytes == 0\n\n    def is_meaningful(inv: Atom):\n        return isinstance(inv, Var) and not isinstance(inv, DropVar)\n\n    def _offload_remat_process_pipeline(eqn: JaxprEqn,\n                                        discard_invars: Sequence[Var]):\n        discard_invars = set(discard_invars)\n        new_invars = []\n        new_outvars = []\n        for inv, outv in zip(eqn.invars, eqn.outvars):\n            if not (is_meaningful(inv) and inv in discard_invars):\n                new_invars.append(inv)\n                new_outvars.append(outv)\n        return clone_jaxpr_eqn(eqn, new_invars, new_outvars)\n\n    def difference_cross_marker(eqns, base, dif):\n        base = set(base)\n        dif = set(v for v in dif if is_meaningful(v))\n        pipeline_mapping = {}\n        for eqn in eqns:\n            if eqn.primitive is pipeline_p:\n                for inv, outv in zip(eqn.invars, eqn.outvars):\n                    if is_meaningful(inv) and is_meaningful(outv):\n                        pipeline_mapping[outv] = inv\n        for var in dif:\n            base.discard(var)\n            while var in pipeline_mapping:\n                var = pipeline_mapping[var]\n                base.discard(var)\n        return base\n\n    rng_primitives_set = set(rng_primitives)\n\n    def add_rng_as_output(jaxpr: Jaxpr):\n        rng_outvars = []\n        for eqn in jaxpr.eqns:\n            if eqn.primitive in rng_primitives_set:\n                assert not eqn.primitive.multiple_results\n                rng_outvars.append(eqn.outvars[0])\n        new_outvars = jaxpr.outvars + rng_outvars\n        return Jaxpr(jaxpr.constvars, jaxpr.invars, new_outvars,\n                     jaxpr.eqns), rng_outvars\n\n    def get_rng_from_input(jaxpr: Jaxpr):\n        new_invars = list(jaxpr.invars)\n        new_eqns = []\n        for eqn in jaxpr.eqns:\n            if eqn.primitive in rng_primitives_set:\n                new_invars.append(eqn.outvars[0])\n            else:\n                new_eqns.append(eqn)\n        return Jaxpr(jaxpr.constvars, new_invars, jaxpr.outvars, new_eqns)\n\n    def clone_outvars(outvars):\n        new_outvars = []\n        var_mapping = {}\n        for v in outvars:\n            if isinstance(v, DropVar):\n                new_outvars.append(v)\n            else:\n                new_v = gensym_fn(v.aval)\n                new_outvars.append(new_v)\n                var_mapping[v] = new_v\n                while v in var_pipeline_mapping:\n                    v = var_pipeline_mapping[v]\n                    var_mapping[v] = new_v\n        return new_outvars, var_mapping\n\n    # Find offloaded eqns\n    offloaded_eqns = set()\n    gensym_fn = gensym([closed_jaxpr.jaxpr])\n\n    for eqn_idx, eqn in enumerate(closed_jaxpr.eqns):\n        if (eqn.primitive == pe.remat_call_p and only_input_consts(eqn) and\n                only_create_consts(eqn.params[\"call_jaxpr\"])):\n            offloaded_eqns.add(eqn_idx)\n    # Find where each eqn is offloaded\n    # A faster way is to rewrite remat to set each call's name unique, but users\n    # may use 'from jax import remat' instead of 'jax.remat()' which disables\n    # monkey patch to remat.\n    # Dict[fwd_outvar -> fwd_remat_call_idx]\n    offloaded_vars_from = {}\n    # Dict[var -> var]\n    var_pipeline_mapping = {}\n    # Dict[bwd_remat_call_idx -> fwd_remat_call_idx]\n    offload_to = {}\n    for eqn_idx in offloaded_eqns:\n        for var in closed_jaxpr.eqns[eqn_idx].outvars:\n            if is_meaningful(var):\n                offloaded_vars_from[var] = eqn_idx\n    for eqn_idx, eqn in enumerate(closed_jaxpr.eqns):\n        if (eqn.primitive == pe.remat_call_p and eqn.params[\"differentiated\"]):\n            for inv in eqn.invars:\n                if is_meaningful(inv) and inv in offloaded_vars_from:\n                    fwd_eqn_idx = offloaded_vars_from[inv]\n                    assert (eqn_idx not in offload_to or\n                            offload_to[eqn_idx] == fwd_eqn_idx\n                           ), \"A backward matches multiple forward.\"\n                    offload_to[eqn_idx] = fwd_eqn_idx\n        elif eqn.primitive == pipeline_p:\n            for inv, outv in zip(eqn.invars, eqn.outvars):\n                if is_meaningful(inv) and inv in offloaded_vars_from:\n                    offloaded_vars_from[outv] = eqn\n                    var_pipeline_mapping[inv] = outv\n    # Insert the fwd remat call and rewrite corresponding bwd remat call\n    new_eqns = []\n    discarded = difference_cross_marker(closed_jaxpr.eqns,\n                                        offloaded_vars_from.keys(),\n                                        closed_jaxpr.jaxpr.outvars)\n    # Dict[fwd_eqn_idx -> Sequence[fwd_rng_outvars]]\n    rng_vars = {}\n    for eqn_idx, eqn in enumerate(closed_jaxpr.eqns):\n        if eqn.primitive is pipeline_p:\n            # Rewrite pipeline_markers\n            new_eqns.append(_offload_remat_process_pipeline(eqn, discarded))\n        elif eqn_idx in offloaded_eqns:\n            # add rng result as an output\n            new_params = dict(eqn.params)\n            new_called, rng_outvars = add_rng_as_output(\n                new_params[\"call_jaxpr\"])\n            new_params[\"call_jaxpr\"] = new_called\n            rng_outvars = [gensym_fn(v.aval) for v in rng_outvars]\n            new_outvars = list(eqn.outvars) + rng_outvars\n            rng_vars[eqn_idx] = rng_outvars\n            cloned_eqn = clone_jaxpr_eqn(eqn,\n                                         outvars=new_outvars,\n                                         params=new_params)\n            new_eqns.append(cloned_eqn)\n        elif eqn_idx not in offload_to:\n            new_eqns.append(eqn)\n        else:\n            inserted_idx = offload_to[eqn_idx]\n            # clone the forward remat call\n            # rewrite the inserted. Remove its rng, add invars from the cloned\n            inserted = closed_jaxpr.eqns[inserted_idx]\n            cloned_invars = list(inserted.invars)\n            cloned_invars.extend(rng_vars[inserted_idx])\n            cloned_params = dict(inserted.params)\n            cloned_params[\"call_jaxpr\"] = get_rng_from_input(\n                inserted.params[\"call_jaxpr\"])\n            cloned_outvars, var_mapping = clone_outvars(inserted.outvars)\n            cloned_fwd = clone_jaxpr_eqn(inserted,\n                                         cloned_invars,\n                                         cloned_outvars,\n                                         params=cloned_params)\n            # rewrite invars for bwd remat call\n            new_invars = [get_var_mapping(var_mapping, v) for v in eqn.invars]\n            new_eqn = clone_jaxpr_eqn(eqn, invars=new_invars)\n            new_eqns.extend([cloned_fwd, new_eqn])\n    return clone_jaxpr(closed_jaxpr, eqns=new_eqns)\n\n\ndef trace_jaxpr_with_micro_batch(fun: lu.WrappedFun,\n                                 batch_invars: Sequence[bool],\n                                 num_micro_batches: int,\n                                 raw_avals: Sequence[AbstractValue],\n                                 batch_dim: int = 0):\n    \"\"\"Trace the jaxpr of the computation of a micro batch.\"\"\"\n    assert batch_dim == 0, \"Only support batch_dim == 0\"\n    # Monkey patch jax.random to fast stateful version\n    monkey_patch_random()\n    monkey_patch_jaxarray()\n\n    avals = []\n    batch_size = None\n    for aval, is_batch_var in zip(raw_avals, batch_invars):\n        if is_batch_var:\n            assert aval.shape[0] % num_micro_batches == 0, (\n                f\"The batch size must be divisable by num_micro_batches. \"\n                f\"batch_size = {aval.shape[0]}, \"\n                f\"num_micro_batches = {num_micro_batches}\")\n            if batch_size is None:\n                batch_size = aval.shape[0] // num_micro_batches\n            else:\n                assert batch_size == aval.shape[0] // num_micro_batches, (\n                    \"The batch dimension must be the same for all batch vars.\")\n            shape = (batch_size,) + aval.shape[1:]\n            avals.append(aval.update(shape=shape))\n        else:\n            avals.append(aval)\n    with jax.disable_jit():\n        jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, avals)\n    closed_jaxpr = ClosedJaxpr(jaxpr, consts)\n\n    # Restore jax.random to original stateless version\n    restore_random()\n    restore_jaxarray()\n    return closed_jaxpr, batch_size\n\n\nbackup_jnp_array = jnp.array\n\n\ndef monkey_patch_jaxarray():\n    \"\"\"Monkey patch jnp.array as jnp.asarray to avoid unnecessary copy.\"\"\"\n    jnp.array = jnp.asarray\n    setattr(Literal, \"__hash__\", lambda self: self.hash)\n\n\ndef restore_jaxarray():\n    \"\"\"Monkey patch jnp.array as jnp.asarray to avoid unnecessary copy.\"\"\"\n    jnp.array = backup_jnp_array\n    setattr(Literal, \"__hash__\", None)\n\n\ndef slices_to_jaxpr(\n        closed_jaxpr: ClosedJaxpr,\n        sliced_eqns: Sequence[Sequence[JaxprEqn]]) -> Sequence[ClosedJaxpr]:\n    \"\"\"Wrap sliced equations to a list of ClosedJaxpr.\"\"\"\n    n_eqns = len(sliced_eqns)\n    global_invars = OrderedSet(closed_jaxpr.jaxpr.invars)\n    global_outvars = OrderedSet(\n        var for var in closed_jaxpr.jaxpr.outvars if isinstance(var, Var))\n    global_consts = dict(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))\n\n    layer_invars = [OrderedSet() for _ in range(n_eqns)]\n    layer_outvars = [OrderedSet() for _ in range(n_eqns)]\n    layer_consts = [{} for _ in range(n_eqns)]\n\n    var_layer_dict = {}  # Dict[var -> layer_idx]\n    for i, eqns in enumerate(sliced_eqns):\n        for eqn in eqns:\n            for var in eqn.invars:\n                if isinstance(var, Literal):\n                    continue\n                if var in global_consts:\n                    layer_consts[i][var] = global_consts[var]\n                elif var in global_invars:\n                    layer_invars[i].add(var)\n                elif var_layer_dict[var] != i:\n                    layer_invars[i].add(var)\n                    layer_outvars[var_layer_dict[var]].add(var)\n                else:\n                    assert var_layer_dict[var] == i\n            for var in eqn.outvars:\n                if not isinstance(var, DropVar):\n                    var_layer_dict[var] = i\n                if var in global_outvars:\n                    layer_outvars[i].add(var)\n\n    result = []\n    for i, eqns in enumerate(sliced_eqns):\n        new_jaxpr = Jaxpr(list(layer_consts[i].keys()), list(layer_invars[i]),\n                          list(layer_outvars[i]), eqns)\n        new_closed_jaxpr = ClosedJaxpr(new_jaxpr,\n                                       list(layer_consts[i].values()))\n        result.append(new_closed_jaxpr)\n    return result\n\n\ndef get_var_mapping(mapping, var):\n    \"\"\"map the var to a new value if var is Var and in the mapping.\"\"\"\n    if isinstance(var, Var) and var in mapping:\n        return mapping[var]\n    else:\n        return var\n\n\ndef log_jaxpr(jaxpr: ClosedJaxpr, filename: str):\n    \"\"\"Print jaxpr int a temporary file for debugging purposes.\"\"\"\n    path = \"/tmp/\" + filename\n    with open(path, \"w\", encoding=\"utf-8\") as f:\n        f.write(str(jaxpr))\n\n\n########################################\n##### Flax Utilities\n########################################\n\n\ndef get_metrics(device_metrics):\n    \"\"\"\n    This function is similar to flax/training/common_utils.py, but works for\n    DistributedArray in alpa.\n    \"\"\"\n    # pylint: disable=import-outside-toplevel\n    from alpa.device_mesh import prefetch\n\n    prefetch(device_metrics)\n    return stack_forest(device_metrics)\n\n\n########################################\n##### Profiling Utilities\n########################################\n\n\ndef profile_xla_executable(compiled, backend, local_devices):\n    \"\"\"Measure the time costs of a xla executable with dummy inputs.\"\"\"\n    hlo_module = compiled.hlo_modules()[0]\n    cost_failed = [np.inf] * 3\n\n    # Allocate dummy buffers\n    input_shapes = hlo_module.parameter_shapes()\n\n    # prune OOM cases, not exact because third party lib not considered:\n    free_mem = local_devices[0].available_memory()\n    input_bytes = 0\n    for shape in input_shapes:\n        input_bytes += np.prod(\n            shape.dimensions()) * shape.numpy_dtype().itemsize\n    if free_mem < compiled.total_allocation_size() and free_mem != -1:\n        return cost_failed\n\n    device_inputs = []\n    try:\n        for shape in input_shapes:\n            device_inputs.append([\n                backend.buffer_from_pyval(\n                    np.empty(shape.dimensions(), shape.numpy_dtype()), device)\n                for device in local_devices\n            ])\n        local_devices[0].synchronize_all_activity()\n    except RuntimeError:\n        return cost_failed\n\n    # Run benchmark\n    def run_func():\n        device_outputs = compiled.execute_sharded_on_local_devices(\n            device_inputs)\n\n        # Reset the value for donate buffers\n        ct = 0\n        for j in range(len(device_inputs)):\n            if device_inputs[j][0].is_deleted():\n                device_inputs[j] = device_outputs[ct]\n                ct += 1\n\n        local_devices[0].synchronize_all_activity()\n\n    try:\n        costs = benchmark_func(run_func, repeat=3, number=3)\n    except RuntimeError:\n        costs = cost_failed\n    return costs\n\n\ndef benchmark_func(run_func,\n                   sync_func=None,\n                   warmup=1,\n                   repeat=3,\n                   number=5,\n                   min_repeat_second=None):\n    \"\"\"\n    Benchmark the execution time of a function.\n\n    The function is executed for (warmup + number * repeat) times.\n    The return value is a list of `repeat` elements and each elements is\n    the average execution time of `number` executions.\n\n    If `min_repeat_second` is set, the function automatically picks a `number`\n    so that one `repeat` lasts for at least `min_repeat_second` seconds.\n    \"\"\"\n    costs = []\n\n    # Warmup\n    for _ in range(warmup):\n        run_func()\n\n    # Choose a \"number\" according to \"min_repeat_second\"\n    if min_repeat_second:\n        if sync_func:\n            sync_func()\n        tic = time.time()\n        run_func()\n        if sync_func:\n            sync_func()\n        toc = time.time()\n        cost = toc - tic\n        number = max(int(min_repeat_second / cost), 1)\n\n    # Benchmark\n    for _ in range(repeat):\n        if sync_func:\n            sync_func()\n        tic = time.time()\n        for _ in range(number):\n            run_func()\n        if sync_func:\n            sync_func()\n        costs.append(time.time() - tic)\n\n    return np.array(costs) / number\n\n\ndef run_with_timeout(func, args=(), kwargs=None, timeout=None):\n    \"\"\"Run a function with timeout.\"\"\"\n    ret_value = []\n\n    def _target_func():\n        ret_value.append(func(*args, **(kwargs or {})))\n\n    t = threading.Thread(target=_target_func)\n    t.start()\n    t.join(timeout=timeout)\n    if t.is_alive():\n        raise TimeoutError\n\n    if not ret_value:\n        raise RuntimeError\n\n    return ret_value[0]\n\n\n########################################\n##### Array Conversion\n########################################\n\n\ndef is_continuous_subset(tensor_slice, tensor_shape, row_major=True):\n    \"\"\"\n    Figure out whether a slice is a continuous subset of the tensor.\n\n    Args:\n        slice_shape (Sequence(slice)): the shape of the slice.\n        tensor_shape (Sequence(int)): the shape of the tensor.\n        row_major (bool): whether the tensor layout is row-majored.\n\n    Returns:\n        is_continuous (bool)\n    \"\"\"\n    if not row_major:\n        raise NotImplementedError(\"Do not support column major.\")\n    ndim = len(tensor_shape)\n    if len(tensor_slice) != ndim:\n        raise RuntimeError(\"ndims mismatch.\")\n    slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice)\n    for dim, dim_shape in enumerate(slice_shape):\n        if dim + 1 > ndim:\n            return True\n        if dim_shape == 1:\n            continue\n        return slice_shape[dim + 1:] == tensor_shape[dim + 1:]\n\n\ndef infer_start_pos_and_n_elements(tensor_shape, tensor_slice):\n    start_pos = 0\n    n_elements = 1\n    for dim_len, dim_slice in zip(tensor_shape, tensor_slice):\n        start_pos = start_pos * dim_len + dim_slice.start\n        n_elements = n_elements * (dim_slice.stop - dim_slice.start)\n    return start_pos, n_elements\n\n\ndef infer_offset_and_n_elements(tensor_slice):\n    \"\"\"Calculate the offset and #elements before making NCCL calls.\n\n    This function assumes the slice is a continuous subset of the original\n    tensor.\n    \"\"\"\n    slice_shape = tuple(ind.stop - ind.start for ind in tensor_slice)\n    offset = tuple()\n    n_elements = np.prod(slice_shape)\n    for dim, dim_shape in enumerate(slice_shape):\n        offset = offset + (tensor_slice[dim].start,)\n        if dim_shape > 1:\n            break\n    return offset, n_elements\n\n\ndef xla_buffer_to_jax_tensor(xla_buf):\n    \"\"\"\n    Convert an xla buffer to a JAX DeviceArray.\n\n    So we can index over the data buffer.\n    \"\"\"\n    aval = ShapedArray(xla_buf.shape, xla_buf.dtype)\n    return _DeviceArray(aval, xla_buf.device(), xla_buf)\n\n\ndef jax_tensor_to_xla_buffer(jax_buf):\n    \"\"\"Convert a JAX Device array back to XLA buffer.\"\"\"\n    return jax_buf.device_buffer\n\n\n# Note: use Python jit instead of CPP jit,\n# because CPP jit has bugs on _DeviceArray.\nif is_worker:\n    FLAGS.experimental_cpp_jit = False\n\n\n# Note(Hao): this function will be jit-ed into as many versions as the possible\n# length of start_indices\n@partial(jax.jit, donate_argnums=0, static_argnums=2)\ndef jax_tensor_set(src_buf, update, start_indices):\n    \"\"\"\n    In-place write on a JAX buffer.\n\n    Args:\n        src_buf: JAX device array.\n        update: JAX device array.\n        start_indices (tuple[int]): tuple of integers indicating the starting\n        indices.\n    \"\"\"\n    # src_buf = src_buf.at[indices].set(update)\n    src_buf = jax.lax.dynamic_update_slice(src_buf, update, start_indices)\n    return src_buf\n\n\n@partial(jax.jit, static_argnums=(1, 2))\ndef jax_tensor_index(src_tensor, indices, size):\n    dst_tensor = jax.lax.dynamic_slice(src_tensor, indices, size)\n    return dst_tensor\n\n\n########################################\n##### OS / IO Utilities\n########################################\n\n\ndef run_cmd(cmd: str):\n    \"\"\"Run a bash command.\"\"\"\n    print(cmd)\n    ret = os.system(cmd)\n    return ret\n\n\ndef list_gpu_info():\n    \"\"\"List all gpu information by calling nvidia-smi.\"\"\"\n    ret = subprocess.getoutput(\"nvidia-smi -L\")\n    visible_devices = os.environ.get(\"CUDA_VISIBLE_DEVICES\", None)\n    if visible_devices:\n        ids = [int(x) for x in visible_devices.split(\",\")]\n        lines = ret.split(\"\\n\")\n        lines = [lines[i] for i in ids]\n        ret = \"\\n\".join(lines)\n    return ret\n\n\ndef disable_tqdm_globally():\n    \"\"\"Disable tqdm globally.\"\"\"\n    tqdm.tqdm.__init__ = partialmethod(tqdm.tqdm.__init__, disable=True)\n\n\ndef get_num_hosts_and_num_devices(args):\n    \"\"\"Get the number of hosts and the number of devices per host for benchmark\n    scripts.\"\"\"\n    if args.num_hosts is not None or args.num_devices_per_host is not None:\n        assert (args.num_hosts is not None and\n                args.num_devices_per_host is not None)\n        num_hosts, num_devices_per_host = (args.num_hosts,\n                                           args.num_devices_per_host)\n    else:\n        if hasattr(args, \"local\") and args.local:\n            num_hosts = 1\n            if global_config.backend == \"gpu\":\n                num_devices_per_host = list_gpu_info().count(\"UUID\")\n            elif global_config.backend == \"tpu\":\n                num_devices_per_host = len(jax.devices(\"tpu\"))\n            else:\n                raise ValueError(\n                    f\"Unsupported backend: {global_config.backend}\")\n        else:\n            ray.init(address=\"auto\")\n            num_hosts = len(ray.nodes())\n            num_devices_per_host = int(\n                ray.cluster_resources()[\"GPU\"]) // num_hosts\n    return num_hosts, num_devices_per_host\n\n\ndef write_tsv(heads: Sequence[str],\n              values: Sequence[Any],\n              filename: str,\n              print_line: bool = True):\n    \"\"\"Write tsv data to a file.\"\"\"\n    assert len(heads) == len(values)\n\n    values = [str(x) for x in values]\n\n    with open(filename, \"a\", encoding=\"utf-8\") as fout:\n        fout.write(\"\\t\".join(values) + \"\\n\")\n\n    if print_line:\n        line = \"\"\n        for i in range(len(heads)):\n            line += heads[i] + \": \" + values[i] + \"  \"\n        print(line)\n\n\ndef to_str_round(x: Any, decimal: int = 6):\n    \"\"\"Print a python object but round all floating point numbers.\"\"\"\n    if isinstance(x, str):\n        return x\n    if isinstance(x, (list, tuple, np.ndarray)):\n        tmp_str = \", \".join([to_str_round(y, decimal=decimal) for y in x])\n        return \"[\" + tmp_str + \"]\"\n    if isinstance(x, dict):\n        return str({k: to_str_round(v, decimal=decimal) for k, v in x.items()})\n    if isinstance(x, (int, np.int32, np.int64)):\n        return str(x)\n    if isinstance(x, (float, np.float32, np.float64)):\n        format_str = f\"%.{decimal}f\"\n        return format_str % x\n    if x is None:\n        return str(x)\n    raise ValueError(\"Invalid value: \" + str(x))\n\n\ndef check_server_port(address, port):\n    \"\"\"Checking Port Opening Status \"\"\"\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        try:\n            s.connect((address, port))\n            return True\n        except socket.error:\n            return False\n\n\n_tic = None\n\n\ndef print_used_time(message: str):\n    \"\"\"Print a message and the elapsed time from the last call.\"\"\"\n    global _tic\n    if message:\n        print(f\" - {message}: {time.time() - _tic:.2f} s\")\n    _tic = time.time()\n\n\n########################################\n##### Ray Compatibility API Utilities\n########################################\n\n\ndef try_import_ray_worker(error: bool = False):\n    \"\"\"Tries importing `ray.worker` and returns the module (or None).\n\n    Args:\n        error: Whether to raise an error if ray.worker cannot be imported.\n\n    Returns:\n        The `ray.worker` modules.\n\n    Raises:\n        ImportError: If error=True and ray's version >= 2.0.\n    \"\"\"\n    # In the ray-nightly version,\n    # worker = _DeprecationWrapper(\"worker\", ray._private.worker)\n    # `_DeprecationWrapper` has attributes of `_real_worker`\n    try:\n        if hasattr(ray.worker, \"_real_worker\"):\n            if error:\n                raise ImportError(\"Could not import `ray.worker`!\"\n                                  \"You might use the ray-nightly \"\n                                  \"and `ray.worker` is deprecated there\"\n                                  \"`pip install ray==1.13.0`.\")\n            return ray.worker._real_worker  # pylint: disable=protected-access\n        else:\n            return ray.worker\n    except ModuleNotFoundError:\n        return ray._private.worker  # pylint: disable=protected-access\n\n\ndef try_import_ray_state(error: bool = False):\n    \"\"\"Tries importing `ray.state` and returns the module (or None).\n\n    Args:\n        error: Whether to raise an error if ray.state cannot be imported.\n\n    Returns:\n        The `ray.state` modules.\n\n    Raises:\n        ImportError: If error=True and ray's version >= 2.0.\n    \"\"\"\n    # In the ray-nightly version,\n    # state = _DeprecationWrapper(\"state\", ray._private.state)\n    # `_DeprecationWrapper` has attributes of `_real_worker`\n    try:\n        if hasattr(ray.state, \"_real_worker\"):\n            if error:\n                raise ImportError(\"Could not import `ray.state`!\"\n                                  \"You might use the ray-nightly \"\n                                  \"and `ray.state` is deprecated there\"\n                                  \"`pip install ray>=1.13.0`.\")\n            return ray.state._real_worker  # pylint: disable=protected-access\n        else:\n            return ray.state\n    except ModuleNotFoundError:\n        return ray._private.state  # pylint: disable=protected-access\n\n\n########################################\n##### Ray Palcement Group API Utilities\n########################################\n\n\ndef is_ray_node_resource(resource_key):\n    \"\"\"Check if the current resource is the host ip.\"\"\"\n    ishost_regex = re.compile(r\"^node:\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$\")\n    return ishost_regex.match(resource_key)\n\n\ndef get_bundle2ip(pg: PlacementGroup = None):\n    \"\"\"get the ip address list from placement group\n\n    The ordering of the ip address are aligned with each bundle index.\n    \"\"\"\n\n    if pg:\n        pg_id = pg.id.hex()\n    # dictionary: bundle_group to node_ip\n    dict_bg2ip = {}\n\n    ray_state = try_import_ray_state()\n    resources_list = ray_state.state._available_resources_per_node(  # pylint: disable=protected-access\n    ).values()\n\n    for resource in resources_list:\n        resource_name_list = resource.keys()\n\n        node_ip = None\n        bundle_index_list = []\n        for resource_name in resource_name_list:\n            # when bundles are created, pg resources are\n            # specified as [resource]_[bundle_index]_[pg_id]\n            if pg:\n                try_bundle_index = re.findall(rf\"bundle_group_(\\d+)_{pg_id}\",\n                                              resource_name)\n            else:\n                try_bundle_index = re.findall(r\"bundle_group_(\\d+)_.*\",\n                                              resource_name)\n\n            try_node_ip = re.findall(\n                r\"^node:(\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}\\.\\d{1,3}$)\", resource_name)\n\n            if try_node_ip:\n                node_ip = try_node_ip[0]\n\n            if try_bundle_index:\n                bundle_index_list.append(try_bundle_index[0])\n\n        dict_bg2ip.update(\n            **dict(zip(bundle_index_list, [node_ip] * len(bundle_index_list))))\n\n    ip_list = []\n    for i in range(len(dict_bg2ip)):\n        ip_list.append(dict_bg2ip[str(i)])\n\n    return ip_list\n\n\ndef env_integer(key, default):\n    if key in os.environ:\n        value = os.environ[key]\n        if value.isdigit():\n            return int(os.environ[key])\n\n        logger.debug(f\"Found {key} in environment, but value must \"\n                     f\"be an integer. Got: {value}. Returning \"\n                     f\"provided default {default}.\")\n        return default\n    return default\n\n\ndef create_placement_group(num_hosts,\n                           host_num_devices,\n                           name,\n                           additional_resources_per_host=None):\n    \"\"\"Creates a placement group if it does not exist.\n\n    If a placement group is already detected (in Tune integration),\n    this will be a no-op.\n\n    By default the placement group will be created with `SPREAD` strategy.\n    This is optimized for colocating GPUs on different nodes.\n\n    Args:\n        num_hosts: the number of hosts to create the placement group for\n        host_num_devices: the number of devices on each host\n        additional_resources_per_host: additional resources per host\n\n    Returns:\n        The placement group\n    \"\"\"\n    current_placement_group = get_current_placement_group()\n    ray_worker = try_import_ray_worker()\n    worker = ray_worker.global_worker  # pylint: disable=protected-access\n    should_capture_child_tasks_in_placement_group = (\n        worker.should_capture_child_tasks_in_placement_group)\n    should_create_placement_group = (\n        current_placement_group is None or\n        not should_capture_child_tasks_in_placement_group)\n\n    if should_create_placement_group:\n        # `should_create_placement_group` is always True when using alpa alone.\n        # `should_create_placement_group` can be false when integrated with Tune\n        additional_resources_per_host = (additional_resources_per_host or {})\n        bundles = [{\n            \"CPU\": 1,\n            \"GPU\": host_num_devices[i],\n            **additional_resources_per_host\n        } for i in range(num_hosts)]\n\n        # Alpa Placement Group: `SPREAD` strategy is required\n        # https://docs.ray.io/en/latest/ray-core/placement-group.html#strategy-types\n        # Each bundle must be scheduled in a separate node.\n        strategy = \"SPREAD\"\n\n        placement_group = ray.util.placement_group(bundles,\n                                                   strategy=strategy,\n                                                   name=name or \"\")\n        logger.debug(\"Waiting for placement group to start.\")\n        timeout = env_integer(PLACEMENT_GROUP_TIMEOUT_S_ENV, 100)\n        ready, _ = ray.wait([placement_group.ready()], timeout=timeout)\n        if ready:\n            logger.debug(\"Placement group has started.\")\n        else:\n            raise TimeoutError(\n                \"Placement group creation timed out. Make sure your \"\n                \"cluster either has enough resources or use an \"\n                \"autoscaling cluster. If you are running on a cluster, \"\n                \"make sure you specify an address in `ray.init()`, for example,\"\n                ' `ray.init(\"auto\")`. You can also increase the timeout by '\n                \"setting the ALPA_PLACEMENT_GROUP_TIMEOUT_S environment \"\n                \"variable. Current resources available: \"\n                f\"{ray.available_resources()}, resources requested by \"\n                f\"the placement group: {placement_group.bundle_specs}\")\n        return placement_group\n    else:\n        return current_placement_group\n\n\ndef get_bundle_idx(placement_group: PlacementGroup, node_ips: List[str]):\n    \"\"\"Get the bundle index for the placement group.\n\n    The placement group is a list of resource bundles.\n    Each bundle will be assigned to **one** node.\n\n    First, we need to find the bundle index with GPU resources.\n    Then, we can find the node IP for the bundle index.\n    Lastly, we sort bundle index according to the node IP list given.\n\n    Args:\n        placement_group: The placement group.\n        node_ips: The list of node IP addresses.\n\n    Returns:\n        list: The sorted bundle index list.\n    \"\"\"\n    # get the node IP for the bundle index\n    bundle_ips = get_bundle2ip(placement_group)\n    bundle_specs = placement_group.bundle_specs\n\n    # filter out the bundle index with node (GPUs)\n    node_bundle_idx_list = [\n        i for i, bundle_spec in enumerate(bundle_specs)\n        if bundle_spec.get(\"GPU\", 0) > 0\n    ]\n\n    if len(node_bundle_idx_list) < len(node_ips):\n        raise ValueError(\"The number of bundles with GPU resources \"\n                         \"is less than the number of node IPs.\")\n\n    # node IP -> bundle index\n    bundle_ip2idx = {bundle_ips[i]: i for i in node_bundle_idx_list}\n\n    # sorted bundle index according to the node IP list given\n    sorted_bundle_idx = [bundle_ip2idx[ip] for ip in node_ips]\n\n    return sorted_bundle_idx\n\n\ndef retrieve_placement_group():\n    \"\"\"retrieve the placement group to support node affinity scheduling\n\n    If already inside the placement group, retrieve the current placement\n    group (case I). Then, if the placement group is detected globally in\n    alpa, retrieve the global placement group (case II).\n\n    \"\"\"\n    # case 1:\n    # Get the current placement group which a task or actor is using\n    current_placement_group = get_current_placement_group()\n    if current_placement_group:\n        return current_placement_group\n\n    # case 2:\n    # Get the placement group created when alpa.init('ray')\n    global_cluster = device_mesh.global_cluster\n    if global_cluster and global_cluster.placement_group:\n        alpa_placement_group = global_cluster.placement_group\n        return alpa_placement_group\n\n    raise ValueError(\n        \"The alpa training is not inside the ray tasks or actor or \"\n        \"the placement group is not created yet. One reason is that \"\n        \"Alpa is not connected to Ray cluster, and use `alpa.init('ray')`\"\n        \" at the beginning. Do you have override the placement group? \"\n        \"If not, please help file an issue on Github.\")\n\n\ndef get_num_available_gpus(pg: PlacementGroup):\n    res = ray.available_resources()\n    pg_id = pg.id.hex()\n    return res[f\"GPU_group_{pg_id}\"]\n\n\n########################################\n##### Other Utilities\n########################################\n\nGB = 1 << 30  # Gigabyte\nMB = 1 << 20  # Megabyte\n\n\ndef map_to_shape(array_pytree: PyTreeDef):\n    \"\"\"Map a PyTree of jax arrays to their shapes.\"\"\"\n    return tree_map(lambda x: getattr(x, \"shape\", None), array_pytree)\n\n\ndef map_to_nparray(tree: PyTreeDef):\n    \"\"\"Map a PyTree to a PyTree of numpy array.\"\"\"\n\n    def convert_to_nparray(x):\n        if hasattr(x, \"__array__\"):\n            return np.asarray(x)\n        return x\n\n    return jax.tree_map(convert_to_nparray, tree)\n\n\ndef compute_bytes(pytree: PyTreeDef):\n    \"\"\"Compute the total bytes of arrays in a pytree.\"\"\"\n    flatten_args, _ = tree_flatten(pytree)\n    ret = 0\n    for x in flatten_args:\n        if hasattr(x, \"shape\"):\n            ret += np.prod(x.shape) * x.dtype.itemsize\n    return ret\n\n\ndef compute_param_number(pytree: PyTreeDef):\n    \"\"\"Compute the total number of elements in a pytree.\"\"\"\n    flatten_args, _ = tree_flatten(pytree)\n    ret = 0\n    for x in flatten_args:\n        if hasattr(x, \"shape\"):\n            ret += np.prod(x.shape)\n    return ret\n\n\ndef compute_gpt_tflops(batch_size,\n                       seq_len,\n                       num_layers,\n                       hidden_size,\n                       vocab_size,\n                       num_gpus,\n                       latency,\n                       backward=True,\n                       checkpoint_activations=False):\n    \"\"\"\n    Compute the Tera Flop Operations (TFLOP) per second per GPU\n    for GPT-like models.\n    \"\"\"\n    factor = 24\n    if backward:\n        factor += 48\n    if checkpoint_activations:\n        factor += 24\n\n    total_flop = (factor * batch_size * seq_len *\n                  (hidden_size**2) * num_layers * (1 + seq_len /\n                                                   (6 * hidden_size)) +\n                  6 * batch_size * seq_len * hidden_size * vocab_size)\n    # Note: The above formula does not count the first embedding table lookup\n    # because it is a sparse operation.\n    # If we use dense dot to compute the first embedding table lookup,\n    # then the last term in total_flops should be\n    # \"+ 10 * batch_size * seq_len * hidden_size * vocab_size\".\n    tflops = total_flop / latency / num_gpus / 1e12\n    return tflops\n\n\n_DISABLE_NUMBA = False\n\n\ndef maybe_numba_jit(func):\n    \"\"\"Decorator to mark a function as numba jitted if numba is available.\"\"\"\n    try:\n        from numba import jit  # pylint: disable=import-outside-toplevel\n        jitted_func = jit(nopython=True)(func)\n\n        def wrapper(*args, **kwargs):\n            if _DISABLE_NUMBA:\n                return func(*args, **kwargs)\n            return jitted_func(*args, **kwargs)\n\n        return wrapper\n    except ImportError:\n        logger.warning(\"Install numba to jit and accelerate the function.\")\n        return func\n\n\ndef mesh_ids_hash(mesh_ids):\n    ret = b\"\"\n    for i in sorted(mesh_ids):\n        ret += bytes(f\"{i}\", \"utf-8\") + b\"$\"\n    return ret\n"
  },
  {
    "path": "alpa/version.py",
    "content": "# pylint: disable=pointless-string-statement, line-too-long\n\"\"\"Version information.\"\"\"\nfrom jax._src.lib import xla_extension as xe\n\n__version__ = \"1.0.0.dev0\"\n\nminimal_alpa_jaxlib_version = (0, 2, 2)\n\n\ndef check_alpa_jaxlib_version():\n    \"\"\"Check the minimal requirement of alpa's jaxlib.\"\"\"\n    try:\n        alpa_jaxlib_version_str = xe.get_alpa_jaxlib_version()\n        alpa_jaxlib_version = tuple(\n            int(x) for x in alpa_jaxlib_version_str.split(\".\"))\n    except AttributeError:\n        alpa_jaxlib_version = (0, 0, 0)\n\n    if alpa_jaxlib_version < minimal_alpa_jaxlib_version:\n        minimal_alpa_jaxlib_version_str = \".\".join(\n            str(x) for x in minimal_alpa_jaxlib_version)\n        alpa_jaxlib_version_str = \".\".join(str(x) for x in alpa_jaxlib_version)\n        raise RuntimeError(\n            f\"The alpa-jaxlib's internal version is v{alpa_jaxlib_version_str}, \"\n            f\"but the minimal requirement is v{minimal_alpa_jaxlib_version_str}. \"\n            f\"Please install the latest alpa-jaxlib. If you build alpa from source,\"\n            f\" please update your tensorflow-alpa submodule and re-compile jaxlib (\"\n            f\"help : https://alpa-projects.github.io/developer/developer_guide.html\"\n            f\"#updating-submodule-tensorflow-alpa)\")\n\n\n##### Attach all licenses of used open-source code below #####\n\n# For some huggingface model implementations\n\"\"\"\nCopyright 2018- The Hugging Face team. All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\"\"\"\n\n# For model utils in flax\n\"\"\"\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\"\"\"\n\n# For OPT serving examples\n\"\"\"\nMIT License\n\nCopyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved.\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\"\"\"\n\n# For ray serve\n\"\"\"\nCopyright 2022- The Ray team. All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\"\"\"\n"
  },
  {
    "path": "alpa/wrapped_hlo.py",
    "content": "\"\"\"A class that wraps HloModule and records whether the module runs AutoSharding\nand SPMD Partitioner or not.\n\"\"\"\nfrom enum import Enum, auto\nfrom typing import Union\n\nfrom jax._src.lib import xla_extension as xe\nfrom jax.interpreters import mlir\n\n\nclass HloStatus(Enum):\n    \"\"\"\n    The status of an HloModule.\n    See also the docstring at the beginning of shard_parallel/auto_sharding.py.\n    \"\"\"\n    UNOPTIMIZED = auto()\n    SHARDING_ANNOTATED = auto()\n    SPMD_PARTITIONED = auto()\n    FULLY_OPTIMIZED = auto()\n\n\nclass WrappedHlo:\n    \"\"\"Wrapped HloModule with HloStatus.\"\"\"\n\n    def __init__(self,\n                 module: Union[xe.HloModule, xe.XlaComputation, bytes],\n                 status: HloStatus = HloStatus.UNOPTIMIZED):\n        if isinstance(module, xe.HloModule):\n            self.module = module\n        elif isinstance(module, xe.XlaComputation):\n            self.module = module.get_hlo_module()\n        else:\n            assert isinstance(module, bytes)\n            self.module = xe.XlaComputation(module).get_hlo_module()\n        self.name = self.module.name\n        self.status = status\n        self.is_manually_annotated = False\n\n    def get_computation(self) -> xe.XlaComputation:\n        return xe.XlaComputation(self.module.as_serialized_hlo_module_proto())\n\n    def get_mhlo(self):\n        xla_computation = self.get_computation()\n        module_str = xe.mlir.xla_computation_to_mlir_module(xla_computation)\n        with mlir.make_ir_context():\n            mhlo = mlir.ir.Module.parse(module_str)\n        return mhlo\n\n    def get_module(self) -> xe.HloModule:\n        return self.module\n\n    def get_hlo_proto(self):\n        return self.module.as_serialized_hlo_module_proto()\n\n    def program_shape(self):\n        return self.module.program_shape()\n\n    def set_input_shardings(self, sharding_protos):\n        assert self.is_sharding_annotated() or self.is_unoptimized()\n        xe.set_hlo_module_input_shardings(self.module, sharding_protos)\n\n    def set_output_shardings(self, sharding_protos):\n        assert self.is_sharding_annotated() or self.is_unoptimized()\n        xe.set_hlo_module_output_shardings(self.module, sharding_protos)\n\n    def is_unoptimized(self):\n        return self.status == HloStatus.UNOPTIMIZED\n\n    def is_sharding_annotated(self):\n        return self.status == HloStatus.SHARDING_ANNOTATED\n\n    def is_spmd_partitioned(self):\n        return self.status == HloStatus.SPMD_PARTITIONED\n\n    def to_string(self):\n        return self.module.to_string()\n\n    def __getstate__(self):\n        return (self.get_hlo_proto(), self.status)\n\n    def __setstate__(self, bytes_and_status):\n        b, s = bytes_and_status\n        self.__init__(b, s)\n"
  },
  {
    "path": "benchmark/alpa/README.md",
    "content": "# Benchmark\nTo achieve the best performance with Alpa, one needs to run a full auto-parallelization search for the target model on a target cluster.\nThe search procedure can take a significant amount of time.\nTo make the benchmark feasible in a short amount of time, this documentation provides:\n- Instructions for benchmarking the solutions found on an AWS p3.16xlarge cluster.  \n  You can use these to quickly run Alpa, see how Alpa works, and get an estimation of the performance.\n  The performance may not be the best if your cluster is not an AWS p3.16xlarge cluster.\n- Instructions for running the full search.  \n  You can use these to fully benchmark the auto-parallelization ability of Alpa.\n\n## Benchmark Pre-found Solutions\n\n### Start a Ray Cluster\nAlpa uses a distributed framework Ray to manage the cluster and distributed workers.\nHere, we provide instructions for manually launching a ray cluster.\nYou can also refer to the Ray [documentation](https://docs.ray.io/en/latest/cluster/quickstart.html#) for more methods on launching and managing ray clusters. \n\n1. Pick one node as the head node and run the command below on it\n    ```\n    ray start --head\n    ```\n2. For all other nodes, connect them to the head node following the instructions printed by the previous command. Skip this step if you only have one node.\n    ```\n    # The command should look like this, but with the ip address and password printed by the previous command. \n    ray start --address='172.31.31.37:6379' --redis-password='5241590000000000'\n    ```\n\nYou can check the cluster status by \n```\nray status\n```\nYou should be able to see the number of CPUs and GPUs available on your cluster.\nAll nodes should have alpa installed.\n\n### GPT-3\nRun the benchmark with all GPUs in your cluster.\n```\npython3 benchmark.py --suite gpt.perf_test_auto\n```\n\nYou can also specify the number of hosts and the number of devices per host.\n```\npython3 benchmark.py --suite gpt.perf_test_auto --num-hosts 2 --num-devices-per-host 8\n```\n\n### Mixture-of-Expert Transformer\nSimilar to the previous subsection.\n```\npython3 benchmark.py --suite moe.perf_test_auto\n```\n\n### Wide-ResNet\nSimilar to the previous subsection.\n```\npython3 benchmark.py --suite wresnet.perf_test_auto\n```\n\n## Run Full Search\n\n### Generate Profiling Database\nAlpa requires a cost model to estimate the performance of different parallelization strategies.\nThis cost model is based on profiling results on the target cluster.\nWe can generate a profiling database with the following commands, which profiles the time costs of various computation and communication patterns.\nNote that this procedure is very slow and can take hours, but you only need to do it once for your cluster.\n\n1. Start a Ray cluster\n2. Generate the profiling database\n  ```\n  # for AWS p3.16:\n  python3 gen_prof_database.py --max-comm-size-intra-node 32 --max-comm-size-inter-node 29\n  \n  # for AWS p4.24 with EFA:\n  python3 gen_prof_database.py --efa --max-comm-size-intra-node 33 --max-comm-size-inter-node 30 --max-fail-retry 8\n  ```\n\n### Run Search\n```\npython3 benchmark.py --suite gpt.grid_search_auto\n```\n\n## A Quick Performance Test\nThis is a quick test for checking performance regressions.\nDevelopers should at least run this test to make sure their modifications do not introduce performance regressions.\n\n```\npython3 benchmark.py --suite gpt.perf_test_manual\n```\n\nExpected output on AWS p3.16 (10/17/2022)\n```\nubuntu@ip-172-31-34-216:~/efs/alpa/benchmark/alpa$ python3 benchmark.py --suite gpt.perf_test_manual\nWorking on case: BenchmarkCase(batch_size=32, model_config=GPTModelConfig(seq_len=1024, hidden_size=2560, num_layers=32, num_heads=32, vocab_size=51200), num_micro_batches=4, parallel_mode='uniform', parallel_args=UniformParallelArgs(prefer_reduce_scatter=True, use_remat=True, dp=2, op=2, pp=2, force_batch_dim_mapping=True))\n - Prepare input: 0.05 s\n - Create train state: 8.37 s\n - Compile (driver): 67.38 s\n - Compile (worker): 21.99 s\nIteration 0 ...\nIteration 1 ...\nIteration 2 ...\n - Benchmark: 18.83 s\nType: gpt  Model Config: GPTModelConfig(seq_len=1024, hidden_size=2560, num_layers=32, num_heads=32, vocab_size=51200)  #Microbatch: 4  #GPU: 8  Parallel Config: UniformParallelArgs(prefer_reduce_scatter=True, use_remat=True, dp=2, op=2, pp=2, force_batch_dim_mapping=True)  Mean Time (s): 2.464  Std Time (s): 0.000  #Params (Billion): 2.649B  TFLOPs: 37.01  Peak Mem (GB): 8.745  Metadata: {'compilation_times': 'None', 'compute_cost_file_name': 'None', 'forward_stage_layer_ids': 'None', 'submesh_shapes': 'None', 'logical_mesh_shapes': 'None', 'autosharding_option_dicts': 'None'}\n```\n\n## Advanced Usage\nBenchmark pipeshard parallel case:\n```\npython benchmark.py --suite gpt.perf_test_auto\n```\n\nBenchmark shard parallel case (i.e. only intra-opeartor parallelism, no pipeline parallelism). Add `--local` in the end to run the benchmark with the local cluster without ray.\n```\npython benchmark.py --suite gpt.perf_test_fast_2d --shard-only [--local]\n```\n\nSome benchmarks are inference benchmarks:\n```\npython benchmark.py --suite gpt_inference.profile\n```\n\nAdd `--profile-driver-time` to derive the latency from the driver. This flag will also turn off the synchronization barrier after each benchmarking step. Specially, for inference case, this turns streaming inference on and the model will pipeline different input batches (in addition to pipelining different micro-batches).\n```\npython benchmark.py --suite gpt_inference.profile --profile-driver-time\n```\n\nAdd `--profile_stage_execution_time` to derive the stage execution timeline for each requests and dump into chrome tracing files in folder `$PWD/chrome_trace/`.\n```\npython benchmark.py --suite gpt_inference.profile --profile-stage-execution-time\n```\n\nWe also include a convenient script `run_exp.py` to run multiple benchmarks with different cluster configurations. For example, to run all gpt search cases:\n```\npython run_exp.py gpt\n```\n"
  },
  {
    "path": "benchmark/alpa/benchmark.py",
    "content": "\"\"\"The entry point of intra-op + inter-op parallelism benchmark.\"\"\"\nimport os\nimport argparse\nfrom datetime import datetime\nimport time\n\nimport numpy as np\n\nfrom alpa.util import (write_tsv, get_num_hosts_and_num_devices, to_str_round,\n                       GB)\n\nfrom benchmark_one_case import benchmark_one_case\nimport suite_auto_gpt\nimport suite_auto_moe\nimport suite_manual_gpt\nimport suite_manual_moe\nimport suite_unet\nimport suite_wresnet\nimport suite_inference_gpt\nimport suite_inference_moe\n\nbenchmark_suites = {\n    \"gpt.tmp\": suite_manual_gpt.tmp_suite,\n    \"gpt.tmp_auto\": suite_auto_gpt.tmp_suite,\n    \"gpt.perf_test_fast_2d\": suite_manual_gpt.perf_test_fast_2d_suite,\n    \"gpt.perf_test_manual\": suite_manual_gpt.perf_test_suite,\n    \"gpt.perf_test_auto\": suite_auto_gpt.perf_test_suite,\n    \"gpt.grid_search_auto\": suite_auto_gpt.grid_search_suite,\n    \"gpt.correctness_test_auto\": suite_auto_gpt.correctness_test_suite,\n    \"gpt_inference.profile\": suite_inference_gpt.profile_suite,\n    \"gpt_no_embedding_inference.profile\": suite_inference_gpt.profile_suite,\n    \"moe.tmp\": suite_manual_moe.tmp_suite,\n    \"moe.tmp_auto\": suite_auto_moe.tmp_suite,\n    \"moe.perf_test_fast_2d\": suite_manual_moe.perf_test_fast_2d_suite,\n    \"moe.perf_test_auto\": suite_auto_moe.perf_test_suite,\n    \"moe.grid_search_auto\": suite_auto_moe.grid_search_suite,\n    \"moe_inference.profile\": suite_inference_moe.profile_suite,\n    \"unet.perf_test_auto\": suite_unet.perf_test_auto_suite,\n    \"unet.grid_search_auto\": suite_unet.grid_search_auto_suite,\n    \"wresnet.perf_test_2d\": suite_wresnet.perf_test_2d_suite,\n    \"wresnet.perf_test_auto\": suite_wresnet.perf_test_auto_suite,\n    \"wresnet.grid_search_auto\": suite_wresnet.grid_search_auto_suite,\n}\n\n\ndef benchmark_suite(suite_name,\n                    num_hosts,\n                    num_devices_per_host,\n                    exp_name=\"default\",\n                    niter=3,\n                    shard_only=False,\n                    local=False,\n                    profile_driver_time=False,\n                    profile_stage_execution_time=False,\n                    disable_tqdm=False,\n                    use_separate_process=True):\n    num_gpus = num_hosts * num_devices_per_host\n\n    if local:\n        assert shard_only, (\"Only shard-only mode is supported for execution \"\n                            \"on local GPUs.\")\n\n    if num_gpus not in benchmark_suites[suite_name]:\n        print(f\"No benchmark suite for #gpu={num_gpus}\")\n        return\n    suite = benchmark_suites[suite_name][num_gpus]\n\n    os.makedirs(\"tmp\", exist_ok=True)\n\n    model_type = suite_name.split(\".\")[0]\n    output_name = f\"{exp_name}.tsv\"\n\n    # Run all cases\n    for benchmark_case in suite:\n        model_config = benchmark_case.model_config\n        num_micro_batches = benchmark_case.num_micro_batches\n        parallel_args = benchmark_case.parallel_args\n\n        # Run one case\n        print(\"Working on case: {}\".format(str(benchmark_case)))\n        result = benchmark_one_case(\n            model_type,\n            benchmark_case,\n            niter,\n            num_hosts,\n            num_devices_per_host,\n            shard_only=shard_only,\n            local=local,\n            profile_driver_time=profile_driver_time,\n            profile_stage_execution_time=profile_stage_execution_time,\n            disable_tqdm=disable_tqdm,\n            use_separate_process=use_separate_process)\n\n        (parameter_count, peak_mem, latencies, tflops, metadata) = result\n\n        heads = [\n            \"Type\", \"Model Config\", \"#Microbatch\", \"#GPU\", \"Parallel Config\",\n            \"Mean Time (s)\", \"Std Time (s)\", \"#Params (Billion)\", \"TFLOPs\",\n            \"Peak Mem (GB)\", \"Metadata\"\n        ]\n        values = [\n            model_type, model_config, num_micro_batches, num_gpus,\n            parallel_args, f\"{np.mean(latencies):.3f}\",\n            f\"{np.std(latencies):.3f}\", f\"{parameter_count/1e9:.3f}B\",\n            f\"{tflops:.2f}\", f\"{peak_mem/GB:.3f}\",\n            to_str_round(metadata, 2)\n        ]\n        write_tsv(heads, values, output_name)\n\n        time.sleep(0.1)  # for ctrl+c to work\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--suite\",\n                        choices=list(benchmark_suites.keys()),\n                        type=str,\n                        required=True)\n    parser.add_argument(\"--niter\",\n                        type=int,\n                        default=3,\n                        help=\"The number of benchmark iterations\")\n    parser.add_argument(\"--num-hosts\", type=int, default=None)\n    parser.add_argument(\"--num-devices-per-host\", type=int, default=None)\n    parser.add_argument(\"--shard-only\",\n                        action=\"store_true\",\n                        help=\"Only profile the 2D case. No pipeline \"\n                        \"parallelism.\")\n    parser.add_argument(\"--local\",\n                        action=\"store_true\",\n                        help=\"Run on local GPUs. Do not use ray actors.\")\n    parser.add_argument(\"--profile-driver-time\",\n                        action=\"store_true\",\n                        help=\"Profile the execution time on the driver instead \"\n                        \"of the workers.\")\n    parser.add_argument(\n        \"--profile-stage-execution-time\",\n        action=\"store_true\",\n        help=\"Profile the execution timestamps of each pipeline \"\n        \"stage\")\n    parser.add_argument(\"--no-separate-process\",\n                        action=\"store_false\",\n                        help=\"Do not launch separate processes for benchmark. \"\n                        \"Errors in a single case will terminate this \"\n                        \"script.\",\n                        dest=\"use_separate_process\")\n    parser.add_argument(\"--exp-name\", type=str, default=\"default\")\n    parser.add_argument(\"--disable-tqdm\", action=\"store_true\")\n    args = parser.parse_args()\n\n    num_hosts, num_devices_per_host = get_num_hosts_and_num_devices(args)\n\n    benchmark_suite(args.suite, num_hosts, num_devices_per_host, args.exp_name,\n                    args.niter, args.shard_only, args.local,\n                    args.profile_driver_time, args.profile_stage_execution_time,\n                    args.disable_tqdm, args.use_separate_process)\n"
  },
  {
    "path": "benchmark/alpa/benchmark_one_case.py",
    "content": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport os\nimport argparse\nimport multiprocessing as mp\n\nimport jax\n\nfrom alpa import (init, global_config, get_global_cluster,\n                  LocalPhysicalDeviceMesh)\nfrom alpa.util import disable_tqdm_globally\n\nfrom benchmark_one_case_gpt_bert import (benchmark_gpt_bert_3d_internal,\n                                         benchmark_gpt_bert_2d_internal)\nfrom benchmark_one_case_moe import (benchmark_moe_3d_internal,\n                                    benchmark_moe_2d_internal)\nfrom benchmark_one_case_unet import benchmark_unet_3d_internal\nfrom benchmark_one_case_wresnet import (benchmark_wresnet_3d_internal,\n                                        benchmark_wresnet_2d_internal)\nfrom benchmark_one_case_gpt_bert_inference import (\n    benchmark_gpt_inference_internal)\nfrom benchmark_one_case_moe_inference import (benchmark_moe_inference_internal)\n\n\ndef benchmark_one_case_internal(model,\n                                case,\n                                niter,\n                                num_hosts,\n                                num_devices_per_host,\n                                profile_driver_time=False,\n                                profile_stage_execution_time=False,\n                                shard_only=False,\n                                local=False,\n                                disable_tqdm=False):\n    if disable_tqdm:\n        disable_tqdm_globally()\n\n    # local mode does not support dummy value\n    global_config.use_dummy_value_for_benchmarking = not local\n\n    if shard_only:\n        global_config.shard_parallel_sync_for_timer = True\n        if local:\n            assert num_hosts == 1\n            physical_mesh = LocalPhysicalDeviceMesh(\n                jax.local_devices()[:num_devices_per_host])\n        else:\n            init(cluster=\"ray\")\n            physical_mesh = get_global_cluster().get_physical_mesh(\n                list(range(num_hosts)), num_devices_per_host)\n\n        # Run benchmark\n        if model in [\"gpt\", \"bert\"]:\n            result = benchmark_gpt_bert_2d_internal(\n                physical_mesh,\n                model,\n                case,\n                niter,\n                profile_driver_time=profile_driver_time)\n        elif model == \"moe\":\n            result = benchmark_moe_2d_internal(\n                physical_mesh,\n                case,\n                niter,\n                profile_driver_time=profile_driver_time)\n        elif model == \"wresnet\":\n            global_config.xla_client_mem_fraction = 0.88\n            # Due to legacy issues, we turn off auto-tuning. Although the\n            # performance will be much better if we turn it on\n            global_config.xla_gpu_autotune_level = 0\n            result = benchmark_wresnet_2d_internal(\n                physical_mesh,\n                case,\n                niter,\n                profile_driver_time=profile_driver_time)\n        else:\n            raise ValueError(f\"Invalid model: {model}\")\n\n    else:\n        global_config.pipeline_sync_for_timer = True\n        if profile_stage_execution_time:\n            global_config.collect_trace = True\n        init(cluster=\"ray\")\n\n        # Run benchmark\n        if model in [\"gpt\", \"bert\"]:\n            result = benchmark_gpt_bert_3d_internal(\n                model,\n                case,\n                niter,\n                num_hosts,\n                num_devices_per_host,\n                profile_driver_time=profile_driver_time)\n        elif model == \"moe\":\n            result = benchmark_moe_3d_internal(\n                case,\n                niter,\n                num_hosts,\n                num_devices_per_host,\n                profile_driver_time=profile_driver_time)\n        elif model == \"wresnet\":\n            global_config.xla_client_mem_fraction = 0.88\n            # Due to legacy issues, we turn off auto-tuning. Although the\n            # performance will be much better if we turn it on\n            global_config.xla_gpu_autotune_level = 0\n            result = benchmark_wresnet_3d_internal(\n                case,\n                niter,\n                num_hosts,\n                num_devices_per_host,\n                profile_driver_time=profile_driver_time)\n        elif model == \"unet\":\n            global_config.xla_client_mem_fraction = 0.88\n            global_config.xla_gpu_autotune_level = 0\n            result = benchmark_unet_3d_internal(\n                case,\n                niter,\n                num_hosts,\n                num_devices_per_host,\n                profile_driver_time=profile_driver_time)\n        elif model in [\"gpt_inference\", \"gpt_no_embedding_inference\"]:\n            result = benchmark_gpt_inference_internal(\n                model,\n                case,\n                niter,\n                num_hosts,\n                num_devices_per_host,\n                profile_driver_time=profile_driver_time,\n                profile_stage_execution_time=profile_stage_execution_time)\n        elif model in [\"moe_inference\"]:\n            result = benchmark_moe_inference_internal(\n                case,\n                niter,\n                num_hosts,\n                num_devices_per_host,\n                profile_driver_time=profile_driver_time,\n                profile_stage_execution_time=profile_stage_execution_time)\n        else:\n            raise ValueError(f\"Invalid model: {model}\")\n\n    return result\n\n\ndef benchmark_and_write_to_namespace(result_namespace, *args, **kwargs):\n    result = benchmark_one_case_internal(*args, **kwargs)\n    result_namespace.result = result\n\n\ndef benchmark_one_case(*args, use_separate_process=False, **kwargs):\n    if not use_separate_process:\n        return benchmark_one_case_internal(*args, **kwargs)\n    ctx = mp.get_context(\"spawn\")\n    manager = ctx.Manager()\n    result_namespace = manager.Namespace()\n    p = ctx.Process(target=benchmark_and_write_to_namespace,\n                    args=(result_namespace, *args),\n                    kwargs=kwargs)\n    p.start()\n    p.join()\n    if p.exitcode != 0:\n        return -1, -1, [-1], -1, None\n    return result_namespace.result\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str)\n    parser.add_argument(\"--niter\", type=int)\n    parser.add_argument(\"--case\", type=str, required=True)\n    parser.add_argument(\"--num-hosts\", type=int)\n    parser.add_argument(\"--num-devices-per-host\", type=int)\n    parser.add_argument(\"--shard-only\",\n                        action=\"store_true\",\n                        help=\"Only profile the 2D case. No pipeline \"\n                        \"parallelism.\")\n    parser.add_argument(\"--local\",\n                        action=\"store_true\",\n                        help=\"Run on local GPUs. Do not use ray actors.\")\n    parser.add_argument(\"--profile-driver-time\",\n                        action=\"store_true\",\n                        help=\"Profile the execution time on the driver instead \"\n                        \"of the workers.\")\n    parser.add_argument(\"--disable-tqdm\", action=\"store_true\")\n    args = parser.parse_args()\n\n    os.makedirs(\"tmp\", exist_ok=True)\n\n    # Make eval work smoothly\n    from benchmark_parallel_utils import *\n    from suite_manual_gpt import GPTModelConfig\n    from suite_manual_moe import MoEModelConfig\n    from suite_wresnet import WResNetModelConfig\n    from suite_unet import UNetModelConfig\n    case = eval(args.case)\n\n    result = benchmark_one_case(args.model,\n                                case,\n                                args.niter,\n                                args.num_hosts,\n                                args.num_devices_per_host,\n                                shard_only=args.shard_only,\n                                local=args.local,\n                                profile_driver_time=args.profile_driver_time,\n                                disable_tqdm=args.disable_tqdm)\n\n    print(result)\n"
  },
  {
    "path": "benchmark/alpa/benchmark_one_case_gpt_bert.py",
    "content": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\nimport alpa\nfrom alpa import (parallelize, get_global_cluster,\n                  set_global_virtual_physical_mesh, automatic_remat,\n                  global_config)\nfrom alpa.model.bert_model import BertConfig, FlaxBertForMaskedLMModule\nfrom alpa.model.model_util import TrainState\nfrom alpa.model.gpt_model import FlaxGPTForLMModule\nfrom alpa.pipeline_parallel.stage_construction import get_last_dp_result\nfrom alpa.util import print_used_time\n\nfrom util import compute_gpt_parameter_count, compute_gpt_tflops\nfrom benchmark_parallel_utils import (\n    get_pipeshard_parallel_method, get_shard_parallel_method,\n    compile_and_benchmark_pipeshard_training_executable,\n    compile_and_benchmark_shard_training_executable)\n\n\ndef report_pipeline_breakdown(executable, timer_names, niter):\n    overall_costs = executable.get_execution_time_costs(timer_name=\"overall\")\n\n    print(\">>> overall: {}...\".format(overall_costs))\n    other_percentage = [100.0] * niter\n    other = overall_costs\n    for timer_name in timer_names:\n        costs = executable.get_execution_time_costs(timer_name=timer_name)\n        if len(costs) == 0:\n            costs = [0.0] * niter\n        percentage = [\n            cost / overall_costs[i] * 100 for i, cost in enumerate(costs)\n        ]\n        other = [remain - costs[i] for i, remain in enumerate(other)]\n        other_percentage = [\n            remain - percentage[i] for i, remain in enumerate(other_percentage)\n        ]\n        strs = []\n        for i, cost in enumerate(costs):\n            strs.append(str(cost) + f\" ({percentage[i]:.1f}) \")\n        print_string = \",\".join(strs)\n        print(\">>> {}: {}\".format(timer_name, print_string))\n\n    # print unknown overhead\n    strs = []\n    for i, remain in enumerate(other):\n        strs.append(\" \" + str(remain) + f\" ({other_percentage[i]:.1f})\")\n    print_string = \",\".join(strs)\n    print(\">>> {}: {}\".format(\"Others: \", print_string))\n\n\ndef create_train_state(rngkey, model, batch, dtype):\n    params = model.init_dummy(rngkey, batch[\"input_ids\"],\n                              batch[\"attention_mask\"], batch[\"token_type_ids\"],\n                              batch[\"position_ids\"])\n\n    def weight_decay_mask(pytree):\n        # do not use weight decay on layer norm and bias.\n        return jax.tree_map(lambda x: x.ndim > 1, pytree)\n\n    tx = optax.chain(\n        #optax.clip_by_global_norm(1.0),  # TODO(lmzheng): fix reduce-scatter for this\n        optax.adamw(learning_rate=1e-2, mask=weight_decay_mask))\n    use_master_copy = (dtype == jnp.float16)\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              use_master_copy=use_master_copy,\n                              dynamic_scale=None)\n    return state\n\n\ndef create_train_state_aval(rngkey, model, batch, dtype):\n    params = jax.eval_shape(model.init, rngkey, batch[\"input_ids\"],\n                            batch[\"attention_mask\"], batch[\"token_type_ids\"],\n                            batch[\"position_ids\"])\n\n    def weight_decay_mask(pytree):\n        # do not use weight decay on layer norm and bias.\n        return jax.tree_map(lambda x: x.ndim > 1, pytree)\n\n    tx = optax.chain(\n        #optax.clip_by_global_norm(1.0),  # TODO(lmzheng): fix reduce-scatter for this\n        optax.adamw(learning_rate=1e-2, mask=weight_decay_mask))\n    use_master_copy = (dtype == jnp.float16)\n    state = TrainState.create_aval(apply_fn=model.apply,\n                                   params=params,\n                                   tx=tx,\n                                   use_master_copy=use_master_copy,\n                                   dynamic_scale=None)\n    return state\n\n\ndef get_train_step(parallel_method, grad_func=None):\n\n    if grad_func is None:\n        grad_func = alpa.grad\n\n    @parallelize(method=parallel_method)\n    def train_step(state, batch, rng_key):\n\n        def loss_func(params):\n            rngs = {\"dropout\": rng_key}\n            logits = state.apply_fn(params,\n                                    batch[\"input_ids\"],\n                                    batch[\"attention_mask\"],\n                                    batch[\"token_type_ids\"],\n                                    batch[\"position_ids\"],\n                                    deterministic=True,\n                                    rngs=rngs)[0]\n            label_mask = jnp.where(batch[\"labels\"] > 0, 1.0, 0.0)\n            labels = jax.nn.one_hot(batch[\"labels\"], logits.shape[-1])\n            loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1),\n                            axis=-1)\n            loss = (label_mask * loss).sum() / label_mask.sum()\n            return loss\n\n        grads = grad_func(loss_func)(state.params)\n        new_state = state.apply_gradients(grads=grads)\n        # TODO(lmzheng): add dynamic scaling for mixed-precision training\n        return new_state\n\n    return train_step\n\n\ndef prepare_gpt_bert_input_and_model(model_type,\n                                     benchmark_case,\n                                     add_manual_remat=None,\n                                     add_manual_layer_marker=None,\n                                     num_manual_pipeline_stages=None,\n                                     aval_train_state=True,\n                                     tie_word_embeddings=False):\n    print_used_time(None)\n    batch_size = benchmark_case.batch_size\n    (seq_len, hidden_size, num_layers, num_heads,\n     vocab_size) = benchmark_case.model_config\n    dtype = jnp.float16\n    # Prepare input batch\n    batch = {\n        \"input_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"attention_mask\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"token_type_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"position_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"labels\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n    }\n    print_used_time(\"Prepare input\")\n\n    bert_config = BertConfig(\n        vocab_size=vocab_size,\n        hidden_size=hidden_size,\n        num_attention_heads=num_heads,\n        intermediate_size=hidden_size * 4,\n        num_hidden_layers=num_layers,\n        type_vocab_size=0,\n        tie_word_embeddings=tie_word_embeddings,\n        gradient_checkpointing=add_manual_remat,\n        add_manual_pipeline_markers=add_manual_layer_marker,\n        pipeline_mp_size=num_manual_pipeline_stages,\n    )\n\n    # Init train state\n    if model_type == \"bert\":\n        model = FlaxBertForMaskedLMModule(bert_config, dtype=dtype)\n    elif model_type == \"gpt\":\n        model = FlaxGPTForLMModule(bert_config, dtype=dtype)\n    else:\n        raise ValueError(f\"Invalid model {model_type}\")\n\n    rngkey = jax.random.PRNGKey(0)\n    if aval_train_state:\n        state = create_train_state_aval(rngkey, model, batch, dtype)\n    else:\n        state = create_train_state(rngkey, model, batch, dtype)\n    print_used_time(\"Create train state\")\n    return state, batch, rngkey\n\n\ndef compute_gpt_bert_statistics(benchmark_case, latencies, num_devices):\n    batch_size = benchmark_case.batch_size\n    (seq_len, hidden_size, num_layers, num_heads,\n     vocab_size) = benchmark_case.model_config\n    use_remat = benchmark_case.parallel_args.use_remat\n\n    tflops = compute_gpt_tflops(batch_size,\n                                seq_len,\n                                num_layers,\n                                hidden_size,\n                                vocab_size,\n                                num_devices,\n                                np.mean(latencies),\n                                checkpoint_activations=use_remat)\n    parameter_count = compute_gpt_parameter_count(num_layers, hidden_size,\n                                                  vocab_size)\n    return tflops, parameter_count\n\n\ndef benchmark_gpt_bert_3d_internal(model_type,\n                                   benchmark_case,\n                                   niter,\n                                   num_hosts,\n                                   num_devices_per_host,\n                                   aval_train_state=True,\n                                   profile_driver_time=False):\n    # Connect to the cluster\n    virtual_mesh = get_global_cluster().get_virtual_physical_mesh(\n        host_ids=list(range(num_hosts)),\n        num_devices_per_host=num_devices_per_host)\n    set_global_virtual_physical_mesh(virtual_mesh)\n\n    # Parallel configs\n    pipeline_schedule = (\"1f1b_overlap_friendly\"\n                         if global_config.enable_overlapping else \"1f1b\")\n    (method, add_manual_remat, add_manual_layer_marker,\n     num_manual_pipeline_stages) = get_pipeshard_parallel_method(\n         benchmark_case,\n         virtual_mesh.num_devices_per_host,\n         use_fine_grained_remat=True,\n         pipeline_schedule=pipeline_schedule)\n\n    state, batch, rngkey = prepare_gpt_bert_input_and_model(\n        model_type,\n        benchmark_case,\n        add_manual_remat=add_manual_remat,\n        add_manual_layer_marker=add_manual_layer_marker,\n        num_manual_pipeline_stages=num_manual_pipeline_stages,\n        aval_train_state=aval_train_state)\n\n    train_step = get_train_step(method)\n\n    (latencies, max_mem_allocated, compilation_times,\n     executable) = compile_and_benchmark_pipeshard_training_executable(\n         benchmark_case.parallel_mode,\n         niter,\n         train_step,\n         state, (batch, rngkey),\n         profile_driver_time=profile_driver_time)\n\n    tflops, parameter_count = compute_gpt_bert_statistics(\n        benchmark_case, latencies, virtual_mesh.num_devices)\n\n    # report_pipeline_breakdown(executable,\n    #                           [\"resharding_send\", \"resharding_recv\",\n    #                            \"compute\"],\n    #                           niter)\n\n    (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes,\n     logical_mesh_shapes, autosharding_option_dicts) = get_last_dp_result()\n    metadata = {\n        \"compilation_times\": compilation_times,\n        \"compute_cost_file_name\": compute_cost_file_name,\n        \"forward_stage_layer_ids\": forward_stage_layer_ids,\n        \"submesh_shapes\": submesh_shapes,\n        \"logical_mesh_shapes\": logical_mesh_shapes,\n        \"autosharding_option_dicts\": autosharding_option_dicts,\n    }\n\n    return parameter_count, max_mem_allocated, latencies, tflops, metadata\n\n\ndef benchmark_gpt_bert_2d_internal(physical_mesh,\n                                   model_type,\n                                   benchmark_case,\n                                   niter,\n                                   profile_driver_time=False):\n    method, grad_func = get_shard_parallel_method(benchmark_case, physical_mesh)\n\n    state, batch, rngkey = prepare_gpt_bert_input_and_model(\n        model_type,\n        benchmark_case,\n        add_manual_remat=benchmark_case.parallel_args.use_remat,\n        aval_train_state=global_config.use_dummy_value_for_benchmarking)\n\n    train_step = get_train_step(method, grad_func=grad_func)\n\n    (latencies, ilp_objective, peak_mem,\n     executable) = compile_and_benchmark_shard_training_executable(\n         physical_mesh,\n         niter,\n         train_step,\n         state, (batch, rngkey),\n         profile_driver_time=profile_driver_time)\n\n    tflops, parameter_count = compute_gpt_bert_statistics(\n        benchmark_case, latencies, physical_mesh.num_devices)\n    metadata = {\n        \"ilp_objective\": ilp_objective,\n    }\n    return parameter_count, peak_mem, latencies, tflops, metadata\n"
  },
  {
    "path": "benchmark/alpa/benchmark_one_case_gpt_bert_inference.py",
    "content": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport os\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom alpa import (parallelize, get_global_cluster,\n                  set_global_virtual_physical_mesh)\nfrom alpa.model.bert_model import BertConfig, FlaxBertLayerCollection\nfrom alpa.model.gpt_model import FlaxGPTForLMModule\nfrom alpa.util import print_used_time, GB, write_tsv\n\nfrom util import compute_gpt_parameter_count, compute_gpt_tflops\nfrom benchmark_parallel_utils import (\n    get_pipeshard_parallel_method,\n    compile_and_benchmark_pipeshard_inference_executable,\n    compute_avg_stage_latencies)\n\n\ndef create_infer_params_aval(rngkey, model, batch, model_type):\n    if model_type == \"gpt_no_embedding_inference\":\n        params = jax.eval_shape(model.init, rngkey, batch[\"x\"],\n                                batch[\"attention_mask\"])\n    elif model_type == \"gpt_inference\":\n        params = jax.eval_shape(model.init, rngkey, batch[\"input_ids\"],\n                                batch[\"attention_mask\"],\n                                batch[\"token_type_ids\"], batch[\"position_ids\"])\n    else:\n        raise ValueError(f\"Invalid model type: {model_type}\")\n    params = jax.eval_shape(\n        lambda p: jax.tree_util.tree_map(\n            lambda x: jnp.asarray(x, dtype=jnp.float16), p), params)\n    return params\n\n\ndef get_infer_step(parallel_method, model, model_type):\n\n    def infer_step_with_embedding(params, batch, rng_key):\n        rngs = {\"dropout\": rng_key}\n        logits = model.apply(params,\n                             batch[\"input_ids\"],\n                             batch[\"attention_mask\"],\n                             batch[\"token_type_ids\"],\n                             batch[\"position_ids\"],\n                             deterministic=True,\n                             rngs=rngs)[0]\n        label_mask = jnp.where(batch[\"labels\"] > 0, 1.0, 0.0)\n        labels = jax.nn.one_hot(batch[\"labels\"], logits.shape[-1])\n        loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)\n        loss = (label_mask * loss).sum() / label_mask.sum()\n        return loss\n\n    def infer_step_without_embedding(params, batch, rng_key):\n        out = model.apply(params,\n                          batch[\"x\"],\n                          batch[\"attention_mask\"],\n                          output_attentions=True,\n                          output_hidden_states=True)\n        loss = jnp.mean((out.last_hidden_state - batch[\"y\"])**2)\n        return loss\n\n    if model_type == \"gpt_no_embedding_inference\":\n        infer_step = infer_step_without_embedding\n    elif model_type == \"gpt_inference\":\n        infer_step = infer_step_with_embedding\n    else:\n        raise ValueError(f\"Invalid model type: {model_type}\")\n    return parallelize(infer_step, method=parallel_method, donate_argnums=())\n\n\ndef prepare_gpt_inference_input_and_model(model_type,\n                                          benchmark_case,\n                                          add_manual_layer_marker=None,\n                                          num_manual_pipeline_stages=None,\n                                          tie_word_embeddings=False):\n    print_used_time(None)\n    batch_size = benchmark_case.batch_size\n    (seq_len, hidden_size, num_layers, num_heads,\n     vocab_size) = benchmark_case.model_config\n    dtype = jnp.float16\n\n    bert_config = BertConfig(\n        vocab_size=vocab_size,\n        hidden_size=hidden_size,\n        num_attention_heads=num_heads,\n        intermediate_size=hidden_size * 4,\n        num_hidden_layers=num_layers,\n        type_vocab_size=0,\n        tie_word_embeddings=tie_word_embeddings,\n        add_manual_pipeline_markers=add_manual_layer_marker,\n        pipeline_mp_size=num_manual_pipeline_stages,\n    )\n\n    # Init train state\n    if model_type == \"gpt_no_embedding_inference\":\n        batch = {\n            \"x\": jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype),\n            \"y\": jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype),\n            \"attention_mask\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        }\n        model = FlaxBertLayerCollection(bert_config, dtype=dtype)\n    elif model_type == \"gpt_inference\":\n        batch = {\n            \"input_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n            \"attention_mask\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n            \"token_type_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n            \"position_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n            \"labels\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        }\n\n        model = FlaxGPTForLMModule(bert_config, dtype=dtype)\n    else:\n        raise ValueError(f\"Invalid model {model_type}\")\n\n    rngkey = jax.random.PRNGKey(0)\n    params = create_infer_params_aval(rngkey, model, batch, model_type)\n    print_used_time(\"Create infer state\")\n    return model, params, batch, rngkey\n\n\ndef compute_gpt_inference_statistics(benchmark_case, latencies, num_devices):\n    batch_size = benchmark_case.batch_size\n    (seq_len, hidden_size, num_layers, num_heads,\n     vocab_size) = benchmark_case.model_config\n    use_remat = benchmark_case.parallel_args.use_remat\n\n    tflops = compute_gpt_tflops(batch_size,\n                                seq_len,\n                                num_layers,\n                                hidden_size,\n                                vocab_size,\n                                num_devices,\n                                np.mean(latencies),\n                                backward=False)\n    parameter_count = compute_gpt_parameter_count(num_layers, hidden_size,\n                                                  vocab_size)\n    return tflops, parameter_count\n\n\ndef benchmark_gpt_inference_internal(model_type,\n                                     benchmark_case,\n                                     niter,\n                                     num_hosts,\n                                     num_devices_per_host,\n                                     profile_driver_time=False,\n                                     profile_stage_execution_time=False):\n    # Connect to the cluster\n    virtual_mesh = get_global_cluster().get_virtual_physical_mesh(\n        host_ids=list(range(num_hosts)),\n        num_devices_per_host=num_devices_per_host)\n    set_global_virtual_physical_mesh(virtual_mesh)\n\n    (method, _, add_manual_layer_marker,\n     num_manual_pipeline_stages) = get_pipeshard_parallel_method(\n         benchmark_case,\n         virtual_mesh.num_devices_per_host,\n         pipeline_schedule=\"inference\")\n\n    model, params, batch, rngkey = prepare_gpt_inference_input_and_model(\n        model_type, benchmark_case, add_manual_layer_marker,\n        num_manual_pipeline_stages)\n\n    infer_step = get_infer_step(method, model, model_type)\n\n    (latencies, max_mem_allocated, compilation_times, executable,\n     per_stage_weight_mem,\n     per_stage_peak_mem) = compile_and_benchmark_pipeshard_inference_executable(\n         benchmark_case.parallel_mode,\n         niter,\n         infer_step,\n         params, (batch, rngkey),\n         profile_driver_time=profile_driver_time)\n\n    # Compute statistics\n    tflops, parameter_count = compute_gpt_inference_statistics(\n        benchmark_case, latencies, virtual_mesh.num_devices_per_host)\n\n    # Log per-stage execution information if needed\n    if profile_stage_execution_time:\n        model_name = f\"bert-{parameter_count/1e9:.1f}b\"\n        # dump chrome trace\n        executable.dump_stage_execution_trace(\n            f\"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json\"\n        )\n        # compute and log per-stage latency/memory statistics\n        exec_info = executable.get_stage_execution_info()\n        timelines = list(zip(*exec_info))\n        # drop warmup case\n        timelines = timelines[3:]\n        avg_stage_latencies = compute_avg_stage_latencies(timelines)\n        assert len(avg_stage_latencies) == num_manual_pipeline_stages\n        parallel_args = benchmark_case.parallel_args\n        dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp\n        heads = [\n            \"ModelName\", \"BS\", \"#Microbatch\", \"DP\", \"OP\", \"PP\", \"#GPU\",\n            \"MeanTime(s)\", \"StdTime(s)\", \"TFLOPs\", \"StageWeights(B)\",\n            \"StagePeakMem(B)\", \"StageLatencies(s)\"\n        ]\n        values = [\n            model_name, benchmark_case.batch_size,\n            benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp,\n            f\"{np.mean(latencies):.3f}\", f\"{np.std(latencies):.3f}\",\n            f\"{tflops:.2f}\", f\"{per_stage_weight_mem}\", f\"{per_stage_peak_mem}\",\n            list(avg_stage_latencies)\n        ]\n        write_tsv(heads, values, f\"inference_prof_res.tsv\")\n\n    metadata = {\n        \"compilation_times\": compilation_times,\n    }\n    return parameter_count, max_mem_allocated, latencies, tflops, metadata\n"
  },
  {
    "path": "benchmark/alpa/benchmark_one_case_moe.py",
    "content": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom alpa import get_global_cluster, set_global_virtual_physical_mesh\nfrom alpa.model.moe import FlaxMoEForLMModule, MoEConfig, TrainState\nfrom alpa.pipeline_parallel.stage_construction import get_last_dp_result\nfrom alpa.util import print_used_time\nimport optax\n\nfrom benchmark_one_case_gpt_bert import get_train_step\nfrom util import compute_moe_parameter_count, compute_moe_tflops\nfrom benchmark_parallel_utils import (\n    get_pipeshard_parallel_method, get_shard_parallel_method,\n    compile_and_benchmark_pipeshard_training_executable,\n    compile_and_benchmark_shard_training_executable)\n\n\ndef create_train_state(rngkey, model, dtype, batch):\n    params = model.init_dummy(rngkey, batch[\"input_ids\"],\n                              batch[\"attention_mask\"], batch[\"token_type_ids\"],\n                              batch[\"position_ids\"])\n\n    def weight_decay_mask(pytree):\n        # do not use weight decay on layer norm and bias.\n        return jax.tree_map(lambda x: x.ndim > 1, pytree)\n\n    tx = optax.adafactor(learning_rate=1e-2,\n                         weight_decay_mask=weight_decay_mask)\n\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              use_master_copy=(dtype == jnp.float16),\n                              dynamic_scale=None)\n    return state\n\n\ndef prepare_moe_input_and_model(benchmark_case,\n                                add_manual_remat=None,\n                                add_manual_layer_marker=None,\n                                num_manual_pipeline_stages=None,\n                                correct_expert_group_size=True):\n    print_used_time(None)\n    (batch_size, model_config, num_micro_batches, parallel_mode,\n     parallel_args) = benchmark_case\n    (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_experts,\n     expert_group_size) = model_config\n    dtype = jnp.float16\n    tie_word_embeddings = False\n\n    if correct_expert_group_size:\n        rang_factor = 1\n        expected_expert_group_size = min(\n            expert_group_size,\n            batch_size * seq_len // num_micro_batches // 1 // rang_factor)\n        if expected_expert_group_size != expert_group_size:\n            print(\"- Expected expert group size should be {}, \"\n                  \"but got {}. Will reset it\".format(expected_expert_group_size,\n                                                     expert_group_size))\n            expert_group_size = expected_expert_group_size\n\n    # Prepare input batch\n    batch = {\n        \"input_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"attention_mask\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"token_type_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"position_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"labels\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n    }\n    print_used_time(\"Prepare input\")\n\n    # Init train state\n    model = FlaxMoEForLMModule(\n        MoEConfig(\n            num_hidden_layers=num_layers,\n            hidden_size=hidden_size,\n            intermediate_size=hidden_size * 8,  # this is specific to gspmd.\n            num_attention_heads=num_heads,\n            max_position_embeddings=seq_len,\n            vocab_size=vocab_size,\n            expert_group_size=expert_group_size,\n            expert_number=num_experts,\n            tie_word_embeddings=tie_word_embeddings,\n            gradient_checkpointing=add_manual_remat,\n            add_manual_pipeline_markers=add_manual_layer_marker,\n            pipeline_mp_size=num_manual_pipeline_stages,\n        ),\n        dtype=dtype)\n\n    rngkey = jax.random.PRNGKey(0)\n    state = create_train_state(rngkey, model, dtype, batch)\n    print_used_time(\"Create train state\")\n    return state, batch, rngkey\n\n\ndef compute_moe_statistics(benchmark_case, latencies, num_devices):\n    batch_size = benchmark_case.batch_size\n    (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_experts,\n     expert_group_size) = benchmark_case.model_config\n    use_remat = benchmark_case.parallel_args.use_remat\n\n    tflops = compute_moe_tflops(batch_size,\n                                seq_len,\n                                num_layers,\n                                hidden_size,\n                                expert_group_size,\n                                vocab_size,\n                                num_experts,\n                                num_devices,\n                                np.mean(latencies),\n                                checkpoint_activations=use_remat)\n    parameter_count = compute_moe_parameter_count(num_layers,\n                                                  hidden_size,\n                                                  vocab_size,\n                                                  num_experts,\n                                                  mlp_factor=8)\n    return tflops, parameter_count\n\n\ndef benchmark_moe_3d_internal(benchmark_case,\n                              niter,\n                              num_hosts,\n                              num_devices_per_host,\n                              profile_driver_time=False):\n    # Connect to the cluster\n    virtual_mesh = get_global_cluster().get_virtual_physical_mesh(\n        host_ids=list(range(num_hosts)),\n        num_devices_per_host=num_devices_per_host)\n    set_global_virtual_physical_mesh(virtual_mesh)\n\n    # Parallel configs\n    (method, add_manual_remat, add_manual_layer_marker,\n     num_manual_pipeline_stages) = get_pipeshard_parallel_method(\n         benchmark_case,\n         virtual_mesh.num_devices_per_host,\n         use_fine_grained_remat=True,\n         allow_mixed_mesh_shape=True)\n\n    state, batch, rngkey = prepare_moe_input_and_model(\n        benchmark_case,\n        add_manual_remat=add_manual_remat,\n        add_manual_layer_marker=add_manual_layer_marker,\n        num_manual_pipeline_stages=num_manual_pipeline_stages)\n\n    train_step = get_train_step(method)\n\n    (latencies, max_mem_allocated, compilation_times,\n     executable) = compile_and_benchmark_pipeshard_training_executable(\n         benchmark_case.parallel_mode,\n         niter,\n         train_step,\n         state, (batch, rngkey),\n         profile_driver_time=profile_driver_time)\n\n    tflops, parameter_count = compute_moe_statistics(benchmark_case, latencies,\n                                                     virtual_mesh.num_devices)\n\n    (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes,\n     logical_mesh_shapes, autosharding_option_dicts) = get_last_dp_result()\n    metadata = {\n        \"compilation_times\": compilation_times,\n        \"compute_cost_file_name\": compute_cost_file_name,\n        \"forward_stage_layer_ids\": forward_stage_layer_ids,\n        \"submesh_shapes\": submesh_shapes,\n        \"logical_mesh_shapes\": logical_mesh_shapes,\n        \"autosharding_option_dicts\": autosharding_option_dicts,\n    }\n\n    return parameter_count, max_mem_allocated, latencies, tflops, metadata\n\n\ndef benchmark_moe_2d_internal(physical_mesh,\n                              benchmark_case,\n                              niter,\n                              profile_driver_time=False):\n    # Model configs\n    method, grad_func = get_shard_parallel_method(benchmark_case, physical_mesh)\n\n    state, batch, rngkey = prepare_moe_input_and_model(\n        benchmark_case,\n        add_manual_remat=benchmark_case.parallel_args.use_remat,\n        correct_expert_group_size=False)\n\n    # Compile executable\n    train_step = get_train_step(method, grad_func=grad_func)\n\n    (latencies, ilp_objective, peak_mem,\n     executable) = compile_and_benchmark_shard_training_executable(\n         physical_mesh,\n         niter,\n         train_step,\n         state, (batch, rngkey),\n         profile_driver_time=profile_driver_time)\n\n    # Compute statistics\n    tflops, parameter_count = compute_moe_statistics(benchmark_case, latencies,\n                                                     physical_mesh.num_devices)\n    metadata = {\n        \"ilp_objective\": ilp_objective,\n    }\n    return parameter_count, peak_mem, latencies, tflops, metadata\n"
  },
  {
    "path": "benchmark/alpa/benchmark_one_case_moe_inference.py",
    "content": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom alpa import parallelize, get_global_cluster, set_global_virtual_physical_mesh\nfrom alpa.model.moe import FlaxMoEForLMModule, MoEConfig, TrainState\nfrom alpa.pipeline_parallel.stage_construction import get_last_dp_result\nfrom alpa.util import print_used_time, GB, write_tsv\n\nfrom benchmark_one_case_gpt_bert import get_train_step\nfrom util import compute_moe_parameter_count, compute_moe_tflops\nfrom benchmark_parallel_utils import (\n    get_pipeshard_parallel_method, get_shard_parallel_method,\n    compile_and_benchmark_pipeshard_inference_executable,\n    compute_avg_stage_latencies)\n\n\ndef create_infer_params_aval(rngkey, model, batch):\n    params = jax.eval_shape(model.init, rngkey, batch[\"input_ids\"],\n                            batch[\"attention_mask\"], batch[\"token_type_ids\"],\n                            batch[\"position_ids\"])\n    params = jax.eval_shape(\n        lambda p: jax.tree_util.tree_map(\n            lambda x: jnp.asarray(x, dtype=jnp.float16), p), params)\n    return params\n\n\ndef get_infer_step(parallel_method, model):\n\n    def infer_step(params, batch, rng_key):\n        rngs = {\"dropout\": rng_key}\n        logits = model.apply(params,\n                             batch[\"input_ids\"],\n                             batch[\"attention_mask\"],\n                             batch[\"token_type_ids\"],\n                             batch[\"position_ids\"],\n                             deterministic=True,\n                             rngs=rngs)[0]\n        label_mask = jnp.where(batch[\"labels\"] > 0, 1.0, 0.0)\n        labels = jax.nn.one_hot(batch[\"labels\"], logits.shape[-1])\n        loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)\n        loss = (label_mask * loss).sum() / label_mask.sum()\n        return loss\n\n    return parallelize(infer_step, method=parallel_method, donate_argnums=())\n\n\ndef prepare_moe_inference_input_and_model(benchmark_case,\n                                          add_manual_remat=None,\n                                          add_manual_layer_marker=None,\n                                          num_manual_pipeline_stages=None,\n                                          correct_expert_group_size=True):\n    print_used_time(None)\n    batch_size = benchmark_case.batch_size\n    (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_experts,\n     expert_group_size) = benchmark_case.model_config\n    dtype = jnp.float16\n    tie_word_embeddings = False\n\n    if correct_expert_group_size:\n        rang_factor = 1\n        expected_expert_group_size = min(\n            expert_group_size, batch_size * seq_len //\n            benchmark_case.num_micro_batches // 1 // rang_factor)\n        if expected_expert_group_size != expert_group_size:\n            print(\"- Expected expert group size should be {}, \"\n                  \"but got {}. Will reset it\".format(expected_expert_group_size,\n                                                     expert_group_size))\n            expert_group_size = expected_expert_group_size\n\n    # Prepare input batch\n    batch = {\n        \"input_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"attention_mask\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"token_type_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"position_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"labels\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n    }\n    print_used_time(\"Prepare input\")\n\n    # Init train state\n    model = FlaxMoEForLMModule(\n        MoEConfig(\n            num_hidden_layers=num_layers,\n            hidden_size=hidden_size,\n            intermediate_size=hidden_size * 8,  # this is specific to gspmd.\n            num_attention_heads=num_heads,\n            max_position_embeddings=seq_len,\n            vocab_size=vocab_size,\n            expert_group_size=expert_group_size,\n            expert_number=num_experts,\n            tie_word_embeddings=tie_word_embeddings,\n            gradient_checkpointing=add_manual_remat,\n            add_manual_pipeline_markers=add_manual_layer_marker,\n            pipeline_mp_size=num_manual_pipeline_stages,\n        ),\n        dtype=dtype)\n\n    rngkey = jax.random.PRNGKey(0)\n    params = create_infer_params_aval(rngkey, model, batch)\n    print_used_time(\"Create train state\")\n    return model, params, batch, rngkey\n\n\ndef compute_moe_statistics(benchmark_case, latencies, num_devices):\n    batch_size = benchmark_case.batch_size\n    (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_experts,\n     expert_group_size) = benchmark_case.model_config\n    use_remat = benchmark_case.parallel_args.use_remat\n\n    tflops = compute_moe_tflops(batch_size,\n                                seq_len,\n                                num_layers,\n                                hidden_size,\n                                expert_group_size,\n                                vocab_size,\n                                num_experts,\n                                num_devices,\n                                np.mean(latencies),\n                                checkpoint_activations=use_remat)\n    parameter_count = compute_moe_parameter_count(num_layers,\n                                                  hidden_size,\n                                                  vocab_size,\n                                                  num_experts,\n                                                  mlp_factor=8)\n    return tflops, parameter_count\n\n\ndef benchmark_moe_inference_internal(benchmark_case,\n                                     niter,\n                                     num_hosts,\n                                     num_devices_per_host,\n                                     profile_driver_time=False,\n                                     profile_stage_execution_time=False):\n    # Connect to the cluster\n    virtual_mesh = get_global_cluster().get_virtual_physical_mesh(\n        host_ids=list(range(num_hosts)),\n        num_devices_per_host=num_devices_per_host)\n    set_global_virtual_physical_mesh(virtual_mesh)\n\n    # Parallel configs\n    (method, _, add_manual_layer_marker,\n     num_manual_pipeline_stages) = get_pipeshard_parallel_method(\n         benchmark_case,\n         virtual_mesh.num_devices_per_host,\n         pipeline_schedule=\"inference\")\n\n    model, params, batch, rngkey = prepare_moe_inference_input_and_model(\n        benchmark_case,\n        add_manual_layer_marker=add_manual_layer_marker,\n        num_manual_pipeline_stages=num_manual_pipeline_stages)\n\n    infer_step = get_infer_step(method, model)\n\n    (latencies, max_mem_allocated, compilation_times, executable,\n     per_stage_weight_mem,\n     per_stage_peak_mem) = compile_and_benchmark_pipeshard_inference_executable(\n         benchmark_case.parallel_mode,\n         niter,\n         infer_step,\n         params, (batch, rngkey),\n         profile_driver_time=profile_driver_time)\n\n    # compute statistics\n    tflops, parameter_count = compute_moe_statistics(benchmark_case, latencies,\n                                                     virtual_mesh.num_devices)\n\n    # Log per-stage execution information if needed\n    if profile_stage_execution_time:\n        model_name = f\"moe-{parameter_count/1e9:.1f}b\"\n        # dump chrome trace\n        executable.dump_stage_execution_trace(\n            f\"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json\"\n        )\n        # compute and log per-stage latency/memory statistics\n        exec_info = executable.get_stage_execution_info()\n        timelines = list(zip(*exec_info))\n        # drop warmup case\n        timelines = timelines[1:]\n        avg_stage_latencies = compute_avg_stage_latencies(timelines)\n        assert len(avg_stage_latencies) == num_manual_pipeline_stages\n        parallel_args = benchmark_case.parallel_args\n        dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp\n        heads = [\n            \"ModelName\", \"BS\", \"#Microbatch\", \"DP\", \"OP\", \"PP\", \"#GPU\",\n            \"MeanTime(s)\", \"StdTime(s)\", \"TFLOPs\", \"StageWeights(B)\",\n            \"StagePeakMem(B)\", \"StageLatencies(s)\"\n        ]\n        values = [\n            model_name, benchmark_case.batch_size,\n            benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp,\n            f\"{np.mean(latencies):.3f}\", f\"{np.std(latencies):.3f}\",\n            f\"{tflops:.2f}\", f\"{per_stage_weight_mem}\", f\"{per_stage_peak_mem}\",\n            avg_stage_latencies\n        ]\n        write_tsv(heads, values, f\"benchmark_results.tsv\")\n\n    metadata = {\n        \"compilation_times\": compilation_times,\n    }\n\n    return parameter_count, max_mem_allocated, latencies, tflops, metadata\n"
  },
  {
    "path": "benchmark/alpa/benchmark_one_case_unet.py",
    "content": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nfrom alpa.pipeline_parallel.layer_construction import ManualLayerOption\nfrom flax.training import common_utils\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\nimport alpa\nfrom alpa import (parallelize, get_global_cluster,\n                  set_global_virtual_physical_mesh, ShardParallel,\n                  automatic_remat, global_config)\nfrom alpa.model.unet_2d import get_unet_2d\nfrom alpa.model.model_util import TrainState\nfrom alpa.pipeline_parallel.stage_construction import get_last_dp_result\nfrom alpa.util import print_used_time, compute_param_number\nfrom benchmark_parallel_utils import (\n    get_pipeshard_parallel_method,\n    compile_and_benchmark_pipeshard_training_executable)\n\n\ndef create_learning_rate_fn():\n    \"\"\"Create learning rate schedule.\"\"\"\n    base_learning_rate = 0.1\n    warmup_epochs = 5.0\n    steps_per_epoch = 10000\n    num_epochs = 100.0\n\n    warmup_fn = optax.linear_schedule(init_value=0.,\n                                      end_value=base_learning_rate,\n                                      transition_steps=warmup_epochs *\n                                      steps_per_epoch)\n    cosine_epochs = max(num_epochs - warmup_epochs, 1)\n    cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate,\n                                            decay_steps=cosine_epochs *\n                                            steps_per_epoch)\n    schedule_fn = optax.join_schedules(\n        schedules=[warmup_fn, cosine_fn],\n        boundaries=[warmup_epochs * steps_per_epoch])\n    return schedule_fn\n\n\ndef create_train_state(rngkey, model, batch, learning_rate_fn):\n    params = model.init_dummy(rngkey, *batch)\n\n    # dynamic_scale = optim.DynamicScale()\n    dynamic_scale = None\n\n    tx = optax.sgd(\n        learning_rate=learning_rate_fn,\n        momentum=0.9,\n        nesterov=True,\n    )\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              dynamic_scale=None)\n    return state\n\n\ndef get_train_step(learning_rate_fn,\n                   use_remat,\n                   num_remat_layers,\n                   method,\n                   grad_func=None):\n\n    if grad_func is None:\n        grad_func = alpa.grad\n\n    @parallelize(method=method)\n    def train_step(state, batch):\n\n        def loss_fn(params):\n            outs = state.apply_fn(params, batch[\"images\"], batch[\"timesteps\"],\n                                  batch[\"encoder_hidden_states\"])\n            sample = outs.sample\n            loss = jnp.mean(\n                optax.l2_loss(predictions=sample, targets=batch[\"targets\"]))\n\n            metrics = {\"loss\": loss, \"lr\": learning_rate_fn(step)}\n            return loss, metrics\n\n        if isinstance(method, ShardParallel) and use_remat:\n            loss_fn = automatic_remat(loss_fn, layer_num=num_remat_layers)\n\n        step = state.step\n\n        grad_fn = grad_func(loss_fn, has_aux=True)\n        grads, aux = grad_fn(state.params)\n        metrics = aux\n\n        new_state = state.apply_gradients(grads=grads)\n\n        return new_state, metrics\n\n    return train_step\n\n\ndef prepare_unet_input_and_model(benchmark_case):\n    print_used_time(None)\n    # Model configs\n    (batch_size, model_config, _, _, _) = benchmark_case\n    (image_size, channel_size, block_cnt, dtype, _) = model_config\n    in_channels = 3\n    out_channels = 4\n\n    # Prepare input batch\n    encoder_factor = 2**(block_cnt - 1)\n    # Unlike wide-resnet, we have a transpose of input image in unet 2d model.\n    batch = {\n        \"images\":\n            jnp.ones((batch_size, in_channels, image_size, image_size),\n                     dtype=dtype),\n        \"targets\":\n            jnp.ones((batch_size, out_channels, image_size, image_size),\n                     dtype=dtype),\n        \"timesteps\":\n            1,\n        \"encoder_hidden_states\":\n            jnp.ones((batch_size, (image_size // encoder_factor)**2,\n                      channel_size * encoder_factor // 2))\n    }\n    print_used_time(\"Prepare input\")\n\n    # Init train state\n\n    down_block_types = ((\"CrossAttnDownBlock2D\",) * (block_cnt - 1) +\n                        (\"DownBlock2D\",))\n    up_block_types = (\"UpBlock2D\",) + (\"CrossAttnUpBlock2D\",) * (block_cnt - 1)\n    # Each downsampling, the num channels grows twice\n    block_out_channels = [channel_size * (2**i) for i in range(block_cnt - 1)]\n    block_out_channels.append(block_out_channels[-1])\n    model = get_unet_2d(image_size,\n                        down_block_types=down_block_types,\n                        up_block_types=up_block_types,\n                        block_out_channels=block_out_channels,\n                        in_channels=in_channels,\n                        out_channels=out_channels,\n                        layers_per_block=1,\n                        dtype=dtype)\n\n    rngkey = jax.random.PRNGKey(0)\n    learning_rate_fn = create_learning_rate_fn()\n    input_batch = (batch[\"images\"], batch[\"timesteps\"],\n                   batch[\"encoder_hidden_states\"])\n    state = create_train_state(rngkey, model, input_batch, learning_rate_fn)\n    print_used_time(\"Create train state\")\n    return state, batch, learning_rate_fn\n\n\ndef benchmark_unet_3d_internal(benchmark_case,\n                               niter,\n                               num_hosts,\n                               num_devices_per_host,\n                               profile_driver_time=False):\n    # Connect to the cluster\n    virtual_mesh = get_global_cluster().get_virtual_physical_mesh(\n        host_ids=list(range(num_hosts)),\n        num_devices_per_host=num_devices_per_host)\n    set_global_virtual_physical_mesh(virtual_mesh)\n\n    # Parallel configs\n    allow_mixed_mesh_shape = True\n    pipeline_schedule = (\"1f1b_overlap_friendly\"\n                         if global_config.enable_overlapping else \"1f1b\")\n    (method, _, _, _) = get_pipeshard_parallel_method(\n        benchmark_case,\n        virtual_mesh.num_devices_per_host,\n        allow_mixed_mesh_shape=allow_mixed_mesh_shape,\n        pipeline_schedule=pipeline_schedule)\n    method: alpa.parallel_method.PipeshardParallel\n    # The operator clustering for unet is not sufficient\n    method.layer_option = ManualLayerOption(remat_layer=True)\n\n    use_grad_acc = benchmark_case.num_micro_batches > 1\n    grad_func = alpa.grad if use_grad_acc else jax.grad\n    state, batch, learning_rate_fn = prepare_unet_input_and_model(\n        benchmark_case)\n    train_step = get_train_step(learning_rate_fn,\n                                False,\n                                None,\n                                method,\n                                grad_func=grad_func)\n\n    (latencies, max_mem_allocated, compilation_times,\n     executable) = compile_and_benchmark_pipeshard_training_executable(\n         benchmark_case.parallel_mode,\n         niter,\n         train_step,\n         state, (batch,),\n         profile_driver_time=profile_driver_time)\n\n    # Profile submesh executables\n    # del state\n    # del metrics\n    # for i, profiled in enumerate(executable.profile_all_executables()):\n    #     pstr = f\"Mesh {i}: \"\n    #     for k in profiled:\n    #         pstr += f\"Exec {k}: {profiled[k][0]}s; \"\n    #     print(pstr)\n    executable.dump_debug_info(\"tmp\")\n\n    # Compute statistics\n    num_gpus = virtual_mesh.num_devices\n    tflops = executable.flop_count / num_gpus / np.mean(latencies) / 1e12\n    parameter_count = compute_param_number(state.params)\n\n    (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes,\n     logical_mesh_shapes, autosharding_option_dicts) = get_last_dp_result()\n    metadata = {\n        \"compilation_times\": compilation_times,\n        \"compute_cost_file_name\": compute_cost_file_name,\n        \"forward_stage_layer_ids\": forward_stage_layer_ids,\n        \"submesh_shapes\": submesh_shapes,\n        \"logical_mesh_shapes\": logical_mesh_shapes,\n        \"autosharding_option_dicts\": autosharding_option_dicts,\n    }\n\n    return parameter_count, max_mem_allocated, latencies, tflops, metadata\n"
  },
  {
    "path": "benchmark/alpa/benchmark_one_case_wresnet.py",
    "content": "\"\"\"Benchmark one case of inter-op + intra-op parallelism.\"\"\"\nfrom functools import partial\n\nfrom flax.training import common_utils\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\nimport alpa\nfrom alpa import (parallelize, get_global_cluster,\n                  set_global_virtual_physical_mesh, ShardParallel,\n                  automatic_remat)\nfrom alpa.model.wide_resnet import get_wide_resnet, TrainState\nfrom alpa.pipeline_parallel.stage_construction import get_last_dp_result\nfrom alpa.util import print_used_time, compute_param_number\nfrom benchmark_parallel_utils import (\n    get_pipeshard_parallel_method, get_shard_parallel_method,\n    compile_and_benchmark_pipeshard_training_executable,\n    compile_and_benchmark_shard_training_executable)\n\n\ndef compute_metrics(logits, labels):\n    metrics = {\n        \"loss\": cross_entropy_loss(logits, labels),\n        \"accuracy\": jnp.mean(jnp.argmax(logits, -1) == labels),\n    }\n    return metrics\n\n\ndef cross_entropy_loss(logits, labels):\n    num_classes = logits.shape[-1]\n    one_hot_labels = common_utils.onehot(labels, num_classes=num_classes)\n    xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)\n    return jnp.mean(xentropy)\n\n\ndef create_learning_rate_fn():\n    \"\"\"Create learning rate schedule.\"\"\"\n    base_learning_rate = 0.1\n    warmup_epochs = 5.0\n    steps_per_epoch = 10000\n    num_epochs = 100.0\n\n    warmup_fn = optax.linear_schedule(init_value=0.,\n                                      end_value=base_learning_rate,\n                                      transition_steps=warmup_epochs *\n                                      steps_per_epoch)\n    cosine_epochs = max(num_epochs - warmup_epochs, 1)\n    cosine_fn = optax.cosine_decay_schedule(init_value=base_learning_rate,\n                                            decay_steps=cosine_epochs *\n                                            steps_per_epoch)\n    schedule_fn = optax.join_schedules(\n        schedules=[warmup_fn, cosine_fn],\n        boundaries=[warmup_epochs * steps_per_epoch])\n    return schedule_fn\n\n\ndef create_train_state(rngkey, model, input_images, learning_rate_fn):\n    params = model.init_dummy(rngkey, input_images)\n    params, batch_stats = params[\"params\"], params[\"batch_stats\"]\n\n    # dynamic_scale = optim.DynamicScale()\n    dynamic_scale = None\n\n    tx = optax.sgd(\n        learning_rate=learning_rate_fn,\n        momentum=0.9,\n        nesterov=True,\n    )\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              batch_stats=batch_stats,\n                              dynamic_scale=None)\n    return state\n\n\ndef get_train_step(learning_rate_fn,\n                   use_remat,\n                   num_remat_layers,\n                   method,\n                   grad_func=None):\n\n    if grad_func is None:\n        grad_func = alpa.grad\n\n    @parallelize(method=method)\n    def train_step(state, batch):\n\n        def loss_fn(params):\n            logits, new_model_state = state.apply_fn(\n                {\n                    \"params\": params,\n                    \"batch_stats\": state.batch_stats\n                },\n                batch[\"images\"],\n                mutable=[\"batch_stats\"])\n            loss = cross_entropy_loss(logits, batch[\"labels\"])\n            # weight_penalty_params = jax.tree_leaves(params)\n            # weight_decay = 0.0001\n            # weight_l2 = sum(\n            #     [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])\n            # weight_penalty = weight_decay * 0.5 * weight_l2\n            metrics = {\n                \"loss\": loss,\n                \"accuracy\": jnp.mean(jnp.argmax(logits, -1) == batch[\"labels\"]),\n                \"lr\": learning_rate_fn(step)\n            }\n            return loss, (new_model_state, metrics)\n\n        if isinstance(method, ShardParallel) and use_remat:\n            loss_fn = automatic_remat(loss_fn, layer_num=num_remat_layers)\n\n        step = state.step\n        dynamic_scale = state.dynamic_scale\n\n        if dynamic_scale:\n            # TODO(lmzheng): handle gradient accumulation for this\n            grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)\n            dynamic_scale, is_fin, aux, grads = grad_fn(state.params)\n            # dynamic loss takes care of averaging gradients across replicas\n        else:\n            grad_fn = grad_func(loss_fn, has_aux=True)\n            grads, aux = grad_fn(state.params)\n        new_model_state, metrics = aux\n\n        new_state = state.apply_gradients(\n            grads=grads, batch_stats=new_model_state[\"batch_stats\"])\n        if dynamic_scale:\n            # if is_fin == False the gradients contain Inf/NaNs and optimizer\n            # state and params should be restored (= skip this step).\n            new_state = new_state.replace(\n                opt_state=jax.tree_multimap(partial(jnp.where, is_fin),\n                                            new_state.opt_state,\n                                            state.opt_state),\n                params=jax.tree_multimap(partial(jnp.where, is_fin),\n                                         new_state.params, state.params))\n            metrics[\"scale\"] = dynamic_scale.scale\n\n        return new_state, metrics\n\n    return train_step\n\n\ndef prepare_wresnet_input_and_model(benchmark_case):\n    print_used_time(None)\n    # Model configs\n    (batch_size, model_config, num_micro_batches, parallel_mode,\n     parallel_args) = benchmark_case\n    (image_size, num_layers, num_channels, width_factor, dtype) = model_config\n    if dtype == \"fp32\":\n        dtype = jnp.float32\n    elif dtype == \"fp16\":\n        dtype = jnp.float16\n    else:\n        raise ValueError(f\"Invalid dtype: {dtype}\")\n\n    # Prepare input batch\n    num_classes = 1024\n    batch = {\n        \"images\":\n            jnp.ones((batch_size, image_size, image_size, 3), dtype=dtype),\n        \"labels\":\n            jnp.ones((batch_size), dtype=jnp.int32),\n    }\n    print_used_time(\"Prepare input\")\n\n    # Init train state\n    model = get_wide_resnet(num_layers, width_factor, num_channels, num_classes,\n                            dtype)\n\n    rngkey = jax.random.PRNGKey(0)\n    learning_rate_fn = create_learning_rate_fn()\n    state = create_train_state(rngkey, model, batch[\"images\"], learning_rate_fn)\n    print_used_time(\"Create train state\")\n    return state, batch, learning_rate_fn\n\n\ndef benchmark_wresnet_3d_internal(benchmark_case,\n                                  niter,\n                                  num_hosts,\n                                  num_devices_per_host,\n                                  profile_driver_time=False):\n    # Connect to the cluster\n    virtual_mesh = get_global_cluster().get_virtual_physical_mesh(\n        host_ids=list(range(num_hosts)),\n        num_devices_per_host=num_devices_per_host)\n    set_global_virtual_physical_mesh(virtual_mesh)\n\n    # Parallel configs\n    allow_mixed_mesh_shape = True\n    (method, _, _, _) = get_pipeshard_parallel_method(\n        benchmark_case,\n        virtual_mesh.num_devices_per_host,\n        allow_mixed_mesh_shape=allow_mixed_mesh_shape)\n\n    use_grad_acc = benchmark_case.num_micro_batches > 1\n    grad_func = alpa.grad if use_grad_acc else jax.grad\n    state, batch, learning_rate_fn = prepare_wresnet_input_and_model(\n        benchmark_case)\n    train_step = get_train_step(learning_rate_fn,\n                                False,\n                                None,\n                                method,\n                                grad_func=grad_func)\n\n    (latencies, max_mem_allocated, compilation_times,\n     executable) = compile_and_benchmark_pipeshard_training_executable(\n         benchmark_case.parallel_mode,\n         niter,\n         train_step,\n         state, (batch,),\n         profile_driver_time=profile_driver_time)\n\n    # Profile submesh executables\n    # del state\n    # del metrics\n    # for i, profiled in enumerate(executable.profile_all_executables()):\n    #     pstr = f\"Mesh {i}: \"\n    #     for k in profiled:\n    #         pstr += f\"Exec {k}: {profiled[k][0]}s; \"\n    #     print(pstr)\n\n    # Compute statistics\n    num_gpus = virtual_mesh.num_devices\n    tflops = executable.flop_count / num_gpus / np.mean(latencies) / 1e12\n    parameter_count = compute_param_number(state.params)\n\n    (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes,\n     logical_mesh_shapes, autosharding_option_dicts) = get_last_dp_result()\n    metadata = {\n        \"compilation_times\": compilation_times,\n        \"compute_cost_file_name\": compute_cost_file_name,\n        \"forward_stage_layer_ids\": forward_stage_layer_ids,\n        \"submesh_shapes\": submesh_shapes,\n        \"logical_mesh_shapes\": logical_mesh_shapes,\n        \"autosharding_option_dicts\": autosharding_option_dicts,\n    }\n\n    return parameter_count, max_mem_allocated, latencies, tflops, metadata\n\n\ndef benchmark_wresnet_2d_internal(physical_mesh,\n                                  benchmark_case,\n                                  niter,\n                                  profile_driver_time=False):\n    # Model configs\n    method, grad_func = get_shard_parallel_method(benchmark_case, physical_mesh)\n\n    use_grad_acc = benchmark_case.num_micro_batches > 1\n    grad_func = alpa.grad if use_grad_acc else jax.grad\n    state, batch, learning_rate_fn = prepare_wresnet_input_and_model(\n        benchmark_case)\n    train_step = get_train_step(learning_rate_fn,\n                                False,\n                                None,\n                                method,\n                                grad_func=grad_func)\n\n    (latencies, ilp_objective, peak_mem,\n     executable) = compile_and_benchmark_shard_training_executable(\n         physical_mesh,\n         niter,\n         train_step,\n         state, (batch,),\n         profile_driver_time=profile_driver_time)\n\n    # Compute statistics\n    num_gpus = physical_mesh.num_devices\n    tflops = executable.flop_count / num_gpus / np.mean(latencies) / 1e12\n    parameter_count = compute_param_number(state.params)\n    metadata = {\n        \"ilp_objective\": ilp_objective,\n    }\n    return parameter_count, peak_mem, latencies, tflops, metadata\n"
  },
  {
    "path": "benchmark/alpa/benchmark_parallel_utils.py",
    "content": "\"\"\"Options of a benchmark case.\"\"\"\nfrom collections import namedtuple\nimport json\nimport os\nimport time\nfrom typing import Optional, Dict, Any, List\n\nimport numpy as np\nimport jax\nfrom jax._src.tree_util import tree_flatten, tree_leaves, tree_unflatten\n\nimport alpa\nfrom alpa import (AutoShardingOption, ShardParallel, PipeshardParallel,\n                  ManualStageOption, AutoStageOption, AutoLayerOption,\n                  global_config, PhysicalDeviceMesh)\nfrom alpa.timer import timers\nfrom alpa.util import (print_used_time, to_str_round,\n                       count_communication_primitives, GB)\n\nBenchmarkCase = namedtuple(\"BenchmarkCase\", [\n    \"batch_size\", \"model_config\", \"num_micro_batches\", \"parallel_mode\",\n    \"parallel_args\"\n])\n\nShardParallelArgs = namedtuple(\"ShardParallelArgs\", [\n    \"prefer_reduce_scatter\", \"use_remat\", \"logical_mesh_shape\",\n    \"force_batch_dim_mapping\"\n])\n\nUniformParallelArgs = namedtuple(\"UniformParallelArgs\", [\n    \"prefer_reduce_scatter\", \"use_remat\", \"dp\", \"op\", \"pp\",\n    \"force_batch_dim_mapping\"\n])\n\nSearchParallelArgs = namedtuple(\"SearchParallelArgs\", [\n    \"prefer_reduce_scatter\", \"use_remat\", \"num_auto_layers\", \"auto_stage_option\"\n])\n\nLoadSolutionParallelArgs = namedtuple(\"LoadSolutionParallelArgs\", [\n    \"prefer_reduce_scatter\", \"use_remat\", \"num_auto_layers\",\n    \"forward_stage_layer_ids\", \"submesh_physical_shapes\",\n    \"submesh_logical_shapes\", \"submesh_autosharding_option_dicts\"\n])\n\n\ndef get_pipeshard_parallel_method(benchmark_case: BenchmarkCase,\n                                  num_devices_per_host: Optional[int] = None,\n                                  allow_mixed_mesh_shape: bool = False,\n                                  use_fine_grained_remat: bool = False,\n                                  pipeline_schedule: str = \"1f1b\"):\n    \"\"\"Create the parallel method of a benchmark case.\n\n    Args:\n        benchmark_case: The benchmark case.\n        num_devices_per_host: The number of devices per host, used in uniform\n          parallel mode.\n        allow_mixed_mesh_shape: Whether to allow the mixed mesh shape in\n          the autosharding pass.\n    \"\"\"\n\n    num_micro_batches = benchmark_case.num_micro_batches\n    parallel_mode = benchmark_case.parallel_mode\n    parallel_args = benchmark_case.parallel_args\n\n    if parallel_mode == \"search\":\n        assert isinstance(parallel_args, SearchParallelArgs)\n        (prefer_reduce_scatter, use_remat, num_auto_layers,\n         auto_stage_option) = parallel_args\n        add_manual_layer_marker = None\n        num_manual_pipeline_stages = None\n        add_manual_remat = None\n        remat_mode = \"coarse_grained_remat\" if use_remat else \"none\"\n        auto_stage_option[\"cached_profile_result\"] = None\n        method = PipeshardParallel(\n            num_micro_batches=num_micro_batches,\n            default_auto_sharding_option=AutoShardingOption(\n                prefer_reduce_scatter=prefer_reduce_scatter,\n                allow_mixed_mesh_shape=allow_mixed_mesh_shape,\n            ),\n            pipeline_schedule=pipeline_schedule,\n            layer_option=AutoLayerOption(layer_num=num_auto_layers,\n                                         remat_mode=remat_mode),\n            stage_option=AutoStageOption(**auto_stage_option))\n    elif parallel_mode == \"load_solution\":\n        assert isinstance(parallel_args, LoadSolutionParallelArgs)\n        (prefer_reduce_scatter, use_remat, num_auto_layers,\n         forward_stage_layer_ids, submesh_physical_shapes,\n         submesh_logical_shapes,\n         submesh_autosharding_option_dicts) = parallel_args\n        add_manual_layer_marker = None\n        num_manual_pipeline_stages = None\n        add_manual_remat = None\n        if use_remat:\n            remat_mode = (\"fine_grained_remat\"\n                          if use_fine_grained_remat else \"coarse_grained_remat\")\n        else:\n            remat_mode = \"none\"\n        model_num_layers = benchmark_case.model_config.num_layers\n        method = PipeshardParallel(\n            num_micro_batches=num_micro_batches,\n            default_auto_sharding_option=AutoShardingOption(\n                prefer_reduce_scatter=prefer_reduce_scatter,\n                allow_mixed_mesh_shape=allow_mixed_mesh_shape,\n            ),\n            pipeline_schedule=pipeline_schedule,\n            layer_option=AutoLayerOption(\n                layer_num=num_auto_layers,\n                remat_mode=remat_mode,\n                fine_grained_remat_layer_num=model_num_layers),\n            stage_option=ManualStageOption(forward_stage_layer_ids,\n                                           submesh_physical_shapes,\n                                           submesh_logical_shapes,\n                                           submesh_autosharding_option_dicts))\n    elif parallel_mode == \"uniform\":\n        assert isinstance(parallel_args, UniformParallelArgs)\n        (prefer_reduce_scatter, use_remat, dp, op, pp,\n         force_batch_dim_mapping) = parallel_args\n        as_option = AutoShardingOption(\n            prefer_reduce_scatter=prefer_reduce_scatter,\n            allow_mixed_mesh_shape=allow_mixed_mesh_shape,\n        )\n        if force_batch_dim_mapping:\n            as_option.force_batch_dim_to_mesh_dim = 0\n        add_manual_layer_marker = True\n        add_manual_remat = use_remat\n\n        logical_mesh_shape = (dp, op)\n        num_manual_pipeline_stages = pp\n        num_mesh_devices = np.prod(logical_mesh_shape)\n        assert num_devices_per_host is not None\n        if num_mesh_devices <= num_devices_per_host:\n            physical_mesh_shape = (1, num_mesh_devices)\n        else:\n            assert num_mesh_devices % num_devices_per_host == 0\n            physical_mesh_shape = (num_mesh_devices // num_devices_per_host,\n                                   num_devices_per_host)\n\n        method = PipeshardParallel(\n            num_micro_batches=num_micro_batches,\n            default_auto_sharding_option=as_option,\n            pipeline_schedule=pipeline_schedule,\n            layer_option=\"manual\",\n            stage_option=ManualStageOption(\n                forward_stage_layer_ids=[[i] for i in range(pp)],\n                submesh_physical_shapes=[physical_mesh_shape] * pp,\n                submesh_logical_shapes=[logical_mesh_shape] * pp,\n                submesh_autosharding_option_dicts=[{}] * pp))\n    else:\n        raise ValueError(f\"Invalid parallel mode: {parallel_mode}\")\n\n    return (method, add_manual_remat, add_manual_layer_marker,\n            num_manual_pipeline_stages)\n\n\ndef get_shard_parallel_method(benchmark_case: BenchmarkCase,\n                              physical_mesh: PhysicalDeviceMesh,\n                              logical_mesh_options: Dict[str, Any] = None):\n    \"\"\"Create the parallel method of a benchmark case.\n\n    Args:\n        benchmark_case: The benchmark case.\n        num_devices_per_host: The number of devices per host, used in uniform\n          parallel mode.\n        allow_mixed_mesh_shape: Whether to allow the mixed mesh shape in\n          the autosharding pass.\n    \"\"\"\n    print_used_time(None)\n\n    num_micro_batches = benchmark_case.num_micro_batches\n    parallel_mode = benchmark_case.parallel_mode\n    parallel_args = benchmark_case.parallel_args\n\n    if isinstance(parallel_args, ShardParallelArgs):\n        (prefer_reduce_scatter, use_remat, logical_mesh_shape,\n         force_batch_dim_mapping) = parallel_args\n    elif isinstance(parallel_args, UniformParallelArgs):\n        (prefer_reduce_scatter, use_remat, dp, op, pp,\n         force_batch_dim_mapping) = parallel_args\n        assert pp == 1, \"Do not support pipeline parallelism for shard parallel\"\n        logical_mesh_shape = (dp, op)\n    else:\n        raise ValueError(f\"Unsupported parallel mode: {parallel_mode}\")\n\n    # Parallel configs\n    if num_micro_batches > 1:\n        grad_func = alpa.grad\n    else:\n        num_micro_batches = None\n        grad_func = jax.grad\n\n    as_option = AutoShardingOption()\n    if force_batch_dim_mapping:  # Always map batch dim to mesh dim 0\n        as_option.force_batch_dim_to_mesh_dim = 0\n    as_option.prefer_reduce_scatter = prefer_reduce_scatter\n    if parallel_mode == \"zero-3\":\n        as_option.force_zero_stage_3 = True\n    elif parallel_mode in [\"shard-largest\"]:\n        as_option.force_simple_heuristic = \"largest\"\n\n    if logical_mesh_options is None:\n        logical_mesh_options = {}\n    logical_mesh = physical_mesh.get_logical_mesh(logical_mesh_shape,\n                                                  **logical_mesh_options)\n    method = ShardParallel(devices=logical_mesh,\n                           num_micro_batches=num_micro_batches,\n                           auto_sharding_option=as_option)\n    print_used_time(\"Setup device mesh\")\n\n    return method, grad_func\n\n\ndef benchmark_training_executable(niter,\n                                  train_step,\n                                  executable,\n                                  state,\n                                  other_train_step_inputs,\n                                  profile_driver_time=False):\n    print_used_time(None)\n\n    # Benchmark step time\n    warmup = 2 if niter >= 5 else 1\n\n    if profile_driver_time:\n        # Benchmark latency with driver overhead\n        global_config.use_dummy_value_for_benchmarking = False\n        global_config.shard_parallel_sync_for_timer = False\n        print(\"Warmup\")\n        for i in range(warmup):\n            state = train_step(state, *other_train_step_inputs)\n        executable.sync()\n        niter -= warmup\n        print(\"Benchmark\")\n        tic = time.time()\n        for i in range(niter):\n            state = train_step(state, *other_train_step_inputs)\n        executable.sync()\n        e2e_latency = (time.time() - tic) / niter\n        latencies = [e2e_latency]\n        print(f\"latency with driver overhead: {e2e_latency:.3f}\")\n    else:\n        # Benchmark latency without driver overhead\n        for i in range(niter):\n            print(f\"Iteration {i} ...\")\n            state = train_step(state, *other_train_step_inputs)\n            if isinstance(state, tuple):\n                # In case the train_step returns extra info (e.g. loss),\n                # Get the actual state out.\n                state = state[0]\n            executable.sync()\n\n        latencies = executable.get_execution_time_costs()[warmup:]\n\n    print_used_time(\"Benchmark\")\n\n    return latencies\n\n\ndef benchmark_inference_executable(niter,\n                                   infer_step,\n                                   executable,\n                                   params,\n                                   other_infer_step_inputs,\n                                   profile_driver_time=False):\n    print_used_time(None)\n\n    # Benchmark step time\n    warmup = 2 if niter >= 5 else 1\n\n    if profile_driver_time:\n        # Benchmark latency with streaming\n        for i in range(warmup):\n            _ = infer_step(params, *other_infer_step_inputs)\n        executable.sync()\n        niter -= warmup\n\n        # Benchmark latency\n        losses = []\n        start_time = time.time()\n        latencies = []\n        for i in range(niter):\n            print(f\"Iteration {i} ...\")\n            loss = infer_step(params, *other_infer_step_inputs)\n            loss.prefetch()\n            losses.append(loss)\n        for i, loss in enumerate(losses):\n            _ = loss._value\n            end_time = time.time()\n            latencies.append(end_time - start_time)\n            start_time = end_time\n    else:\n        for i in range(niter):\n            print(f\"Iteration {i} ...\")\n            _ = infer_step(params, *other_infer_step_inputs)\n            executable.sync()\n\n        latencies = executable.get_execution_time_costs()[warmup:]\n\n    print_used_time(\"Benchmark\")\n\n    return latencies\n\n\ndef compile_pipeshard_executable(parallel_mode, train_step, state,\n                                 other_train_step_inputs):\n    print_used_time(None)\n\n    executable = train_step.get_executable(state, *other_train_step_inputs)\n    print_used_time(\"Compile (driver)\")\n\n    if parallel_mode == \"search\":\n        compilation_times = {\n            k: timers(k).elapsed(mode=\"sum\") for k in [\n                \"stage-construction\", \"stage-construction-dp\",\n                \"stage-construction-compilation\", \"stage-construction-profiling\"\n            ]\n        }\n        print(\n            f\"compilation time breakdown: {to_str_round(compilation_times, 2)}\")\n    else:\n        compilation_times = None\n\n    executable.dump_debug_info(\"tmp\")\n    executable.sync()\n    print_used_time(\"Compile (worker)\")\n    return executable, compilation_times\n\n\ndef compile_shard_executable(physical_mesh, train_step, state,\n                             other_train_step_inputs):\n    print_used_time(None)\n    executable = train_step.get_executable(state, *other_train_step_inputs)\n    print_used_time(\"Compile (driver)\")\n\n    physical_mesh.sync_workers()\n    print_used_time(\"Compile (workers)\")\n\n    # Check sharding strategy\n    alloc_mem = executable.get_total_allocation_size()\n    ilp_objective = executable.auto_sharding_objective or 0.0\n    executable.dump_debug_info(\"tmp\")\n    hlo_text = executable.get_hlo_text()\n    (n_total, n_all_reduce, n_all_gather, n_reduce_scatter,\n     n_all_to_all) = count_communication_primitives(hlo_text)\n\n    print(f\"#total: {n_total}, #all-reduce: {n_all_reduce}, \"\n          f\"#all-gather: {n_all_gather}, #reduce-scatter: {n_reduce_scatter}, \"\n          f\"#all-to-all: {n_all_to_all}\")\n    print(f\"alloc_mem: {alloc_mem / GB:.2f} GB\")\n    return executable, ilp_objective, alloc_mem\n\n\ndef compile_and_benchmark_pipeshard_training_executable(\n        parallel_mode,\n        niter,\n        train_step,\n        state,\n        other_train_step_inputs,\n        profile_driver_time=False):\n    executable, compilation_times = compile_pipeshard_executable(\n        parallel_mode, train_step, state, other_train_step_inputs)\n    latencies = benchmark_training_executable(\n        niter,\n        train_step,\n        executable,\n        state,\n        other_train_step_inputs,\n        profile_driver_time=profile_driver_time)\n    max_mem_allocated = executable.mesh_group.get_max_memory_allocated()\n\n    return latencies, max_mem_allocated, compilation_times, executable\n\n\ndef compile_and_benchmark_shard_training_executable(physical_mesh,\n                                                    niter,\n                                                    train_step,\n                                                    state,\n                                                    other_train_step_inputs,\n                                                    profile_driver_time=False):\n    executable, ilp_objective, alloc_mem = compile_shard_executable(\n        physical_mesh, train_step, state, other_train_step_inputs)\n    latencies = benchmark_training_executable(\n        niter,\n        train_step,\n        executable,\n        state,\n        other_train_step_inputs,\n        profile_driver_time=profile_driver_time)\n    peak_mem = max(physical_mesh.get_max_memory_allocated(), alloc_mem)\n    return latencies, ilp_objective, peak_mem, executable\n\n\ndef compile_and_benchmark_pipeshard_inference_executable(\n        parallel_mode,\n        niter,\n        infer_step,\n        params,\n        other_inference_step_inputs,\n        profile_driver_time=False):\n    executable, compilation_times = compile_pipeshard_executable(\n        parallel_mode, infer_step, params, other_inference_step_inputs)\n\n    # Preshard params\n    executable.mesh_group.reset_memory_stats()\n    params_ps = executable.get_input_placement_specs()[0]\n    flat_params, in_tree = tree_flatten(params)\n    flat_ps = tree_leaves(params_ps)\n    params = tree_unflatten(\n        in_tree,\n        executable.mesh_group.shard_args_to_arrays(flat_ps, flat_params))\n    print_used_time(\"Preshard (driver)\")\n    per_stage_weight_mem = executable.mesh_group.get_max_memory_allocated_per_mesh(\n    )\n\n    latencies = benchmark_inference_executable(\n        niter,\n        infer_step,\n        executable,\n        params,\n        other_inference_step_inputs,\n        profile_driver_time=profile_driver_time)\n    max_mem_allocated = executable.mesh_group.get_max_memory_allocated()\n    per_stage_peak_mem = executable.mesh_group.get_max_memory_allocated_per_mesh(\n    )\n\n    return latencies, max_mem_allocated, compilation_times, executable, per_stage_weight_mem, per_stage_peak_mem\n\n\ndef compute_avg_stage_latencies(timelines: List[tuple]):\n    stage_latencies = []\n    for request_timeline in timelines:\n        sorted_timeline = sorted(request_timeline, key=lambda x: x[0])\n        stage_borders = [sorted_timeline[0][0]]\n        for _, e, _, _ in sorted_timeline:\n            stage_borders.append(e)\n        stage_latency = [\n            stage_borders[i + 1] - stage_borders[i]\n            for i in range(len(stage_borders) - 1)\n        ]\n        stage_latencies.append(stage_latency)\n    return np.mean(stage_latencies, axis=0)\n"
  },
  {
    "path": "benchmark/alpa/gather_gpu_stat.py",
    "content": "\"\"\"Gather gpu utilization from all nodes.\"\"\"\n\nimport os\nimport tempfile\n\nimport gpustat\nimport ray\n\n\ndef call_nvidia_smi():\n    gpus = gpustat.new_query().gpus\n    return [g.utilization for g in gpus]\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\")\n\n    host_info = []\n    for node in ray.nodes():\n        for key in node[\"Resources\"]:\n            if key.startswith(\"node:\"):\n                host_info.append(node)\n\n    results = []\n    for i in range(len(host_info)):\n        # Launch a ray actor\n        node_resource = \"node:\" + host_info[i][\"NodeManagerAddress\"]\n        func = ray.remote(resources={node_resource: 1e-3})(call_nvidia_smi)\n        results.append(func.remote())\n    results = ray.get(results)\n\n    for i in range(len(host_info)):\n        print(host_info[i][\"NodeManagerAddress\"])\n        print(results[i])\n"
  },
  {
    "path": "benchmark/alpa/gen_prof_database.py",
    "content": "\"\"\"Generate the profiling result database.\n\nUsage:\nAWS p3.16:\npython3 gen_prof_database.py --max-comm-size-intra-node 32 --max-comm-size-inter-node 29\n\nAWS p4.24:\npython3 gen_prof_database.py --efa --max-comm-size-intra-node 33 --max-comm-size-inter-node 30 --max-fail-retry 8\n\"\"\"\n\nimport ray\nimport argparse\n\nimport jax\nimport alpa\nfrom alpa import DeviceCluster, ProfilingResultDatabase, global_config\nfrom alpa.util import run_cmd\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--cluster-key\", type=str, default=\"default\")\n    parser.add_argument(\"--efa\", action=\"store_true\")\n    parser.add_argument(\"--filename\",\n                        type=str,\n                        default=\"prof_database.pkl\",\n                        help=\"The filename of the output database\")\n    parser.add_argument(\"--max-comm-size-intra-node\",\n                        type=int,\n                        required=True,\n                        help=\"Run profiling for communication up to 2^x bytes \"\n                        \"within a node, where x is this argument\")\n    parser.add_argument(\"--max-comm-size-inter-node\",\n                        type=int,\n                        required=True,\n                        help=\"Run profiling for communication up to 2^x bytes \"\n                        \"cross nodes, where x is this argument\")\n    parser.add_argument(\n        \"--cache-filename\",\n        type=str,\n        default=\"/home/ubuntu/efs/alpa/benchmark/alpa/tmp/hlo_op_cost_dict.pkl\",\n        help=\"The filename of the temporary cache. This should be an \"\n        \"absolute path on a network file system that can be accessed by \"\n        \"ray workers on all nodes.\")\n    parser.add_argument(\"--max-fail-retry\", type=int, default=5)\n    args = parser.parse_args()\n\n    run_cmd(\"mkdir -p tmp\")\n    if args.efa:\n        global_config.use_aws_efa = True\n\n    # Initialize a useless jax GPU backend in the driver script.\n    # This GPU backend takes 300MB GPU memory to store the CUDA context.\n    # This simulates the environment of our benchmark scripts and\n    # can make the profiling of available memory more accurate.\n    # TODO(lmzheng): Modify jax so it does not allocate this useless CUDA context.\n    jax.config.update('jax_platform_name', 'cpu')\n    _ = jax.numpy.ones(1)\n\n    # Connect to a ray cluster\n    alpa.init(cluster=\"ray\")\n    cluster = alpa.get_global_cluster()\n\n    prof_database = cluster.profile_all(args.cluster_key,\n                                        args.max_comm_size_intra_node,\n                                        args.max_comm_size_inter_node,\n                                        max_fail_retry=args.max_fail_retry,\n                                        cache_filename=args.cache_filename,\n                                        dot_range=range(0, 8192, 128))\n    prof_database.save(args.filename)\n    print(f\"Save profiling database to {args.filename}\")\n"
  },
  {
    "path": "benchmark/alpa/gen_serving_database.py",
    "content": "\"\"\"\nUsage:\npython3 run_exp.py gpt_inference\npython3 gen_serving_database.py\n\"\"\"\n\nimport argparse\n\nfrom alpa_serve.profiling import ProfilingDatabase\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--input\", type=str, default=\"inference_prof_res.tsv\")\n    parser.add_argument(\"--output\", type=str, default=\"profiling_result.pkl\")\n    parser.add_argument(\"--new\", action=\"store_true\")\n    args = parser.parse_args()\n\n    database = ProfilingDatabase(args.output, args.new)\n    database.update_from_csv(args.input)\n    database.materialize()\n"
  },
  {
    "path": "benchmark/alpa/inspect_prof_database.py",
    "content": "\"\"\"Inspect and edit a profiling database.\"\"\"\nimport argparse\n\nfrom alpa import DeviceCluster, ProfilingResultDatabase\nfrom alpa.util import run_cmd\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--filename\", type=str, default=\"prof_database.pkl\")\n    args = parser.parse_args()\n\n    prof_database = ProfilingResultDatabase()\n    prof_database.load(args.filename)\n\n    # Do some editing\n    #prof_database.insert_dummy_mesh_result(\"default\", (8, 8))\n    #prof_database.save(args.filename)\n\n    # Print results\n    print(\"Meshes:\")\n    print(list(prof_database.data.keys()))\n    print()\n\n    mesh_result = prof_database.query(\"default\", (2, 8))\n    print(mesh_result)\n"
  },
  {
    "path": "benchmark/alpa/resharding/README.md",
    "content": "# Benchmark\nThis folder contains benchmarking code for cross mesh resharding, corresponding to the experiment section in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). \n\nTo make the benchmark feasible in a short amount of time, this documentation provides: Instructions for benchmarking on an AWS p3.8xlarge cluster. You can use these to quickly run cross mesh resharding using Alpa and get the statistics of the performance. The statistics may be different from that in our papaer if your cluster is not an AWS p3.8xlarge cluster. \nThere are two types of experiments for benchmarking:\n- Single device to multiple devices microbenchmark: corronspond to section 5.1.1 in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). \n- Multiple devices to multiple devices microbenchmark: corronspond to section 5.1.2 and 5.3.1 in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322). \n\n## Benchmark Steps\n\n### Cluster Preparation\n\nPrepare 5 AWS p3.8xlarge instances and put them in the same Placement Group. \n\n### Start a Ray Cluster\nAlpa uses a distributed framework Ray to manage the cluster and distributed workers.\nHere, we provide instructions for manually launching a ray cluster.\nYou can also refer to the Ray [documentation](https://docs.ray.io/en/latest/cluster/quickstart.html#) for more methods on launching and managing ray clusters. \n\n1. Pick one node as the head node and run the command below on it\n    ```\n    ray start --head\n    ```\n2. For all other 4 nodes, connect them to the head node following the instructions printed by the previous command. \n    ```\n    # The command should look like this, but with the ip address and password printed by the previous command. \n    ray start --address='172.31.31.37:6379' --redis-password='5241590000000000'\n    ```\n\nYou can check the cluster status by \n```\nray status\n```\nYou should be able to see the number of CPUs and GPUs available on your cluster. We should have 20 GPUs to proceed. \nAll nodes should have alpa installed.\n\n### Single device to multiple devices microbenchmark\nRun all benchmark tests with all GPUs in your cluster. \n```\npython3 benchmark.py --suite 1-to-m\n```\nThe result will be saved in `tmp/1_to_m_result.json`. In this set of experiment, the sender mesh has only 1 GPU. We vary the number of GPUs in the receiver mesh. In the first half of benchmark tests, the receiver mesh has 1 node and the number of GPUs in this node varies from 1 to 4. In the second half of benchmark tests, the number of GPUs per node is fixed at 2, but the number of nodes in receiver mesh grows from 1 to 4. For more details, please refer to `perf_1_to_m_suite` in `suite.py`.\n\nIf you only want to run one test case,\n```\npython3 benchmark_cross_mesh_resharding.py --suite 1-to-m --n-nodes 1 --gpu-per-node 4 --resharding-mode send_recv --resharding-loadbalance-mode normal\n```\nHere, I take dst mesh to be (1, 4) as example and you could also choose other cases.\nYou could use `--resharding-mode`, `--resharding-loadbalance-mode`, `--use-local-allgather` flags \nto specify the configurations for cross mesh resharding. \n\n### Multiple devices to multiple devices microbenchmark\nSimilar to the previous subsection. \n```\npython3 benchmark.py --suite n-to-m\n```\nThe result will be saved in `tmp/n_to_m_result.json`. In this set of experiment, we move to more complicated cases where both the sender mesh and receiver mesh have multiple nodes. For more details, please refer to `perf_n_to_m_suite` in `suite.py`.\n\nIf you only want to run one test case,\n```\npython3 benchmark_cross_mesh_resharding.py --suite n-to-m --case case1 --resharding-mode send_recv --resharding-loadbalance-mode normal\n```\nHere, I take case1 as example and you could choose other cases by referring to `suite.py`. Same as above, you could \nspecify the configurations for cross mesh resharding.\n\n## Result\n\nBy using the above benchmark scripts, you could compare the time spent among different resharding configurations.\nAnd then we could see conclusions in [On Optimizing the Communication of Model Parallelism](https://arxiv.org/abs/2211.05322) from \nthese statistics.\n"
  },
  {
    "path": "benchmark/alpa/resharding/benchmark.py",
    "content": "\"\"\"The entry point of intra-op + inter-op parallelism benchmark.\"\"\"\nimport argparse\nimport json\nimport multiprocessing as mp\nimport os\nimport time\n\nfrom benchmark_cross_mesh_resharding import benchmark_one_case_internal\nimport suite\n\n\ndef benchmark_and_write_to_namespace(result_namespace, *args, **kwargs):\n    result = benchmark_one_case_internal(*args, **kwargs)\n    result_namespace.result = result\n\n\ndef benchmark_one_case(*args, use_separate_process=False, **kwargs):\n    if not use_separate_process:\n        return benchmark_one_case_internal(*args, **kwargs)\n    ctx = mp.get_context(\"spawn\")\n    manager = ctx.Manager()\n    result_namespace = manager.Namespace()\n    p = ctx.Process(target=benchmark_and_write_to_namespace,\n                    args=(result_namespace, *args),\n                    kwargs=kwargs)\n    p.start()\n    p.join()\n    if p.exitcode != 0:\n        return -1, -1, [-1], -1, None\n    return result_namespace.result\n\n\ndef benchmark_n_to_m_suite():\n    os.makedirs(\"tmp\", exist_ok=True)\n\n    result_file = \"tmp/n_to_m_result.json\"\n    result = []\n\n    benchmark_cases = suite.perf_n_to_m_suite\n    resharding_config_list = suite.resharding_n_to_m_configs\n\n    # Run all cases\n    for case_name, benchmark_case in benchmark_cases.items():\n        # Run one case\n        for config in resharding_config_list:\n            print(\"Working on {}: {}, config: {}\".format(\n                case_name, str(benchmark_case), str(config)))\n            one_result = benchmark_one_case(\n                benchmark_case.src_mesh_shape, benchmark_case.dst_mesh_shape,\n                benchmark_case.src_sharding_spec,\n                benchmark_case.dst_sharding_spec, benchmark_case.tensor_shape,\n                config[\"resharding_mode\"], config[\"use_local_allgather\"],\n                config[\"resharding_loadbalance_mode\"])\n\n            print(one_result)\n            result.append(one_result)\n            json.dump(result, open(result_file, \"w\"), indent=4)\n\n            time.sleep(0.1)  # for ctrl+c to work\n\n\ndef benchmark_1_to_m_suite():\n    os.makedirs(\"tmp\", exist_ok=True)\n\n    result_file = \"tmp/1_to_m_result.json\"\n    result = []\n\n    benchmark_cases = suite.perf_1_to_m_suite\n    resharding_config_list = suite.resharding_1_to_m_configs\n\n    # Run all cases\n    for case_name, benchmark_case in benchmark_cases.items():\n        # Run one case\n        for config in resharding_config_list:\n            print(\"Working on {}: {}, config: {}\".format(\n                case_name, str(benchmark_case), str(config)))\n            one_result = benchmark_one_case(\n                benchmark_case.src_mesh_shape, benchmark_case.dst_mesh_shape,\n                benchmark_case.src_sharding_spec,\n                benchmark_case.dst_sharding_spec, benchmark_case.tensor_shape,\n                config[\"resharding_mode\"], config[\"use_local_allgather\"],\n                config[\"resharding_loadbalance_mode\"])\n            print(one_result)\n            result.append(one_result)\n            json.dump(result, open(result_file, \"w\"), indent=4)\n\n            time.sleep(0.1)  # for ctrl+c to work\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--suite\",\n                        choices=[\"1-to-m\", \"n-to-m\"],\n                        type=str,\n                        required=True)\n    args = parser.parse_args()\n\n    if args.suite == \"1-to-m\":\n        benchmark_1_to_m_suite()\n    else:\n        benchmark_n_to_m_suite()\n"
  },
  {
    "path": "benchmark/alpa/resharding/benchmark_cross_mesh_resharding.py",
    "content": "\"\"\"Test cross-mesh resharding.\"\"\"\nimport argparse\n\nfrom jax import xla\nfrom jax.core import Var\nfrom jax._src.abstract_arrays import ShapedArray\nfrom jax.interpreters.pxla import spec_to_indices\nimport jax.numpy as jnp\nimport numpy as np\nimport ray\n\nfrom alpa import init\nfrom alpa.device_mesh import (create_remote_array_refs,\n                              get_global_virtual_physical_mesh)\nfrom alpa.mesh_executable import next_mesh_executable_uuid\nfrom alpa.global_env import global_config\nfrom alpa.pipeline_parallel.runtime_emitter import PipelineInstEmitter\nfrom alpa.pipeline_parallel.cross_mesh_resharding import (\n    CollectiveGroup, ReshardingTaskSpec, CrossMeshCommunicator,\n    SymbolicReshardingTask, SymbolicBroadcastReshardingTask)\nfrom alpa.pipeline_parallel.pipeshard_executable import (\n    AllocateZeroWorkerExecutableConfig, PipelineInstruction,\n    PipeshardMeshWorkerExecutable)\nfrom alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray\nfrom alpa.util import get_shard_shape\nfrom alpa.timer import timers\n\nimport suite\n\n\ndef get_device_meshes(src_mesh_shape, dst_mesh_shape):\n    virtual_mesh = get_global_virtual_physical_mesh()\n    src_num_host = src_mesh_shape[0]\n    dst_num_host = dst_mesh_shape[0]\n    assert virtual_mesh.num_hosts >= src_num_host+dst_num_host,\\\n        \"Error: There are not enough nodes for this test case\"\n    src_mesh = virtual_mesh.slice_2d(range(src_num_host),\n                                     [range(src_mesh_shape[1])] *\n                                     src_num_host).get_physical_mesh()\n    dst_host_indices = range(src_num_host, src_num_host + dst_num_host)\n    dst_device_indices = [range(dst_mesh_shape[1])] * dst_num_host\n    dst_mesh = virtual_mesh.slice_2d(dst_host_indices,\n                                     dst_device_indices).get_physical_mesh()\n    return src_mesh, dst_mesh\n\n\ndef get_mean_and_variance(results):\n    assert len(results) == 13\n    results = results[3:]\n    mean = np.mean(results)\n    var = np.var(results)\n    return mean, var\n\n\ndef benchmark_one_case_internal(\n    src_mesh_shape,\n    dst_mesh_shape,\n    src_sharding_spec,\n    dst_sharding_spec,\n    tensor_shape,\n    resharding_mode=\"send_recv\",\n    use_local_allgather=True,\n    resharding_loadbalance_mode=\"normal\",\n):\n\n    global_config.resharding_mode = resharding_mode\n    global_config.resharding_loadbalance_mode = resharding_loadbalance_mode\n    global_config.use_local_allgather = use_local_allgather\n\n    init(cluster=\"ray\")\n\n    src_mesh, dst_mesh = get_device_meshes(src_mesh_shape, dst_mesh_shape)\n\n    var = Var(0, \"\", ShapedArray(tensor_shape, jnp.int32))\n\n    # Resharding task spec and send/recv strategy\n    src_loads = {src: 0 for src in src_mesh.device_strs}\n    dst_loads = {dst: 0 for dst in dst_mesh.device_strs}\n    if resharding_mode == \"send_recv\":\n        rewrite_dst_sharding_spec = CrossMeshCommunicator._rewrite_allgather_spec(\n            dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape)\n    else:\n        rewrite_dst_sharding_spec = dst_sharding_spec\n    src_array = VirtualDistributedArray(device_mesh=src_mesh,\n                                        aval=var.aval,\n                                        sharding_spec=src_sharding_spec)\n    dst_array = VirtualDistributedArray(device_mesh=dst_mesh,\n                                        aval=var.aval,\n                                        sharding_spec=rewrite_dst_sharding_spec)\n    task_spec = ReshardingTaskSpec(src_array, dst_array, dst_sharding_spec)\n\n    if resharding_mode == \"send_recv\":\n        if global_config.resharding_loadbalance_mode == \"normal\":\n            strategy = (CrossMeshCommunicator.\n                        _generate_send_recv_resharding_strategy_by_loads(\n                            task_spec, src_loads, dst_loads))\n        elif global_config.resharding_loadbalance_mode == \"no_loadbalance\":\n            strategy = (\n                CrossMeshCommunicator.\n                _generate_send_recv_resharding_strategy_by_no_load(task_spec))\n        elif global_config.resharding_loadbalance_mode in [\n                \"loadbalance_size\", \"loadbalance_order\"\n        ]:\n            strategy = (CrossMeshCommunicator.\n                        _generate_send_recv_resharding_strategy_by_loadbalance(\n                            task_spec, src_mesh, dst_mesh))\n    else:\n        if global_config.resharding_loadbalance_mode == \"normal\":\n            strategy = (CrossMeshCommunicator.\n                        _generate_broadcast_resharding_strategy_by_loads(\n                            task_spec, src_loads, dst_loads))\n        elif global_config.resharding_loadbalance_mode == \"no_loadbalance\":\n            strategy = (\n                CrossMeshCommunicator.\n                _generate_broadcast_resharding_strategy_by_no_load(task_spec))\n        elif global_config.resharding_loadbalance_mode in [\n                \"loadbalance_size\", \"loadbalance_order\"\n        ]:\n            strategy = (CrossMeshCommunicator.\n                        _generate_broadcast_resharding_strategy_by_loadbalance(\n                            task_spec, src_mesh, dst_mesh))\n\n    task_spec.set_resharding_strategy(strategy)\n\n    # Resharding task. Compile send/recv from strategy and allgather.\n    collective_group = CollectiveGroup(task_spec.get_participant_device_strs(),\n                                       src_mesh, dst_mesh)\n    if global_config.eagerly_create_communicators:\n        collective_group.instantiate_now()\n    else:\n        collective_group.instantiate()\n    if resharding_mode == \"send_recv\":\n        task = SymbolicReshardingTask(task_spec, collective_group, src_mesh,\n                                      dst_mesh)\n    else:\n        task = SymbolicBroadcastReshardingTask(task_spec, collective_group,\n                                               src_mesh, dst_mesh)\n\n    if global_config.eagerly_create_communicators:\n        task.create_resharding_communicators()\n\n    # Compile pipeline instructions\n    instruction_lists = {worker: [] for worker in src_mesh.workers}\n    for worker in dst_mesh.workers:\n        instruction_lists[worker] = []\n    executable_config_lists = {worker: [] for worker in dst_mesh.workers}\n    src_uuid = 21474\n    dst_uuid = 21475\n    # allocate the buffer\n    exec_uuid = next_mesh_executable_uuid()\n    config = AllocateZeroWorkerExecutableConfig(\n        exec_uuid, [get_shard_shape(var.aval, rewrite_dst_sharding_spec)],\n        [var.aval.dtype])\n    output_uuids = [dst_uuid]\n    for worker in dst_mesh.workers:\n        executable_config_lists[worker].append(config)\n        in_uuids = []\n        out_uuids = output_uuids\n        instruction_lists[worker].append(\n            PipelineInstruction.run(config.exec_uuid,\n                                    in_uuids,\n                                    out_uuids, {\n                                        \"sync_before\": False,\n                                        \"sync_after\": False\n                                    },\n                                    info=\"allocate zero for recv\"))\n    # Create resharding task\n\n    if resharding_mode == \"send_recv\":\n        PipelineInstEmitter._compile_resharding_task(src_uuid, task, dst_uuid,\n                                                     instruction_lists)\n    else:\n        PipelineInstEmitter._compile_broadcast_resharding_task(\n            src_mesh, src_uuid, task, dst_uuid, instruction_lists)\n\n    exec_uuids = {}\n\n    # Compile Pipeline Executable\n    for worker in src_mesh.workers:\n        exec_uuid = next_mesh_executable_uuid()\n        # print(worker, exec_uuid)\n        worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable,\n                                     instruction_lists[worker], [src_uuid], [],\n                                     [], [], [],\n                                     [False] * src_mesh.num_devices_per_host)\n        exec_uuids[worker] = exec_uuid\n    for worker in dst_mesh.workers:\n        exec_uuid = next_mesh_executable_uuid()\n        # print(worker, exec_uuid)\n        worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable,\n                                     instruction_lists[worker], [], [dst_uuid],\n                                     executable_config_lists[worker], [], [],\n                                     [False] * dst_mesh.num_devices_per_host)\n        exec_uuids[worker] = exec_uuid\n\n    # Prepare array and shard args\n    test_array = np.arange(np.prod(var.aval.shape),\n                           dtype=var.aval.dtype).reshape(var.aval.shape)\n    indices = spec_to_indices(var.aval.shape, src_sharding_spec)\n    test_array = xla.canonicalize_dtype(test_array)\n    input_refs = src_mesh.shard_args_to_bufs([indices], (False,), (False,),\n                                             None, [test_array])\n    input_refs = np.array(input_refs)\n    input_uuids = [ref.uuid for ref in input_refs]\n    output_refs, output_uuids = create_remote_array_refs(dst_mesh)\n\n    # Run executables\n    time_spend = []\n    for _ in range(13):\n        timers(\"overall_resharding_time\").start()\n        for worker in src_mesh.workers:\n            worker.run_executable.remote(exec_uuids[worker],\n                                         input_uuids, [],\n                                         sync_for_timer=True,\n                                         collect_trace=False)\n        for worker in dst_mesh.workers:\n            worker.run_executable.remote(exec_uuids[worker], [],\n                                         output_uuids,\n                                         sync_for_timer=True,\n                                         collect_trace=False)\n\n        dst_mesh.sync_workers(sync_all_devices=True)\n        timers(\"overall_resharding_time\").stop()\n        time_spend.append(timers(\"overall_resharding_time\").elapsed(mode=\"sum\"))\n        timers(\"overall_resharding_time\").reset()\n\n    mean_time, var_time = get_mean_and_variance(time_spend)\n    result = {\n        \"src_mesh_shape\": src_mesh_shape,\n        \"dst_mesh_shape\": dst_mesh_shape,\n        \"src_sharding_spec\": str(src_sharding_spec),\n        \"dst_sharding_spec\": str(dst_sharding_spec),\n        \"tensor_shape\": tensor_shape,\n        \"resharding_mode\": resharding_mode,\n        \"use_local_allgather\": use_local_allgather,\n        \"resharding_loadbalance_mode\": resharding_loadbalance_mode,\n        \"exec_time_mean\": mean_time,\n        \"exec_time_var\": var_time\n    }\n\n    # Delete executables\n    for worker in src_mesh.workers:\n        worker.delete_executable.remote(exec_uuids[worker])\n    for worker in dst_mesh.workers:\n        worker.delete_executable.remote(exec_uuids[worker])\n\n    src_mesh.shutdown()\n    dst_mesh.shutdown()\n\n    return result\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--suite\",\n                        type=str,\n                        required=True,\n                        choices=[\"1-to-m\", \"n-to-m\"])\n    parser.add_argument(\"--case\", type=str)\n    parser.add_argument(\"--n-nodes\", type=int, default=1)\n    parser.add_argument(\"--gpu-per-node\", type=int, default=1)\n    parser.add_argument(\"--resharding-mode\",\n                        type=str,\n                        required=True,\n                        choices=[\"send_recv\", \"broadcast\"])\n    parser.add_argument(\"--resharding-loadbalance-mode\",\n                        type=str,\n                        required=True,\n                        choices=[\n                            \"normal\", \"no_loadbalance\", \"loadbalance_size\",\n                            \"loadbalance_order\"\n                        ])\n    parser.add_argument(\"--use-local-allgather\", action=\"store_true\")\n    parser.add_argument(\"--disable-tqdm\", action=\"store_true\")\n    args = parser.parse_args()\n\n    if args.suite == \"1-to-m\":\n        case = suite.perf_1_to_m_suite[(args.n_nodes, args.gpu_per_node)]\n    else:\n        case = suite.perf_n_to_m_suite[args.case]\n\n    result = benchmark_one_case_internal(\n        case.src_mesh_shape, case.dst_mesh_shape, case.src_sharding_spec,\n        case.dst_sharding_spec, case.tensor_shape, args.resharding_mode,\n        args.use_local_allgather, args.resharding_loadbalance_mode)\n    print(result)\n\n# python benchmark_cross_mesh_resharding.py --case case1 --resharding-mode broadcast --resharding-loadbalance-mode normal\n"
  },
  {
    "path": "benchmark/alpa/resharding/suite.py",
    "content": "\"\"\"Benchmark suites for cross mesh resharding microbenchmarks.\"\"\"\nfrom collections import namedtuple\nfrom jax.interpreters.pxla import (Chunked, NoSharding, Replicated, ShardedAxis,\n                                   ShardingSpec)\n\nBenchmarkCase = namedtuple(\"BenchmarkCase\", [\n    \"src_mesh_shape\", \"dst_mesh_shape\", \"tensor_shape\", \"src_sharding_spec\",\n    \"dst_sharding_spec\"\n])\n\nperf_n_to_m_suite = {\n    \"case1\":\n        BenchmarkCase(\n            (2, 4),\n            (2, 4),\n            # (1024 // 8, 1024, 512),\n            (1024, 1024, 512),\n            ShardingSpec([Chunked(\n                [2]), NoSharding(), NoSharding()],\n                         [ShardedAxis(0), Replicated(4)]),\n            ShardingSpec([Chunked(\n                [2]), NoSharding(), NoSharding()],\n                         [ShardedAxis(0), Replicated(4)]),\n        ),\n    \"case2\":\n        BenchmarkCase(\n            (2, 4),\n            (2, 4),\n            # (1024 // 8, 1024, 512),\n            (1024, 1024, 512),\n            ShardingSpec(\n                [NoSharding(), NoSharding(),\n                 NoSharding()], [Replicated(8)]),\n            ShardingSpec([Chunked(\n                [2]), NoSharding(), NoSharding()],\n                         [ShardedAxis(0), Replicated(4)]),\n        ),\n    \"case3\":\n        BenchmarkCase(\n            (2, 4),\n            (2, 4),\n            # (1024 // 8, 1024, 512),\n            (1024, 1024, 512),\n            ShardingSpec(\n                [NoSharding(), Chunked([2]),\n                 NoSharding()], [ShardedAxis(0), Replicated(4)]),\n            ShardingSpec([Chunked(\n                [2]), NoSharding(), NoSharding()],\n                         [ShardedAxis(0), Replicated(4)]),\n        ),\n    \"case4\":\n        BenchmarkCase(\n            (2, 4),\n            (2, 4),\n            # (1024 // 8, 1024, 512),\n            (1024, 1024, 512),\n            ShardingSpec(\n                [NoSharding(), Chunked([8]),\n                 NoSharding()], [ShardedAxis(0)]),\n            ShardingSpec([Chunked(\n                [8]), NoSharding(), NoSharding()], [ShardedAxis(0)]),\n        ),\n    \"case5\":\n        BenchmarkCase(\n            (2, 4),\n            (2, 4),\n            # (1024 // 8, 1024, 512),\n            (1024, 1024, 512),\n            ShardingSpec([Chunked(\n                [4]), NoSharding(), NoSharding()],\n                         [Replicated(2), ShardedAxis(0)]),\n            ShardingSpec([Chunked(\n                [2]), NoSharding(), NoSharding()],\n                         [ShardedAxis(0), Replicated(4)]),\n        ),\n    \"case6\":\n        BenchmarkCase(\n            (2, 4),\n            (3, 4),\n            # (1024*3//8, 1024, 170),\n            (1024 * 3, 1024, 170),\n            ShardingSpec([Chunked(\n                [2]), NoSharding(), NoSharding()],\n                         [ShardedAxis(0), Replicated(4)]),\n            ShardingSpec([Chunked(\n                [3]), NoSharding(), NoSharding()],\n                         [ShardedAxis(0), Replicated(4)]),\n        ),\n    \"case7\":\n        BenchmarkCase(\n            (1, 4),\n            (2, 4),\n            # (1024 // 8, 1024, 512),\n            (1024, 1024, 512),\n            ShardingSpec([Chunked(\n                [4]), NoSharding(), NoSharding()], [ShardedAxis(0)]),\n            ShardingSpec(\n                [NoSharding(), NoSharding(),\n                 NoSharding()], [Replicated(4)]),\n        ),\n    \"case8\":\n        BenchmarkCase(\n            (1, 4),\n            (2, 4),\n            # (1024 // 8, 1024, 512),\n            (1024, 1024, 512),\n            ShardingSpec([Chunked(\n                [4]), NoSharding(), NoSharding()], [ShardedAxis(0)]),\n            ShardingSpec(\n                [NoSharding(), NoSharding(),\n                 NoSharding()], [Replicated(4)]),\n        ),\n    \"case9\":\n        BenchmarkCase(\n            (2, 4),\n            (2, 4),\n            # (1024 // 8, 1024, 512),\n            (1024, 1024, 512),\n            ShardingSpec(\n                [NoSharding(), Chunked([2]),\n                 NoSharding()], [ShardedAxis(0), Replicated(4)]),\n            ShardingSpec(\n                [NoSharding(), NoSharding(),\n                 Chunked([2])], [ShardedAxis(0), Replicated(4)]),\n        ),\n}\n\nresharding_n_to_m_configs = [\n    {\n        \"resharding_mode\": \"send_recv\",\n        \"resharding_loadbalance_mode\": \"normal\",\n        \"use_local_allgather\": False\n    },\n    {\n        \"resharding_mode\": \"send_recv\",\n        \"resharding_loadbalance_mode\": \"normal\",\n        \"use_local_allgather\": True\n    },\n    {\n        \"resharding_mode\": \"broadcast\",\n        \"resharding_loadbalance_mode\": \"no_loadbalance\",\n        \"use_local_allgather\": False\n    },\n    {\n        \"resharding_mode\": \"broadcast\",\n        \"resharding_loadbalance_mode\": \"loadbalance_size\",\n        \"use_local_allgather\": False\n    },\n    {\n        \"resharding_mode\": \"broadcast\",\n        \"resharding_loadbalance_mode\": \"loadbalance_order\",\n        \"use_local_allgather\": False\n    },\n]\n\nperf_1_to_m_suite = {(n_node, gpu_per_node): BenchmarkCase(\n    (1, 1),\n    (n_node, gpu_per_node),\n    (1 << 28,),\n    ShardingSpec([NoSharding()], [Replicated(1)]),\n    ShardingSpec([NoSharding()], [Replicated(n_node * gpu_per_node)]),\n) for n_node, gpu_per_node in [(1, 1), (1, 2), (1, 3), (1, 4), (2,\n                                                                2), (3,\n                                                                     2), (4, 2)]\n                    }\n\nresharding_1_to_m_configs = [\n    {\n        \"resharding_mode\": \"send_recv\",\n        \"resharding_loadbalance_mode\": \"normal\",\n        \"use_local_allgather\": False\n    },\n    {\n        \"resharding_mode\": \"send_recv\",\n        \"resharding_loadbalance_mode\": \"normal\",\n        \"use_local_allgather\": True\n    },\n    {\n        \"resharding_mode\": \"broadcast\",\n        \"resharding_loadbalance_mode\": \"normal\",\n        \"use_local_allgather\": False\n    },\n]\n"
  },
  {
    "path": "benchmark/alpa/run_exp.py",
    "content": "\"\"\"Run search experiments with mutliple cluster settings.\"\"\"\nimport argparse\nfrom datetime import datetime\nimport os\nimport subprocess\nimport sys\n\nfrom benchmark import benchmark_suite\n\n\ndef run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=None):\n    os.environ[\"PYTHONUNBUFFERED\"] = \"1\"\n    now = datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n\n    tee = subprocess.Popen([\"tee\", f\"{now}_{suite_name}.log\"],\n                           stdin=subprocess.PIPE)\n    os.dup2(tee.stdin.fileno(), sys.stdout.fileno())\n    os.dup2(tee.stdin.fileno(), sys.stderr.fileno())\n\n    benchmark_settings = benchmark_settings or {}\n\n    for num_hosts, num_devices_per_host in cluster_settings:\n        num_gpus = num_hosts * num_devices_per_host\n        if exp_name is None:\n            exp_name = f\"{now}_{suite_name}_{num_gpus}_gpus\"\n        benchmark_suite(suite_name,\n                        num_hosts,\n                        num_devices_per_host,\n                        exp_name=exp_name,\n                        disable_tqdm=True,\n                        **benchmark_settings)\n\n\nmodel_search_suites = {\n    \"gpt\": (\"gpt.grid_search_auto\", {}),\n    \"moe\": (\"moe.grid_search_auto\", {}),\n    \"wresnet\": (\"wresnet.grid_search_auto\", {}),\n    \"gpt_inference\": (\"gpt_inference.profile\", {\n        \"niter\": 10,\n        \"profile_stage_execution_time\": True\n    }),\n    \"moe_inference\": (\"moe_inference.profile\", {\n        \"niter\": 10,\n        \"profile_stage_execution_time\": True\n    }),\n    \"gpt_no_embedding_inference\": (\"gpt_no_embedding_inference.profile\", {}),\n    \"gpt_inference_streaming\": (\"gpt_inference.profile\", {\n        \"profile_driver_time\": True\n    }),\n}\ncluster_settings = [(8, 8), (4, 8), (3, 8), (2, 8), (1, 8), (1, 4), (1, 2),\n                    (1, 1)]\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"suite\", type=str, choices=model_search_suites.keys())\n    parser.add_argument(\"--exp-name\", type=str, default=None)\n    args = parser.parse_args()\n    run_exp(args.exp_name, cluster_settings, *model_search_suites[args.suite])\n"
  },
  {
    "path": "benchmark/alpa/suite_auto_gpt.py",
    "content": "\"\"\"Benchmark suites for gpt with auto parallelization.\"\"\"\nfrom suite_manual_gpt import gpt_specs\nfrom benchmark_parallel_utils import (BenchmarkCase, SearchParallelArgs,\n                                      LoadSolutionParallelArgs)\n\nmax_global_batch_size = 1024\n\nauto_stage_option = {\n    \"submesh_physical_shape_space\": \"small_power_of_two\",\n    \"submesh_logical_shape_space\": \"all\",\n    \"stage_imbalance_tolerance\": 1.0,\n    \"use_hlo_cost_model\": True,\n    \"profiling_database_filename\": \"prof_database.pkl\",\n}\n\nprefer_reduce_scatter = True\nuse_remat = True\n\n\ndef get_search_cases(model_spec, num_micro_batches_list, num_auto_layers_list):\n    return [\n        BenchmarkCase(\n            max_global_batch_size, model_spec, num_micro_batches, \"search\",\n            SearchParallelArgs(prefer_reduce_scatter, use_remat,\n                               num_auto_layers, auto_stage_option))\n        for num_micro_batches in num_micro_batches_list\n        for num_auto_layers in num_auto_layers_list\n    ]\n\n\ndef get_solution_case(model_spec, num_micro_batches, num_auto_layers,\n                      forward_stage_layer_ids, submesh_physical_shapes,\n                      submesh_logical_shapes,\n                      submesh_autosharding_option_dicts):\n    return [\n        BenchmarkCase(\n            max_global_batch_size, model_spec, num_micro_batches,\n            \"load_solution\",\n            LoadSolutionParallelArgs(prefer_reduce_scatter, use_remat,\n                                     num_auto_layers, forward_stage_layer_ids,\n                                     submesh_physical_shapes,\n                                     submesh_logical_shapes,\n                                     submesh_autosharding_option_dicts))\n    ]\n\n\nforce_dp_dict = {\"force_batch_dim_to_mesh_dim\": 0}\n\n# Temporary debug suite\ntmp_suite = {}\n\n# Performance test with search solutions found for p3.16xlarge\nperf_test_suite = {\n    1:\n        get_solution_case(gpt_specs[\"350M\"], 512, 1, [[0]], [(1, 1)], [(1, 1)],\n                          [{}]),\n    2:\n        get_solution_case(gpt_specs[\"760M\"], 128, 6, [[0, 1, 2], [3, 4, 5]],\n                          [(1, 1)] * 2, [(1, 1)] * 2, [force_dp_dict] * 2),\n    4:\n        get_solution_case(gpt_specs[\"1.3B\"], 128, 6, [[0, 1, 2], [3, 4, 5]],\n                          [(1, 2)] * 2, [(2, 1)] * 2, [force_dp_dict] * 2),\n    8:\n        get_solution_case(gpt_specs[\"2.6B\"], 128,\n                          8, [[0, 1], [2, 3], [4, 5, 6, 7]], [(1, 2), (1, 2),\n                                                              (1, 4)], [(2, 1),\n                                                                        (2, 1),\n                                                                        (4, 1)],\n                          [force_dp_dict, {}, {}]),\n    16:\n        get_solution_case(gpt_specs[\"6.7B\"], 64, 8,\n                          [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 8)] * 2,\n                          [(2, 4)] * 2, [force_dp_dict] * 2),\n    32:\n        get_solution_case(\n            gpt_specs[\"15B\"], 128, 16,\n            [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],\n            [(1, 8)] * 4, [(2, 4)] * 4, [force_dp_dict] * 4),\n    64:\n        get_solution_case(gpt_specs[\"39B\"], 1024,\n                          16, [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9],\n                               [10], [11], [12], [13], [14], [15]],\n                          [(1, 4)] * 16, [(1, 4)] * 16, [force_dp_dict] * 16),\n}\n\n# Grid search on hyperparameters\ngrid_search_suite = {\n    2: (get_search_cases(gpt_specs[\"760M\"], [32, 64, 128, 256], [6]) +\n        get_search_cases(gpt_specs[\"760M\"], [32, 64], [12])),\n    4: (get_search_cases(gpt_specs[\"1.3B\"], [32, 64, 128], [6]) +\n        get_search_cases(gpt_specs[\"1.3B\"], [32, 64], [12])),\n    8: (get_search_cases(gpt_specs[\"2.6B\"], [64, 128, 256], [8]) +\n        get_search_cases(gpt_specs[\"2.6B\"], [64, 128], [16])),\n    16: get_search_cases(gpt_specs[\"6.7B\"], [32, 64, 128, 256], [8]),\n    32: get_search_cases(gpt_specs[\"15B\"], [64, 128, 256, 512], [16]),\n    64: get_search_cases(gpt_specs[\"39B\"], [128, 256, 512, 1024], [8]),\n}\n\n# Small test cases for correctness test\ncorrectness_test_suite = {\n    8: get_search_cases(gpt_specs[\"2.6B\"], [128], [8]),\n}\n"
  },
  {
    "path": "benchmark/alpa/suite_auto_moe.py",
    "content": "\"\"\"Benchmark suites for moe with auto parallelization.\"\"\"\nfrom suite_manual_moe import moe_specs\n# Share parallel options with the GPT suite\nfrom suite_auto_gpt import (get_search_cases, get_solution_case, force_dp_dict)\n\n# Temporary debug suite\ntmp_suite = {}\n\n# Performance test with search solutions found for p3.16xlarge\nperf_test_suite = {\n    1:\n        get_solution_case(moe_specs[\"380M\"], 512, 1, [[0]], [(1, 1)], [(1, 1)],\n                          [{}]),\n    2:\n        get_solution_case(moe_specs[\"690M\"], 32, 8, [[0, 1, 2, 3, 4, 5, 6, 7]],\n                          [(1, 2)], [(2, 1)], [force_dp_dict]),\n    4:\n        get_solution_case(moe_specs[\"1.3B\"], 32, 8,\n                          [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 2)] * 2,\n                          [(2, 1)] * 2, [force_dp_dict] * 2),\n    8:\n        get_solution_case(moe_specs[\"2.4B\"], 32, 8,\n                          [[0, 1, 2, 3], [4, 5, 6, 7]], [(1, 4)] * 2,\n                          [(4, 1)] * 2, [force_dp_dict] * 2),\n    16:\n        get_solution_case(moe_specs[\"10B\"], 16, 8, [[0, 1, 2, 3], [4, 5, 6, 7]],\n                          [(1, 8)] * 2, [(8, 1)] * 2, [{}] * 2),\n    32:\n        get_solution_case(moe_specs[\"27B\"], 128, 8,\n                          [[0], [1], [2], [3], [4], [5], [6], [7]],\n                          [(1, 4)] * 8, [(4, 1)] * 8, [{}] * 8),\n    64:\n        get_solution_case(moe_specs[\"70B\"], 64, 8,\n                          [[0], [1], [2], [3], [4], [5], [6], [7]],\n                          [(1, 8)] * 8, [(8, 1)] * 8, [{}] * 8),\n}\n\n# Grid search on hyperparameters\ngrid_search_suite = {\n    2: (get_search_cases(moe_specs[\"690M\"], [16, 32, 64], [8])),\n    4: (get_search_cases(moe_specs[\"1.3B\"], [16, 32, 64], [8])),\n    8: (get_search_cases(moe_specs[\"2.4B\"], [16, 32, 64], [8])),\n    16: (get_search_cases(moe_specs[\"10B\"], [16, 32, 64], [8])),\n    32: (get_search_cases(moe_specs[\"27B\"], [32, 64, 128], [4, 8, 16])),\n    64: (get_search_cases(moe_specs[\"70B\"], [64], [8, 16, 32])),\n    # submesh_choices_mode: \"small_power_of_two\", max num_cpus = 20\n}\n"
  },
  {
    "path": "benchmark/alpa/suite_inference_gpt.py",
    "content": "\"\"\"Benchmark suites for gpt with auto parallelization.\"\"\"\nfrom suite_manual_gpt import gpt_specs\nfrom benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs)\n\nprefer_reduce_scatter = True\nforce_batch_dim_mapping = True\nuse_remat = False\n\nprofile_suite = {}\nforce_dp_dict = {\"force_batch_dim_to_mesh_dim\": 0}\n\n\ndef get_config(model_config,\n               pp_list,\n               dp_list,\n               op_list,\n               num_micro_batch_config,\n               batch_size_config,\n               ignore_one_device_case=False):\n    for pp in pp_list:\n        for dp in dp_list:\n            for op in op_list:\n                num_gpus = pp * dp * op\n                if ignore_one_device_case and num_gpus == 1:\n                    continue\n                for bs in batch_size_config:\n                    for nb in num_micro_batch_config:\n                        total_bs = bs * nb\n                        if num_gpus not in profile_suite:\n                            profile_suite[num_gpus] = []\n                        parallel_args = UniformParallelArgs(\n                            prefer_reduce_scatter, use_remat, dp, op, pp,\n                            force_batch_dim_mapping)\n                        case = BenchmarkCase(total_bs, model_config, nb,\n                                             \"uniform\", parallel_args)\n                        profile_suite[num_gpus].append(case)\n\n\n## general examples:\n#get_config(gpt_specs[\"350M\"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16])\n#get_config(gpt_specs[\"760M\"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16])\n#get_config(gpt_specs[\"1.3B\"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16])\n#get_config(gpt_specs[\"2.6B\"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16])\n#get_config(gpt_specs[\"6.7B\"], [1, 2, 4, 8], [1], [1], [1], [1, 4, 16])\n#get_config(gpt_specs[\"15B\"],  [1, 2, 4, 8], [1], [1], [1], [1, 4, 16])\n\n## benchmark specific parallel method:\n#get_config(gpt_specs[\"6.7B\"], [1], [1], [1, 2, 4, 8], [1, 256], [1, 4, 16, 64])\n#get_config(gpt_specs[\"6.7B\"], [1], [1, 2, 4, 8], [1], [1, 256], [1, 4, 16, 64],\n#           ignore_one_device_case=True)\n#get_config(gpt_specs[\"6.7B\"], [1, 2, 4, 8], [1], [1], [1, 256], [1, 4, 16, 64],\n#           ignore_one_device_case=True)\n\n## generate inference profiling results\nget_config(gpt_specs[\"1.3B\"], [1, 2, 4, 8], [1], [1, 2, 4, 8], [1],\n           [1, 2, 4, 8, 16])\nget_config(gpt_specs[\"2.6B\"], [1, 2, 4, 8, 16, 32], [1], [1, 2, 4, 8], [1],\n           [1, 2, 4, 8, 16])\nget_config(gpt_specs[\"6.7B\"], [1, 2, 4, 8, 16, 32], [1], [1, 2, 4, 8], [1],\n           [1, 2, 4, 8, 16])\nget_config(gpt_specs[\"15B\"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1],\n           [1, 2, 4, 8, 16])\n"
  },
  {
    "path": "benchmark/alpa/suite_inference_moe.py",
    "content": "\"\"\"Benchmark suites for gpt with auto parallelization.\"\"\"\nfrom suite_manual_moe import moe_specs\nfrom benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs)\n\nprefer_reduce_scatter = True\nforce_batch_dim_mapping = True\nuse_remat = False\n\nprofile_suite = {}\nforce_dp_dict = {\"force_batch_dim_to_mesh_dim\": 0}\n\n\ndef get_config(model_config,\n               pp_list,\n               dp_list,\n               op_list,\n               num_micro_batch_config,\n               batch_size_config,\n               ignore_one_device_case=False):\n    for pp in pp_list:\n        for dp in dp_list:\n            for op in op_list:\n                num_gpus = pp * dp * op\n                if ignore_one_device_case and num_gpus == 1:\n                    continue\n                for bs in batch_size_config:\n                    for nb in num_micro_batch_config:\n                        total_bs = bs * nb\n                        if num_gpus not in profile_suite:\n                            profile_suite[num_gpus] = []\n                        parallel_args = UniformParallelArgs(\n                            prefer_reduce_scatter, use_remat, dp, op, pp,\n                            force_batch_dim_mapping)\n                        case = BenchmarkCase(total_bs, model_config, nb,\n                                             \"uniform\", parallel_args)\n                        profile_suite[num_gpus].append(case)\n\n\n## generate inference profiling results\nget_config(moe_specs[\"1.3B\"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1],\n           [1, 2, 4, 8, 16])\nget_config(moe_specs[\"2.4B\"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1],\n           [1, 2, 4, 8, 16])\nget_config(moe_specs[\"7.1B\"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1],\n           [1, 2, 4, 8, 16])\nget_config(moe_specs[\"10B\"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1],\n           [1, 2, 4, 8, 16])\n"
  },
  {
    "path": "benchmark/alpa/suite_manual_gpt.py",
    "content": "\"\"\"Benchmark suites for gpt with manual specifications.\"\"\"\nfrom collections import namedtuple\nfrom benchmark_parallel_utils import BenchmarkCase, UniformParallelArgs\n\n# B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size\n# head = num_heads,\n# NB = num_micro_batches, PM = parallel_mode\n# 3D config = 3D parallel config (Data, Operator, Pipeline)\n# RS = prefer_reduce_scatter, Remat = use_rematerialization,\n# FM = force_batch_dim_mapping,\n\nGPTModelConfig = namedtuple(\n    \"GPTModelConfig\",\n    [\"seq_len\", \"hidden_size\", \"num_layers\", \"num_heads\", \"vocab_size\"])\n\ngpt_specs = {\n    #                      S，   H,   L,  head,   V,\n    \"125M\": GPTModelConfig(1024, 768, 12, 12, 51200),\n    \"350M\": GPTModelConfig(1024, 1024, 24, 16, 51200),\n    \"760M\": GPTModelConfig(1024, 1536, 24, 16, 51200),\n    \"1.3B\": GPTModelConfig(1024, 2048, 24, 32, 51200),\n    \"2.6B\": GPTModelConfig(1024, 2560, 32, 32, 51200),\n    \"6.7B\": GPTModelConfig(1024, 4096, 32, 32, 51200),\n    \"15B\": GPTModelConfig(1024, 5120, 48, 40, 51200),\n    \"39B\": GPTModelConfig(1024, 8192, 48, 64, 51200),\n    \"76B\": GPTModelConfig(1024, 10240, 60, 80, 51200),\n}\n\n_ = None\n\n# Temporary debug suite\n# key = the number of gpus, value = a list of cases\n# B, model, NB, PM, (RS, Remat, 3D Config, FM)\ntmp_suite = {\n    1: [\n        BenchmarkCase(16, gpt_specs[\"350M\"], 1, \"uniform\",\n                      UniformParallelArgs(True, True, 1, 1, 1, True))\n    ],\n    8: [\n        BenchmarkCase(128, GPTModelConfig(1024, 4096, 4, 32, 51200),\n                      4, \"uniform\",\n                      UniformParallelArgs(True, True, 4, 1, 2, True)),\n    ],\n}\n\n# Fast performance test on models with fewer layers\n# B, model, NB, PM, (RS, Remat, 3D Config, FM)\nperf_test_fast_2d_suite = {\n    1: [\n        BenchmarkCase(8, GPTModelConfig(1024, 1024, 4, 32, 51200), 1, \"uniform\",\n                      UniformParallelArgs(False, True, 1, 1, 1, True))\n    ],\n    8: [\n        BenchmarkCase(32, GPTModelConfig(1024, 4096, 4, 32, 51200),\n                      1, \"uniform\",\n                      UniformParallelArgs(True, True, 8, 1, 1, True)),\n        BenchmarkCase(128, GPTModelConfig(1024, 4096, 4, 32, 51200),\n                      4, \"uniform\",\n                      UniformParallelArgs(True, True, 8, 1, 1, True)),\n    ],\n}\n\n# Performance test on normal models\n# B, model, NB, PM, (RS, Remat, 3D Config, FM)\nperf_test_suite = {\n    1: [\n        BenchmarkCase(16, gpt_specs[\"350M\"], 1, \"uniform\",\n                      UniformParallelArgs(True, True, 1, 1, 1, True))\n    ],\n    4: [\n        BenchmarkCase(16 * 4, gpt_specs[\"1.3B\"], 1 * 4, \"uniform\",\n                      UniformParallelArgs(True, True, 1, 2, 2, True)),\n    ],\n    8: [\n        BenchmarkCase(32, gpt_specs[\"2.6B\"], 4, \"uniform\",\n                      UniformParallelArgs(True, True, 2, 2, 2, True))\n        #BenchmarkCase(32 * 32, gpt_specs[\"2.6B\"], 2 * 32, \"uniform\",\n        #              UniformParallelArgs(True, True, 2, 2, 2, True)),\n        #BenchmarkCase(32 * 32, gpt_specs[\"2.6B\"], 4 * 32, \"uniform\",\n        #              UniformParallelArgs(True, True, 2, 1, 4, True))\n    ],\n    64: [\n        BenchmarkCase(1024, gpt_specs[\"39B\"], 1024, \"uniform\",\n                      UniformParallelArgs(True, True, 1, 4, 16, True))\n    ],\n}\n"
  },
  {
    "path": "benchmark/alpa/suite_manual_moe.py",
    "content": "\"\"\"Benchmark suites for moe with manual specifications.\"\"\"\nfrom collections import namedtuple\nfrom benchmark_parallel_utils import BenchmarkCase, UniformParallelArgs\n\n# B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size\n# head = num_heads, S_ = expert_group_size, E = expert_number,\n# NB = num_micro_batches, PM = parallel_mode\n# 3D config = 3D parallel config (Data, Operator, Pipeline)\n# RS = prefer_reduce_scatter, Remat = use_rematerialization,\n# FM = force_batch_dim_mapping,\n\nMoEModelConfig = namedtuple(\"MoEModelConfig\", [\n    \"seq_len\", \"hidden_size\", \"num_layers\", \"num_heads\", \"vocab_size\",\n    \"num_experts\", \"expert_group_size\"\n])\n\nmoe_specs = {\n    #                      S,    H,   L, head, V,   E,  S_\n    \"380M\": MoEModelConfig(1024, 768, 8, 16, 32000, 8, 2048),\n    \"690M\": MoEModelConfig(1024, 768, 8, 16, 32000, 16, 2048),\n    \"1.3B\": MoEModelConfig(1024, 768, 16, 16, 32000, 16, 2048),\n    \"2.4B\": MoEModelConfig(1024, 1024, 16, 16, 32000, 16, 2048),\n    \"7.1B\": MoEModelConfig(1024, 1280, 16, 16, 32000, 32, 2048),\n    \"10B\": MoEModelConfig(1024, 1536, 16, 16, 32000, 32, 2048),\n    \"27B\": MoEModelConfig(1024, 2048, 16, 16, 32000, 48, 2048),\n    \"70B\": MoEModelConfig(1024, 2048, 32, 16, 32000, 64, 2048),\n    \"140B\": MoEModelConfig(1024, 2048, 32, 16, 32000, 128, 2048),\n}\n\n# Temporary debug suite\n# key = the number of gpus, value = a list of cases\n# B, model, NB, PM, RS, Remat, 3D Config, FM\ntmp_suite = {\n    1: [\n        BenchmarkCase(8, moe_specs[\"380M\"], 1, \"uniform\",\n                      UniformParallelArgs(True, True, 1, 1, 1, False))\n    ],\n    8: [\n        BenchmarkCase(16, moe_specs[\"1.3B\"], 1, \"uniform\",\n                      UniformParallelArgs(True, True, 1, 4, 2, False))\n    ],\n    16: [\n        # verify cost model vs. profiling\n        BenchmarkCase(1024, moe_specs[\"10B\"], 32, \"uniform\",\n                      UniformParallelArgs(True, True, 2, 8, 1, True))\n    ],\n}\n\n# Fast performance test on models with fewer layers\n# B, S, H, L,  #head, V, E, S_, NB, PM, Remat, RS, 3D Config, FM\nperf_test_fast_2d_suite = {\n    1: [\n        BenchmarkCase(8, MoEModelConfig(1024, 1024, 8, 32, 25600, 8, 1024),\n                      1, \"uniform\",\n                      UniformParallelArgs(True, True, 1, 1, 1, True)),\n    ],\n    8: [\n        BenchmarkCase(16, MoEModelConfig(1024, 1024, 4, 32, 25600, 32, 1024),\n                      1, \"uniform\",\n                      UniformParallelArgs(False, True, 8, 1, 1, False)),\n        BenchmarkCase(16, MoEModelConfig(1024, 1024, 4, 32, 25600, 32, 1024),\n                      1, \"uniform\",\n                      UniformParallelArgs(False, True, 4, 2, 1, False)),\n        BenchmarkCase(16, MoEModelConfig(1024, 1024, 4, 32, 25600, 32, 1024),\n                      1, \"uniform\",\n                      UniformParallelArgs(False, True, 2, 4, 1, False)),\n    ],\n}\n"
  },
  {
    "path": "benchmark/alpa/suite_unet.py",
    "content": "\"\"\"Suites for wresnet benchmarking.\"\"\"\nfrom collections import namedtuple\nimport numpy as np\n\nfrom benchmark_parallel_utils import (BenchmarkCase, SearchParallelArgs,\n                                      LoadSolutionParallelArgs)\n\nUNetModelConfig = namedtuple(\n    \"UNetModelConfig\",\n    [\"image_size\", \"channel_size\", \"block_cnt\", \"dtype\", \"num_layers\"])\n\n# block cnt->manual layers: {4: 13, }\nunet_specs = {\n    # #Params: sample size, first channel's size, block cnt, dtype\n    \"470M\": UNetModelConfig(32, 320, 4, np.float32, 13),\n    \"1B\": UNetModelConfig(32, 480, 4, np.float32, 13),\n    \"1.2B\": UNetModelConfig(32, 512, 4, np.float32, 13),\n    \"1.8B\": UNetModelConfig(32, 640, 4, np.float32, 13),\n    \"2B\": UNetModelConfig(32, 672, 4, np.float32, 13),\n}\n\nprefer_reduce_scatter = False\nuse_remat = True\nforce_batch_dim_mapping = False\n\nauto_stage_option = {\n    \"submesh_physical_shape_space\": \"small_power_of_two\",\n    \"submesh_logical_shape_space\": \"single_node_model_parallel\",\n    \"stage_imbalance_tolerance\": 0.25,\n    \"use_hlo_cost_model\": False,\n    \"profiling_database_filename\": None,\n}\n\n\ndef get_num_auto_layers(name):\n    return int(unet_specs[name].block_cnt * 1.5)\n\n\ndef get_search_cases(model_name, max_global_batch_size, num_micro_batches_list):\n    num_auto_layers = get_num_auto_layers(model_name)\n    return [\n        BenchmarkCase(\n            max_global_batch_size, unet_specs[model_name], num_micro_batches,\n            \"search\",\n            SearchParallelArgs(prefer_reduce_scatter, use_remat,\n                               num_auto_layers, auto_stage_option))\n        for num_micro_batches in num_micro_batches_list\n    ]\n\n\ndef get_solution_case(model_name, max_global_batch_size, num_micro_batches,\n                      forward_stage_layer_ids, submesh_physical_shapes,\n                      submesh_logical_shapes,\n                      submesh_autosharding_option_dicts):\n    num_auto_layers = get_num_auto_layers(model_name)\n    return [\n        BenchmarkCase(\n            max_global_batch_size, unet_specs[model_name], num_micro_batches,\n            \"load_solution\",\n            LoadSolutionParallelArgs(prefer_reduce_scatter, use_remat,\n                                     num_auto_layers, forward_stage_layer_ids,\n                                     submesh_physical_shapes,\n                                     submesh_logical_shapes,\n                                     submesh_autosharding_option_dicts))\n    ]\n\n\n# B = batch_size, I = image_size,\n# L = num_layers, C = num_base_channels, W = width_factor,\n# NB = num_micro_batches, PM = parallel_mode\n# L_Shape = logical_mesh_shape\n# RS = prefer_reduce_scatter, Remat = use_rematerialization,\n# FM = force_batch_dim_mapping,\n\nforce_dp_dict = {\"force_batch_dim_to_mesh_dim\": 0}\n\n# Performance test with shard parallel\ntmp_suite = {}\n\n# Performance test with shard parallel\n# key = the number of gpus, value = a list of cases\n# B,    I,   L,   C,   W, dtype,  NB, PM,          RS,    Remat, L_shape, FM\nperf_test_2d_suite = {}\n\n# Performance test with search solutions found for p3.16xlarge\nperf_test_auto_suite = {\n    2:\n        get_solution_case(\"470M\", 256, 4,\n                          [list(range(7)), list(range(7, 13))], [(1, 1)] * 2,\n                          [(1, 1)] * 2, [{}] * 2),\n    4:\n        get_solution_case(\"1B\", 2048, 32,\n                          [list(range(8)), list(range(8, 13))], [(1, 2)] * 2,\n                          [(1, 2)] * 2, [{}] * 2),\n    8:\n        get_solution_case(\"2B\", 2048, 32,\n                          [list(range(9)), list(range(9, 13))], [(1, 4)] * 2,\n                          [(1, 4)] * 2, [{}] * 2),\n}\n\n# Grid search on hyperparameters\n# key = the number of gpus, value = a list of cases\n# model_name, B, NB\ngrid_search_auto_suite = {\n    4: get_search_cases(\"1B\", 256, [\n        16,\n    ])\n}\n"
  },
  {
    "path": "benchmark/alpa/suite_wresnet.py",
    "content": "\"\"\"Suites for wresnet benchmarking.\"\"\"\nfrom collections import namedtuple\nfrom benchmark_parallel_utils import (BenchmarkCase, SearchParallelArgs,\n                                      LoadSolutionParallelArgs,\n                                      ShardParallelArgs)\n\n# B = batch_size, I = image_size,\n# L = num_layers, C = num_base_channels, W = width_factor,\n# NB = num_micro_batches, PM = parallel_mode\n# L_Shape = logical_mesh_shape\n# RS = prefer_reduce_scatter, Remat = use_rematerialization,\n# FM = force_batch_dim_mapping,\n\nWResNetModelConfig = namedtuple(\n    \"WResNetModelConfig\",\n    [\"image_size\", \"num_layers\", \"num_channels\", \"width_factor\", \"dtype\"])\n\nwresnet_specs = {\n    #                      I,   L,   C,   W,  dtype,\n    \"250M\": WResNetModelConfig(224, 50, 160, 2, \"fp32\"),\n    \"500M\": WResNetModelConfig(224, 50, 224, 2, \"fp32\"),\n    \"1B\": WResNetModelConfig(224, 50, 320, 2, \"fp32\"),\n    \"2B\": WResNetModelConfig(224, 50, 448, 2, \"fp32\"),\n    \"4B\": WResNetModelConfig(224, 50, 640, 2, \"fp32\"),\n    \"6.8B\": WResNetModelConfig(224, 50, 320, 16, \"fp32\"),\n    \"13B\": WResNetModelConfig(224, 101, 320, 16, \"fp32\"),\n}\n\nprefer_reduce_scatter = True\nuse_remat = True\n\nauto_stage_option = {\n    \"submesh_physical_shape_space\": \"small_power_of_two\",\n    \"submesh_logical_shape_space\": \"single_node_model_parallel\",\n    \"stage_imbalance_tolerance\": 0.25,\n    \"use_hlo_cost_model\": False,\n    \"profiling_database_filename\": None,\n}\n\n\ndef get_num_auto_layers(model_name):\n    if wresnet_specs[model_name].num_layers == 50:\n        return 16  # number of residual blocks\n    elif wresnet_specs[model_name].num_layers == 101:\n        return 33\n    else:\n        raise ValueError(\"Unsupported number of layers: {}\".format(\n            wresnet_specs[model_name].num_layers))\n\n\ndef get_search_cases(model_name, max_global_batch_size, num_micro_batches_list):\n    num_auto_layers = get_num_auto_layers(model_name)\n    return [\n        BenchmarkCase(\n            max_global_batch_size, wresnet_specs[model_name], num_micro_batches,\n            \"search\",\n            SearchParallelArgs(prefer_reduce_scatter, use_remat,\n                               num_auto_layers, auto_stage_option))\n        for num_micro_batches in num_micro_batches_list\n    ]\n\n\ndef get_solution_case(model_name, max_global_batch_size, num_micro_batches,\n                      forward_stage_layer_ids, submesh_physical_shapes,\n                      submesh_logical_shapes,\n                      submesh_autosharding_option_dicts):\n    num_auto_layers = get_num_auto_layers(model_name)\n    return [\n        BenchmarkCase(\n            max_global_batch_size, wresnet_specs[model_name], num_micro_batches,\n            \"load_solution\",\n            LoadSolutionParallelArgs(prefer_reduce_scatter, use_remat,\n                                     num_auto_layers, forward_stage_layer_ids,\n                                     submesh_physical_shapes,\n                                     submesh_logical_shapes,\n                                     submesh_autosharding_option_dicts))\n    ]\n\n\nforce_dp_dict = {\"force_batch_dim_to_mesh_dim\": 0}\n\n# Performance test with shard parallel\ntmp_suite = {}\n\n# Performance test with shard parallel\n# key = the number of gpus, value = a list of cases\n# B,    I,   L,   C,   W, dtype,  NB, PM,          RS,    Remat, L_shape, FM\nperf_test_2d_suite = {\n    1: [\n        BenchmarkCase(32, WResNetModelConfig(224, 50, 160, 2, \"fp32\"),\n                      1, \"2d_shard\",\n                      ShardParallelArgs(False, False, (1, 1), False)),\n        BenchmarkCase(1536, WResNetModelConfig(224, 50, 160, 2, \"fp32\"),\n                      48, \"2d_shard\",\n                      ShardParallelArgs(False, False, (1, 1), False)),\n    ],\n    4: [\n        BenchmarkCase(32, WResNetModelConfig(224, 50, 320, 2, \"fp32\"),\n                      1, \"2d_shard\",\n                      ShardParallelArgs(False, False, (4, 1), False)),\n        BenchmarkCase(1536, WResNetModelConfig(224, 50, 320, 2, \"fp32\"),\n                      48, \"2d_shard\",\n                      ShardParallelArgs(False, False, (4, 1), False)),\n        BenchmarkCase(64, WResNetModelConfig(224, 50, 320, 2, \"fp32\"),\n                      1, \"2d_shard\",\n                      ShardParallelArgs(False, False, (4, 1), False)),\n        BenchmarkCase(1536, WResNetModelConfig(224, 50, 320, 2, \"fp32\"),\n                      24, \"2d_shard\",\n                      ShardParallelArgs(False, False, (4, 1), False)),\n    ],\n    8: [\n        BenchmarkCase(64, WResNetModelConfig(224, 50, 320, 2, \"fp32\"),\n                      1, \"2d_shard\",\n                      ShardParallelArgs(False, False, (8, 1), False)),\n    ],\n}\n\n# Performance test with search solutions found for p3.16xlarge\nperf_test_auto_suite = {\n    1:\n        get_solution_case(\"250M\", 1536, 24, [list(range(16))], [(1, 1)],\n                          [(1, 1)], [{}]),\n    2:\n        get_solution_case(\"500M\", 1536, 24, [list(range(16))], [(1, 2)],\n                          [(1, 2)], [{}]),\n    4:\n        get_solution_case(\"1B\", 1536, 24, [list(range(16))], [(1, 4)], [(1, 4)],\n                          [{}]),\n    8:\n        get_solution_case(\n            \"2B\", 1536, 24,\n            [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15]],\n            [(1, 4), (1, 4)], [(4, 1), (1, 4)], [{}, force_dp_dict]),\n    16:\n        get_solution_case(\n            \"4B\", 1536, 32,\n            [[0, 1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15]],\n            [(1, 4), (1, 4),\n             (1, 8)], [(4, 1), (4, 1),\n                       (8, 1)], [force_dp_dict, force_dp_dict, {}]),\n    32:\n        get_solution_case(\n            \"6.8B\", 1536,\n            32, [[0, 1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15]],\n            [(1, 8), (1, 8), (1, 8),\n             (1, 8)], [(8, 1), (8, 1), (8, 1),\n                       (8, 1)], [force_dp_dict, {}, {}, {}]),\n    64:\n        get_solution_case(\n            \"13B\", 1520, 38,\n            [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15],\n             [16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27, 28],\n             [29, 30, 31, 32]], [(1, 8), (1, 8), (1, 8), (1, 8), (1, 8), (1, 8),\n                                 (1, 8), (1, 8)],\n            [(8, 1), (1, 8), (8, 1), (1, 8), (8, 1), (8, 1), (1, 8),\n             (8, 1)],\n            [{}, force_dp_dict, {}, force_dp_dict, {}, {}, force_dp_dict, {}]),\n}\n\n# Grid search on hyperparameters\n# key = the number of gpus, value = a list of cases\ngrid_search_auto_suite = {\n    1: get_search_cases(\"250M\", 1536, [24, 32]),\n    2: get_search_cases(\"500M\", 1536, [24, 32]),\n    4: get_search_cases(\"1B\", 1536, [24, 32]),\n    8: get_search_cases(\"2B\", 1536, [24, 32]),\n    16: get_search_cases(\"4B\", 1536, [24, 32]),\n    32: (get_search_cases(\"6.8B\", 1520, [38]) +\n         get_search_cases(\"6.8B\", 1512, [42])),\n    64: get_search_cases(\"13B\", 1520, [38]),\n}\n"
  },
  {
    "path": "benchmark/alpa/util.py",
    "content": "import os\nimport time\n\nimport numpy as np\n\nGB = 1 << 30\n\n\ndef write_tsv(heads, values, filename, print_line=True):\n    \"\"\"Write tsv data to a file.\"\"\"\n    assert len(heads) == len(values)\n\n    values = [str(x) for x in values]\n\n    with open(filename, \"a\") as fout:\n        fout.write(\"\\t\".join(values) + \"\\n\")\n\n    if print_line:\n        line = \"\"\n        for i in range(len(heads)):\n            line += heads[i] + \": \" + values[i] + \"  \"\n        print(line)\n\n\ndef benchmark_func(run_func, sync_func=None, warmup=1, repeat=3, number=5):\n    \"\"\"Benchmark the execution time of a function.\"\"\"\n    costs = []\n\n    # Warmup\n    for i in range(warmup):\n        run_func()\n\n    # Benchmark\n    for i in range(repeat):\n        if sync_func:\n            sync_func()\n        tic = time.time()\n\n        for j in range(number):\n            run_func()\n\n        if sync_func:\n            sync_func()\n        costs.append(time.time() - tic)\n\n    return np.array(costs) / number\n\n\ndef run_cmd(cmd):\n    print(cmd)\n    return os.system(cmd)\n\n\ndef get_torch_memory_usage(print_info=False):\n    \"\"\"Get accurate gpu memory usage by querying torch runtime\"\"\"\n    import torch\n    allocated = torch.cuda.memory_allocated(0)\n    reserved = torch.cuda.memory_reserved(0)\n    if print_info:\n        print(\"allocated: %.2f GB\" % (allocated / GB), flush=True)\n        print(\"reserved:  %.2f GB\" % (reserved / GB), flush=True)\n    return allocated\n\n\ndef compute_gpt_tflops(batch_size,\n                       seq_len,\n                       num_layers,\n                       hidden_size,\n                       vocab_size,\n                       num_gpus,\n                       latency,\n                       backward=True,\n                       checkpoint_activations=False):\n    factor = 24\n    if backward:\n        factor += 48\n    if checkpoint_activations:\n        factor += 24\n\n    total_flop = factor * batch_size * seq_len * (hidden_size ** 2) * num_layers * \\\n          (1 + seq_len / (6 * hidden_size)) \\\n          + 6 * batch_size * seq_len * hidden_size * vocab_size\n    # Note: The above formula does not count the first embedding table lookup\n    # because it is a sparse operation.\n    # If we use dense dot to compute the first embedding table lookup,\n    # then the last term in total_flops should be\n    # \"+ 10 * batch_size * seq_len * hidden_size * vocab_size\".\n    tflops = total_flop / latency / num_gpus / 1e12\n    return tflops\n\n\ndef compute_moe_tflops(batch_size,\n                       seq_len,\n                       num_layers,\n                       hidden_size,\n                       group_size,\n                       vocab_size,\n                       num_expert,\n                       num_gpus,\n                       latency,\n                       mlp_factor=8,\n                       checkpoint_activations=False):\n    factor = 4 if checkpoint_activations else 3\n    # num_layers / 2 attention block\n    pure_transformer = batch_size * seq_len * (hidden_size ** 2) * (8 + 4 * mlp_factor) +\\\n        4 * batch_size * (seq_len ** 2) * hidden_size\n    pure_transformer = pure_transformer * factor\n\n    # num_layers / 2 attention-moe block\n    # transformer\n    moe_transformer = batch_size * seq_len * (hidden_size ** 2) * 8  +\\\n        4 * batch_size * (seq_len ** 2) * hidden_size\n    # expert FFNs:\n    # moe_transformer += 2 * batch_size * seq_len * (hidden_size ** 2) * mlp_factor * 2\n    moe_transformer += 8 * batch_size * seq_len * (hidden_size**2) * mlp_factor\n\n    # softmax\n    moe_transformer += 2 * batch_size * seq_len * hidden_size * num_expert\n    # top-2 gating\n    moe_transformer += 2 * (batch_size * seq_len) * 2 * group_size\n    # dispatch + combine\n    moe_transformer += 2 * batch_size * seq_len * hidden_size * 2 * group_size * 2\n\n    moe_transformer = moe_transformer * factor\n\n    # vocab\n    embedding = 6 * batch_size * seq_len * hidden_size * vocab_size\n\n    total_flop = pure_transformer * num_layers / 2 + \\\n                 moe_transformer * num_layers / 2 + embedding\n    tflops = total_flop / latency / num_gpus / 1e12\n    return tflops\n\n\ndef compute_gpt_parameter_count(num_layers, hidden_size, vocab_size):\n    return num_layers * (\n        # self-attention\n        hidden_size * (3 * hidden_size + 1) + hidden_size * (hidden_size + 1) +\n        # mlp\n        hidden_size * (4 * hidden_size + 1) + hidden_size * 4 *\n        (hidden_size + 1) +\n        # layer norm\n        hidden_size * 4) + vocab_size * (hidden_size + 1)\n\n\ndef compute_moe_parameter_count(num_layers,\n                                hidden_size,\n                                vocab_size,\n                                num_expert,\n                                mlp_factor=8,\n                                tie_embedding=True):\n    pure_transformer = \\\n        hidden_size * (3 * hidden_size + 1) + hidden_size * (hidden_size + 1) + \\\n        hidden_size * (mlp_factor * hidden_size + 1) + hidden_size * mlp_factor * (hidden_size + 1) + \\\n        hidden_size * 4\n    moe_transformer = \\\n        hidden_size * (3 * hidden_size + 1) + hidden_size * (hidden_size + 1) + \\\n        num_expert * (hidden_size * (mlp_factor * hidden_size + 1) + hidden_size * mlp_factor * (hidden_size + 1)) + \\\n        hidden_size * 4\n\n    # embedding\n    embedding_factor = 1 if tie_embedding else 2\n    embedding = embedding_factor * vocab_size * (hidden_size + 1)\n\n    if num_expert == 1:\n        return pure_transformer * num_layers + embedding\n    else:\n        half = num_layers / 2\n        return half * pure_transformer + half * moe_transformer + embedding\n"
  },
  {
    "path": "benchmark/cupy/profile_communication.py",
    "content": "\"\"\"\nBenchmark the communication bandwidth with Ray + NCCL.\nWe use the python binding cupy.nccl to call NCCL.\n\nUsage:\n  python3 profile_communication.py\n\"\"\"\n\nimport argparse\nimport time\nimport os\n\nimport cupy as cp\nfrom cupy.cuda import nccl\nimport numpy as np\nimport ray\n\nMB = 1 << 20\nGB = 1 << 30\n\n\ndef do_all_reduce(comm, in_buffer, out_buffer):\n    comm.allReduce(\n        in_buffer.data.ptr,\n        out_buffer.data.ptr,\n        in_buffer.size,\n        nccl.NCCL_FLOAT32,\n        0,\n        cp.cuda.Stream.null.ptr,\n    )\n\n\ndef do_all_gather(comm, in_buffer, out_buffer):\n    comm.allGather(\n        in_buffer.data.ptr,\n        out_buffer.data.ptr,\n        in_buffer.size,\n        nccl.NCCL_FLOAT32,\n        cp.cuda.Stream.null.ptr,\n    )\n\n\ndef do_send_recv(comm, buf, is_sender):\n    if is_sender:\n        comm.send(buf.data.ptr, buf.size, nccl.NCCL_FLOAT32,\n                  1, cp.cuda.Stream.null.ptr)\n    else:\n        comm.recv(buf.data.ptr, buf.size, nccl.NCCL_FLOAT32,\n                  0, cp.cuda.Stream.null.ptr)\n\n\n@ray.remote(num_gpus=1)\nclass GpuHost:\n    def __init__(self, rank, world_size, nccl_uuid_list):\n        self.rank = rank\n        self.world_size = world_size\n        self.nccl_uuid_list = nccl_uuid_list\n        self.ct = 0\n\n    def init_communicator(self, groups):\n        if np.max(groups) >= self.world_size:\n            return None\n        if len(set(np.ravel(groups))) < len(np.ravel(groups)):\n            return None\n\n        comm = None\n        for group in groups:\n            nccl_uuid = self.nccl_uuid_list[self.ct]\n            self.ct += 1\n            for device_id in group:\n                if self.rank == device_id:\n                    assert comm is None\n                    comm = cp.cuda.nccl.NcclCommunicator(\n                        len(group), nccl_uuid, group.index(self.rank))\n\n        cp.cuda.Device(0).synchronize()\n        return comm\n\n    def profile_allreduce(self, size, dtype, groups):\n        comm = self.init_communicator(groups)\n        if comm is None:\n            return\n\n        in_buffer = cp.ones(int(size), dtype)\n        out_buffer = cp.ones(int(size), dtype)\n\n        do_all_reduce(comm, in_buffer, out_buffer)\n        do_all_reduce(comm, in_buffer, out_buffer)\n\n        number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13)\n        cp.cuda.Device(0).synchronize()\n        tic = time.time()\n        for i in range(number):\n            do_all_reduce(comm, in_buffer, out_buffer)\n        cp.cuda.Device(0).synchronize()\n        toc = time.time()\n\n        if self.rank == 0:\n            num_devices = len(groups[0])\n            time_cost = (toc - tic) / number\n            array_size = size * dtype().nbytes\n            communication_size = 2 * array_size * (num_devices - 1) / num_devices\n            bandwidth = communication_size / time_cost\n            print(f\"AllReduce: {groups}\\tBytes: {array_size / GB:.5f} GB\\t\"\n                  f\"Time: {time_cost:.5f} s\\tBandwidth: {bandwidth / (1<<30):.2f} GB/s\")\n\n    def profile_allgather(self, size, dtype, groups):\n        comm = self.init_communicator(groups)\n        if comm is None:\n            return\n\n        in_buffer = cp.ones(int(size) // len(groups[0]), dtype)\n        out_buffer = cp.ones(int(size), dtype)\n\n        do_all_gather(comm, in_buffer, out_buffer)\n\n        number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13)\n        cp.cuda.Device(0).synchronize()\n        tic = time.time()\n        for i in range(number):\n            do_all_gather(comm, in_buffer, out_buffer)\n        cp.cuda.Device(0).synchronize()\n        toc = time.time()\n\n        if self.rank == 0:\n            num_devices = len(groups[0])\n            time_cost = (toc - tic) / number\n            array_size = size * dtype().nbytes\n            communication_size = array_size * (num_devices - 1) / num_devices\n            bandwidth = communication_size / time_cost\n            print(f\"AllGather: {groups}\\tBytes: {array_size / GB:.5f} GB\\t\"\n                  f\"Time: {time_cost:.5f} s\\tBandwidth: {bandwidth / (1<<30):.2f} GB/s\")\n\n    def profile_send_recv(self, size, dtype, from_rank, to_rank):\n        groups = [[from_rank, to_rank]]\n        comm = self.init_communicator(groups)\n        if comm is None:\n            return\n\n        buf = cp.ones(int(size), dtype)\n        do_send_recv(comm, buf, self.rank == from_rank)\n        do_send_recv(comm, buf, self.rank == from_rank)\n\n        number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13)\n        cp.cuda.Device(0).synchronize()\n        tic = time.time()\n        for i in range(number):\n            do_send_recv(comm, buf, self.rank == from_rank)\n        cp.cuda.Device(0).synchronize()\n        toc = time.time()\n\n        if self.rank == from_rank:\n            time_cost = (toc - tic) / number\n            array_size = size * dtype().nbytes\n            communication_size = array_size\n            bandwidth = communication_size / time_cost\n            print(f\"SendRecv: {groups}\\tBytes: {array_size / GB:.5f} GB\\t\"\n                  f\"Time: {time_cost:.5f} s\\tBandwidth: {bandwidth / (1<<30):.2f} GB/s\")\n\n    def profile_multi_send_recv(self, size, dtype, groups):\n        comm = self.init_communicator(groups)\n        time.sleep(1)\n        comm_sync = self.init_communicator([list(np.ravel(groups))])\n        if comm is None or comm_sync is None:\n            return\n\n        assert all(len(group) == 2 for group in groups)\n\n        senders = set(group[0] for group in groups)\n        receivers = set(group[1] for group in groups)\n\n        buf = cp.ones(int(size), dtype)\n        buf_sync = cp.ones(1, dtype)\n\n        do_send_recv(comm, buf, self.rank in senders)\n        do_send_recv(comm, buf, self.rank in senders)\n        do_all_reduce(comm_sync, buf_sync, buf_sync)\n\n        number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13)\n        cp.cuda.Device(0).synchronize()\n        tic = time.time()\n        for i in range(number):\n            do_send_recv(comm, buf, self.rank in senders)\n        do_all_reduce(comm_sync, buf_sync, buf_sync)\n        cp.cuda.Device(0).synchronize()\n        toc = time.time()\n\n        if self.rank == groups[0][0]:\n            time_cost = (toc - tic) / number\n            array_size = size * dtype().nbytes\n            communication_size = array_size\n            bandwidth = len(groups) * communication_size / time_cost\n            print(f\"SendRecv: {groups}\\tBytes: {array_size / GB:.5f} GB\\t\"\n                  f\"Time: {time_cost:.5f} s\\tBandwidth: {bandwidth / (1<<30):.2f} GB/s\")\n\n    def profile(self):\n        # All-reduce\n        for i in range(29, 30):\n            self.profile_allreduce(1 << i, cp.float32, [list(range(self.world_size))])\n            self.profile_allreduce(1 << i, cp.float32, [list(range(self.world_size//2))])\n\n            #self.profile_allreduce(1 << i, cp.float32, [[0, 3]])\n            #self.profile_allreduce(1 << i, cp.float32, [[0, 4], [1, 5], [2, 6], [3, 7]])\n            #self.profile_allreduce(1 << i, cp.float32, [[0, 2, 4, 6], [1, 3, 5, 7]])\n            #self.profile_allreduce(1 << i, cp.float32, [[0, 1, 2, 3], [4, 5, 6, 7]])\n            #self.profile_allreduce(1 << i, cp.float32, [[0, 1, 2, 3, 4, 5, 6, 7]])\n\n        # single Send-recv\n        for i in range(29, 30):\n            self.profile_send_recv(1 << i, cp.float32, 0, 1)\n            self.profile_send_recv(1 << i, cp.float32, 0, self.world_size - 1)\n\n        # multiple p2p Send-recv\n        for i in range(29, 30):\n            self.profile_multi_send_recv(1 << i, cp.float32, [[0, 1], [2, 3]])\n            self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 4], [1, self.world_size - 3]])\n            self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 2], [1, self.world_size - 1]])\n            self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 4], [1, self.world_size - 3], [2, self.world_size - 2], [3, self.world_size - 1]])\n            self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 8], [1, self.world_size - 7], [2, self.world_size - 6], [3, self.world_size - 5]])\n            self.profile_multi_send_recv(1 << i, cp.float32, [[0, self.world_size - 8], [1, self.world_size - 7], [2, self.world_size - 6], [3, self.world_size - 5],\n                                                              [4, self.world_size - 4], [5, self.world_size - 3], [6, self.world_size - 2], [7, self.world_size - 1]])\n\n    def sync(self):\n        return\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--efa\", action=\"store_true\",\n        help=\"Use AWS EFS on p3.24 or p4.24 instances\")\n    parser.add_argument(\"--ib\", action=\"store_true\",\n        help=\"Use InfiniBand for NCCL communcation\")\n    parser.add_argument(\"--debug\", action=\"store_true\",\n        help=\"Print nccl debug information\")\n    args = parser.parse_args()\n\n    ray.init(address=\"auto\")\n    num_gpus = int(ray.cluster_resources()[\"GPU\"])\n\n    nccl_uuid_list = [cp.cuda.nccl.get_unique_id() for _ in range(500)]\n\n    workers = []\n    for i in range(num_gpus):\n        if args.efa:\n            env_vars = {\n                \"FI_PROVIDER\": \"efa\",\n                \"FI_EFA_USE_DEVICE_RDMA\": \"1\",\n                \"LD_LIBRARY_PATH\": os.environ.get(\"LD_LIBRARY_PATH\", \"\"),  # For libnccl-net.so\n                \"NCCL_PROTO\": \"simple\",\n            }\n        elif args.ib:\n            env_vars = {\n                \"NCCL_SOCKET_NTHREADS\": \"4\",\n                \"NCCL_NSOCKS_PERTHREAD\": \"4\",\n                \"NCCL_IB_HCA\": \"mlx5,ibp\",  # Change this to align with your IB interface name\n                \"LD_LIBRARY_PATH\": os.environ.get(\"LD_LIBRARY_PATH\", \"\"),\n            }\n        else:\n            env_vars = {\n                \"NCCL_SOCKET_NTHREADS\": \"4\",\n                \"NCCL_NSOCKS_PERTHREAD\": \"4\",\n                \"LD_LIBRARY_PATH\": os.environ.get(\"LD_LIBRARY_PATH\", \"\"),\n            }\n\n        if args.debug:\n            env_vars[\"NCCL_DEBUG\"] = \"INFO\"\n\n        workers.append(GpuHost.options(runtime_env={\"env_vars\": env_vars})\\\n                              .remote(i, num_gpus, nccl_uuid_list))\n\n    ray.get([w.profile.remote() for w in workers])\n    ray.get([w.sync.remote() for w in workers])\n"
  },
  {
    "path": "benchmark/cupy/profile_matmul.py",
    "content": "\"\"\"Profile peak TFLOPS on matrix multiplications.\"\"\"\nimport time\nimport cupy as cp\n\ndef benchmark(n, k, m, dtype, init_method=\"ones\"):\n    warmup = 5\n    number = 50\n\n    if init_method == \"zeros\":\n        a = cp.zeros((n, k), dtype)\n        b = cp.zeros((k, m), dtype)\n    elif init_method == \"full\":\n        a = cp.full((n, k), 1e-7, dtype)\n        b = cp.full((k, m), 1e-7, dtype)\n    elif init_method == \"nans\":\n        a = cp.full((n, k), cp.nan, dtype)\n        b = cp.full((k, m), cp.nan, dtype)\n    elif init_method == \"ones\":\n        a = cp.ones((n, k), dtype)\n        b = cp.ones((k, m), dtype)\n    elif init_method == \"ones+randn\":\n        a = cp.ones((n, k), dtype)\n        b = cp.ones((k, m), dtype)\n        ratio = 2\n        a[0:n//ratio, :] = cp.random.randn(n//ratio, k).astype(dtype)\n        b[0:k//ratio, :] = cp.random.randn(k//ratio, m).astype(dtype)\n    elif init_method == \"randn\":\n        a = cp.random.randn(n, k).astype(dtype)\n        b = cp.random.randn(k, m).astype(dtype)\n    elif init_method == \"uniform\":\n        a = cp.random.uniform(-1, 1, (n, k)).astype(dtype)\n        b = cp.random.uniform(-1, 1, (k, m)).astype(dtype)\n    elif init_method == \"uniform+\":\n        a = cp.random.uniform(0, 1, (n, k)).astype(dtype)\n        b = cp.random.uniform(0, 1, (k, m)).astype(dtype)\n    else:\n        raise ValueError(f\"Invalid method: {init_method}\")\n    for i in range(warmup):\n        c = a @ b\n\n    cp.cuda.Device(0).synchronize()\n    tic = time.time()\n    for i in range(number):\n        cp.dot(a, b, c)\n    cp.cuda.Device(0).synchronize()\n    toc = time.time()\n\n    total_flops = 2 * n * k * m\n    cost = (toc - tic) / number\n    shape = (n, k, m, dtype)\n\n    print(f\"shape: {shape}, init_method: {init_method:>8}, \"\n          f\"TFLOP: {total_flops / 1e12:.2f}, \"\n          f\"cost: {cost:3f}, \"\n          f\"TFLOPS : {total_flops / cost / 1e12:.2f}\"\"\")\n\n\nfor n in [8192]:\n    for init_method in [\"nans\", \"full\", \"zeros\", \"ones\",\n                        \"randn\", \"uniform\", \"uniform+\", \"ones+randn\"]:\n        benchmark(n, n, n, \"float16\", init_method)\n"
  },
  {
    "path": "benchmark/deepspeed/README.md",
    "content": "# Benchmark Deepspeed\n\n## Requirements\n1. Install dependencies\n```\n# torch\npip3 install torch==1.8.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html\npip3 install nltk pandas sentencepiece boto3 pybind11 python-config\n\n# Adafactor optimizer\npip3 install torch-optimizer\n\n# pdsh\nsudo apt-get update\nsudo apt-get install pdsh\n\n# Apex\ngit clone https://github.com/NVIDIA/apex\ncd apex\n# Comment out the raised RuntimeError in setup.py if you get errors running the following command.\npip3 install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./\n```\n\n2. Install deepspeed and deepspeed examples\n```\npip3 install deepspeed==0.5.4\ngit clone --recursive https://github.com/microsoft/DeepSpeed.git\necho 'export DEEPSPEED_PATH=~/efs/DeepSpeed' >> ~/.bashrc   # use your own path\nsource ~/.bashrc\n\n# Replace source files (use your own path)\ncp alpa/benchmark/deepspeed/patch/training.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/training.py\ncp alpa/benchmark/deepspeed/patch/gpt2_model.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/model/gpt2_model.py\ncp alpa/benchmark/deepspeed/patch/transformer.py DeepSpeed/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/megatron/model/transformer.py\n```\n\n3. Download dataset\n```\nwget deepspeed_dataset.zip  # ask Lianmin to get the file\ntar xzf deepspeed_dataset.zip\ncd deepspeed_dataset/\nln -s $(pwd) ~/efs/alpa/benchmark/deepspeed/data   # use your own path\n```\n\n## Run\n### Single Node\n```\n# GPT\npython3 benchmark_gpt2.py --nproc_per_node 8\n# MOE\npython3 benchmark_gpt2_moe.py --nproc_per_node 8\n```\n\n### Multiple Node\n- Modify the [hostfile](https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node) and setup the ssh connections.\n```\npython3 benchmark_gpt2.py --nnodes 2 --nproc_per_node 8\n```\n"
  },
  {
    "path": "benchmark/deepspeed/benchmark_gpt2.py",
    "content": "import argparse\nimport os\nimport random\n\nfrom util import run_cmd\n\n# B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size,\n# #head = num_heads, DP = dp_size, TMP = tensor_mp_size, NB = num_micro_batches,\n# CK = checkpoint_activations, DS = use_deepspeed\n\nbenchmark_suite_1_gpu = [\n    #B,    S,    H,    L,  #head,     V,     DP, TMP, NB, CK, DS\n    (16,   512,  1024, 10, 1024//64,  25600, 1,  1,   1,  0,  1),\n    (8,    1024, 1536, 10, 1536//96,  25600, 1,  1,   1,  0,  1),\n]\n\nbenchmark_suite_8_gpu = [\n    #B,    S,    H,    L,  #head,     V,     DP, TMP, NB, CK, DS\n    (256,  512,  1024, 10, 1024//64,  25600, 8,  1,   1,  0,  1),\n    (8,    1024, 4096, 10, 4096//128, 25600, 1,  8,   1,  0,  1),\n    (8,    1024, 4096, 10, 4096//128, 25600, 8,  1,   1,  0,  1),\n]\n\nbenchmark_suite_16_gpu = [\n    #B,    S,    H,    L,  #head,     V,     DP, TMP, NB, CK, DS\n    (512,  512,  1024, 10, 1024//64,  25600, 16, 1,   1,  0,  1),\n    (2048, 512,  1024, 10, 1024//64,  25600, 16, 1,   4,  0,  1),\n    (16,   1024, 4096, 10, 4096//128, 25600, 2,  8,   1,  0,  1),\n    (64,   1024, 4096, 10, 4096//128, 25600, 2,  8,   4,  0,  1),\n    (16,   1024, 4096, 10, 4096//128, 25600, 16, 1,   1,  0,  1),\n    (64,   1024, 4096, 10, 4096//128, 25600, 16, 1,   4,  0,  1),\n]\n\n\ndef update_ds_config(filename, gradient_accumulation_steps):\n    lines = list(open(filename))\n\n    for i in range(len(lines)):\n        if \"gradient_accumulation_steps\" in lines[i]:\n            idx = lines[i].index(\":\")\n            lines[i] = lines[i][:idx] + f\": {gradient_accumulation_steps},\\n\"\n\n    with open(filename, \"w\") as fout:\n        fout.writelines(lines)\n\n\ndef benchmark_all(args):\n    num_gpus = args.nproc_per_node * args.nnodes\n\n    benchmark_suites = {\n        1 : benchmark_suite_1_gpu,\n        8 : benchmark_suite_8_gpu,\n        16 : benchmark_suite_16_gpu,\n    }\n\n    warmup_iter = 2\n    bench_iter = 3\n    config_file = \"ds_zero_stage_2_config.json\"\n\n    for case in benchmark_suites[num_gpus]:\n        batch_size, seq_len, hidden_size, num_layers, num_heads, vocab_size,\\\n        dp_size, tensor_mp_size, num_micro_batches, checkpoint_activations, use_deepspeed\\\n            = case\n\n        assert dp_size * tensor_mp_size == num_gpus\n        assert batch_size % dp_size == 0\n        assert batch_size & num_micro_batches == 0\n\n        gpt_options = (\n            f\"--model-parallel-size {tensor_mp_size} \"\n            f\"--num-layers {num_layers} \"\n            f\"--hidden-size {hidden_size} \"\n            f\"--num-attention-heads {num_heads} \"\n            f\"--seq-length {seq_len} \"\n            f\"--max-position-embeddings {seq_len} \"\n            f\"--batch-size {batch_size // dp_size // num_micro_batches} \"\n            f\"--train-iters {(warmup_iter + bench_iter) * num_micro_batches} \"\n            f\"--lr-decay-iters 320000 \"\n            #f\"--save $CHECKPOINT_PATH \"\n            #f\"--load $CHECKPOINT_PATH \"\n            f\"--data-path data/small-webtext \"\n            f\"--vocab-file data/gpt2-vocab.json \"\n            f\"--merge-file data/gpt2-merges.txt \"\n            f\"--data-impl mmap \"\n            f\"--split 949,50,1 \"\n            f\"--distributed-backend nccl \"\n            f\"--lr 1.5e-4 \"\n            f\"--lr-decay-style cosine \"\n            f\"--min-lr 1.0e-5 \"\n            f\"--weight-decay 1e-2 \"\n            f\"--clip-grad 1.0 \"\n            f\"--warmup 0.01 \"\n            f\"--log-interval 1 \"\n            f\"--save-interval 10000 \"\n            f\"--eval-interval 2000 \"\n            f\"--eval-iters 0 \"\n            f\"--fp16 \"\n            f\"--loss-scale 1.0 \"\n            f\"--scattered-embeddings \"\n            f\"--split-transformers \"\n\n            # Disable fusion optimizations because this makes\n            # loading too slow.\n            #f\"--scaled-upper-triang-masked-softmax-fusion \"\n            #f\"--scaled-masked-softmax-fusion \"\n            #f\"--bias-gelu-fusion \"\n            #f\"--bias-dropout-fusion \"\n        )\n\n        if use_deepspeed:\n            gpt_options += (\n                \"--deepspeed \"\n                f\"--deepspeed_config {config_file} \"\n            )\n            update_ds_config(config_file, num_micro_batches)\n\n        if checkpoint_activations:\n            gpt_options += \"--checkpoint-activations \"\n            gpt_options += \"--deepspeed-activation-checkpointing \"\n            gpt_options += \"--checkpoint-num-layers 1 \"\n\n            # Disable other checkpoint optimizations\n            # gpt_options += \"--partition-activations \"\n            # gpt_options += \"--checkpoint-in-cpu \"\n            # gpt_options += \"--synchronize-each-layer \"\n            # gpt_options += \"--ontigious-checkpointing \"\n\n        if args.nnodes > 1:\n            host_options = \"--hostfile hostfile \"\n        else:\n            host_options = \"\"\n\n        work_dir= os.environ[\"DEEPSPEED_PATH\"] + \"/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/\"\n        ret = run_cmd(f\"PYTHONPATH={work_dir} PYTHON_VOCAB_SIZE={vocab_size} deepspeed \"\n                      f\"{host_options}\"\n                      f\"--num_nodes {args.nnodes} \"\n                      f\"--master_port {random.randint(10000, 20000)} \"\n                      f\"--num_gpus {args.nproc_per_node} \"\n                      f\"pretrain_gpt2.py {gpt_options}\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"gpt\")\n    parser.add_argument(\"--nnodes\", type=int, default=1)\n    parser.add_argument(\"--nproc_per_node\", type=int, required=True)\n    args = parser.parse_args()\n\n    benchmark_all(args)\n"
  },
  {
    "path": "benchmark/deepspeed/benchmark_moe.py",
    "content": "import time\n\nfrom datetime import datetime\n\nimport argparse\nimport os\nimport random\n\nfrom util import run_cmd\nfrom benchmark.alpa import suite_manual_moe\n\n# B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size\n# #head = num_heads, S_ = expert_group_size, E = expert_number,\n# D0 = mesh_dimension_0, D1 = mesh_dimension_1,\n# NB = num_micro_batches, FD = force_data_parallel,\n# CK = use_checkpoint,\n# DS = use_deepspeed\n\n\nbenchmark_suites = {\n    \"paper_moe\": suite_manual_moe.grid_search_manual,\n    \"test_moe\": suite_manual_moe.tmp_suite,\n}\n\n\ndef update_ds_config(filename, gradient_accumulation_steps):\n    lines = list(open(filename))\n\n    for i in range(len(lines)):\n        if \"gradient_accumulation_steps\" in lines[i]:\n            idx = lines[i].index(\":\")\n            lines[i] = lines[i][:idx] + f\": {gradient_accumulation_steps},\\n\"\n\n    with open(filename, \"w\") as fout:\n        fout.writelines(lines)\n\n\ndef benchmark_all(args):\n    num_gpus = args.nproc_per_node * args.nnodes\n\n    try:\n        _ = benchmark_suites[args.suite][num_gpus]\n    except KeyError:\n        print(f\"No available benchmark suite for {args.suite} with {num_gpus} GPUs.\")\n        exit()\n    output_name = args.exp_name + \"-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n\n    warmup_iter = 2\n    bench_iter = 3\n\n    # MOE does not support stage 3\n    config_file = \"ds_zero_stage_2_moe_config.json\"\n\n    for case in benchmark_suites[args.suite][num_gpus]:\n        print(\">>>>>> Alpa benchmark: Working on case {}...\".format(str(case)), flush=True)\n\n        (batch_size, model_config, num_micro_batches, parallel_mode,\n         parallel_args) = case\n        (seq_len, hidden_size, num_layers, num_heads, vocab_size, num_expert,\n         expert_group_size) = model_config\n\n        (prefer_reduce_scatter, checkpoint_activations, dp_size, tensor_mp_size, pipeline_mp_size,\n         _) = parallel_args\n\n        # TODO (hao, zhuohan): Figure out how to set ep_size\n\n        use_deepspeed = True\n\n        assert dp_size * tensor_mp_size == num_gpus\n        assert batch_size % dp_size == 0\n        assert batch_size % num_micro_batches == 0\n\n        gpt_options = (\n            f\"--model-parallel-size {tensor_mp_size} \"\n            f\"--num-layers {num_layers} \"\n            f\"--hidden-size {hidden_size} \"\n            f\"--num-attention-heads {num_heads} \"\n            f\"--seq-length {seq_len} \"\n            f\"--max-position-embeddings {seq_len} \"\n            f\"--batch-size {batch_size // dp_size // num_micro_batches} \"\n            f\"--train-iters {(warmup_iter + bench_iter) * num_micro_batches} \"\n            f\"--lr-decay-iters 320000 \"\n            #f\"--save $CHECKPOINT_PATH \"\n            #f\"--load $CHECKPOINT_PATH \"\n            f\"--data-path data/small-webtext \"\n            f\"--vocab-file data/gpt2-vocab.json \"\n            f\"--merge-file data/gpt2-merges.txt \"\n            f\"--data-impl mmap \"\n            f\"--split 949,50,1 \"\n            f\"--distributed-backend nccl \"\n            f\"--lr 1.5e-4 \"\n            f\"--lr-decay-style cosine \"\n            f\"--min-lr 1.0e-5 \"\n            f\"--weight-decay 1e-2 \"\n            f\"--clip-grad 1.0 \"\n            f\"--warmup 0.01 \"\n            f\"--log-interval 1 \"\n            f\"--save-interval 10000 \"\n            f\"--eval-interval 2000 \"\n            f\"--eval-iters 0 \"\n            f\"--fp16 \"\n            f\"--loss-scale 1.0 \"\n            f\"--scattered-embeddings \"\n            f\"--split-transformers \"\n\n            # Disable fusion optimizations because this makes\n            # loading too slow.\n            #f\"--scaled-upper-triang-masked-softmax-fusion \"\n            #f\"--scaled-masked-softmax-fusion \"\n            #f\"--bias-gelu-fusion \"\n            #f\"--bias-dropout-fusion \"\n        )\n\n        if use_deepspeed:\n            gpt_options += (\n                \"--deepspeed \"\n                f\"--deepspeed_config {config_file} \"\n            )\n            update_ds_config(config_file, num_micro_batches)\n\n        if checkpoint_activations:\n            gpt_options += \"--checkpoint-activations \"\n            gpt_options += \"--deepspeed-activation-checkpointing \"\n            gpt_options += \"--checkpoint-num-layers 1 \"\n\n            # Disable other checkpoint optimizations\n            # gpt_options += \"--partition-activations \"\n            # gpt_options += \"--checkpoint-in-cpu \"\n            # gpt_options += \"--synchronize-each-layer \"\n            # gpt_options += \"--ontigious-checkpointing \"\n\n        if num_expert > 1:\n            gpt_options += \"--moe \"\n            gpt_options += \"--ep-world-size {} \".format(ep_size)\n            gpt_options += \"--num-experts {} \".format(str(num_expert))\n            gpt_options += \"--top-k 2 \"\n            gpt_options += \"--min-capacity 4 \"\n            gpt_options += \"--noisy-gate-policy None \"\n            gpt_options += \"--moe-param-group \"\n            gpt_options += \"--output_name {}\".format(output_name)\n\n        if args.nnodes > 1:\n            host_options = \"--hostfile hostfile_{}node \".format(args.nnodes)\n        else:\n            host_options = \"\"\n\n        work_dir= os.environ[\"DEEPSPEED_PATH\"] + \"/DeepSpeedExamples/Megatron-LM-v1.1.5-ZeRO3/\"\n        ret = run_cmd(f\"PYTHONPATH={work_dir} PYTHON_VOCAB_SIZE={vocab_size} deepspeed \"\n                      f\"{host_options}\"\n                      f\"--num_nodes {args.nnodes} \"\n                      f\"--master_port {random.randint(30000, 40000)} \"\n                      f\"--num_gpus {args.nproc_per_node} \"\n                      f\"pretrain_gpt2_moe.py {gpt_options}\")\n        print(\">>>>>> Alpa benchmark: sleep for 30 seconds before starting the next case.\", flush=True)\n        time.sleep(30)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"gpt\")\n    parser.add_argument(\"--nnodes\", type=int, default=1)\n    parser.add_argument(\"--nproc_per_node\", type=int, required=True)\n    parser.add_argument(\"--suite\", type=str, default=\"paper_gpt\")\n    parser.add_argument(\"--exp_name\", type=str, default=\"none\")\n    args = parser.parse_args()\n\n    benchmark_all(args)\n"
  },
  {
    "path": "benchmark/deepspeed/ds_zero_stage_2_config.json",
    "content": "{\n  \"train_batch_size\": 8192,\n  \"gradient_accumulation_steps\": 4,\n  \"steps_per_print\": 1,\n  \"zero_optimization\": {\n    \"stage\": 2,\n    \"allgather_partitions\": true,\n    \"reduce_scatter\": true,\n    \"allgather_bucket_size\": 5e8,\n    \"reduce_bucket_size\": 5e8,\n    \"overlap_comm\": true,\n    \"contiguous_gradients\": true\n  },\n  \"optimizer\": {\n    \"type\": \"Adam\",\n    \"params\": {\n      \"lr\": 0.00015,\n      \"max_grad_norm\": 1.0,\n      \"betas\": [0.9, 0.95]\n    }\n  },\n  \"gradient_clipping\": 1.0,\n  \"fp16\": {\n    \"enabled\": true,\n    \"loss_scale\": 1.0,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"wall_clock_breakdown\": false,\n  \"zero_allow_untested_optimizer\": false\n}\n"
  },
  {
    "path": "benchmark/deepspeed/ds_zero_stage_2_moe_config.json",
    "content": "{\n  \"train_batch_size\": 8192,\n  \"gradient_accumulation_steps\": 4,\n  \"steps_per_print\": 1,\n  \"zero_optimization\": {\n    \"stage\": 2,\n    \"allgather_partitions\": true,\n    \"reduce_scatter\": true,\n    \"allgather_bucket_size\": 5e8,\n    \"reduce_bucket_size\": 5e8,\n    \"overlap_comm\": true,\n    \"contiguous_gradients\": true\n  },\n  \"gradient_clipping\": 1.0,\n  \"fp16\": {\n    \"enabled\": true,\n    \"loss_scale\": 1.0,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"wall_clock_breakdown\": false,\n  \"zero_allow_untested_optimizer\": true\n}\n"
  },
  {
    "path": "benchmark/deepspeed/ds_zero_stage_3_config.json",
    "content": "{\n  \"train_batch_size\": 8192,\n  \"gradient_accumulation_steps\": 1,\n  \"steps_per_print\": 1,\n  \"zero_optimization\": {\n    \"stage\": 3,\n    \"stage3_max_live_parameters\": 1e9,\n    \"stage3_max_reuse_distance\": 1e9,\n    \"stage3_prefetch_bucket_size\": 1e7,\n    \"stage3_param_persitence_threshold\": 1e5,\n    \"reduce_bucket_size\": 1e7,\n    \"contiguous_gradients\": true\n  },\n  \"optimizer\": {\n    \"type\": \"Adam\",\n    \"params\": {\n      \"lr\": 0.00015,\n      \"max_grad_norm\": 1.0,\n      \"betas\": [0.9, 0.95]\n    }\n  },\n  \"gradient_clipping\": 1.0,\n  \"fp16\": {\n    \"enabled\": true,\n    \"loss_scale\": 1.0,\n    \"loss_scale_window\": 1000,\n    \"hysteresis\": 2,\n    \"min_loss_scale\": 1\n  },\n  \"wall_clock_breakdown\": false,\n  \"zero_allow_untested_optimizer\": false\n}\n"
  },
  {
    "path": "benchmark/deepspeed/hostfile",
    "content": "172.31.19.47 slots=8\n172.31.27.46 slots=8\n"
  },
  {
    "path": "benchmark/deepspeed/killall_python.sh",
    "content": "kill -9 $(ps aux | grep 'python3' | grep -v 'grep' | awk '{print $2}')\n"
  },
  {
    "path": "benchmark/deepspeed/patch/gpt2_model.py",
    "content": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"GPT-2 model.\"\"\"\n\nimport torch\n\nfrom megatron import get_args\nfrom megatron import mpu\nfrom megatron.module import MegatronModule\n\nfrom .language_model import parallel_lm_logits\nfrom .language_model import get_language_model\nfrom .utils import init_method_normal\nfrom .utils import scaled_init_method_normal\n\nimport deepspeed\n\ndef gpt2_attention_mask_func(attention_scores, ltor_mask):\n    attention_scores.masked_fill_(ltor_mask, -10000.0)\n    return attention_scores\n\n\nclass GPT2Model(MegatronModule):\n    \"\"\"GPT-2 Language model.\"\"\"\n\n    def __init__(self, num_tokentypes=0, parallel_output=True):\n        super(GPT2Model, self).__init__()\n        args = get_args()\n\n        self.parallel_output = parallel_output\n        self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy\n\n        self.language_model, self._language_model_key = get_language_model(\n            attention_mask_func=gpt2_attention_mask_func,\n            num_tokentypes=num_tokentypes,\n            add_pooler=False,\n            init_method=init_method_normal(args.init_method_std),\n            scaled_init_method=scaled_init_method_normal(args.init_method_std,\n                                                         args.num_layers))\n\n\n    def forward(self, input_ids, position_ids, attention_mask, labels=None,\n                tokentype_ids=None, layer_past=None, get_key_value=False,\n                forward_method_parallel_output=None, curriculum_seqlen=None):\n        if curriculum_seqlen is not None:\n            args = get_args()\n            args.curriculum_seqlen = curriculum_seqlen\n            if curriculum_seqlen < input_ids.size()[1]:\n                # seqlen-based curriculum learning\n                # input_ids, position_ids, labels have size [batch size, seqlen]\n                input_ids = input_ids[:, :curriculum_seqlen].contiguous()\n                position_ids = position_ids[:, :curriculum_seqlen].contiguous()\n                labels = labels[:, :curriculum_seqlen].contiguous()\n\n                # attention_mask has size [1, 1, seqlen, seqlen]\n                attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous()\n\n        # Language model.\n        lm_output = self.language_model(input_ids,\n                                        position_ids,\n                                        attention_mask,\n                                        tokentype_ids=tokentype_ids,\n                                        layer_past=layer_past,\n                                        get_key_value=get_key_value)\n        if get_key_value:\n            lm_output, presents = lm_output\n\n        # Output.\n        parallel_output = self.parallel_output\n        if forward_method_parallel_output is not None:\n            parallel_output = forward_method_parallel_output\n\n        output = parallel_lm_logits(\n            lm_output,\n            self.language_model.embedding.word_embeddings.weight,\n            parallel_output)\n\n        if get_key_value:\n            output = [output, presents]\n\n        if labels is None:\n            return output\n        else:\n            if self.fp16_lm_cross_entropy:\n                assert output.dtype == torch.half\n                loss = mpu.vocab_parallel_cross_entropy(output, labels)\n            else:\n                loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)\n            return loss\n\n\n    def state_dict_for_save_checkpoint(self, destination=None, prefix='',\n                                       keep_vars=False):\n\n        state_dict_ = {}\n        state_dict_[self._language_model_key] \\\n            = self.language_model.state_dict_for_save_checkpoint(\n                destination, prefix, keep_vars)\n        return state_dict_\n\n    def load_state_dict(self, state_dict, strict=True):\n        \"\"\"Customized load.\"\"\"\n\n        if self._language_model_key in state_dict:\n            state_dict = state_dict[self._language_model_key]\n        self.language_model.load_state_dict(state_dict, strict=strict)\n"
  },
  {
    "path": "benchmark/deepspeed/patch/training.py",
    "content": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Pretrain utilities.\"\"\"\n\nfrom datetime import datetime\nimport math\nimport sys\nimport torch\nimport json\nfrom torch.nn.parallel.distributed import DistributedDataParallel as torchDDP\nfrom apex.optimizers import FusedAdam as Adam\n\nfrom megatron import get_args\nfrom megatron import get_timers\nfrom megatron import get_tensorboard_writer\nfrom megatron import mpu\nfrom megatron import print_rank_0\nfrom megatron.checkpointing import load_checkpoint\nfrom megatron.checkpointing import save_checkpoint\nfrom megatron.fp16 import FP16_Module\nfrom megatron.fp16 import FP16_Optimizer\nfrom megatron.initialize import initialize_megatron\nfrom megatron.learning_rates import AnnealingLR\nfrom megatron.model import DistributedDataParallel as LocalDDP\nfrom megatron.model import get_params_for_weight_decay_optimization\nfrom megatron.model.realm_model import ICTBertModel\nfrom megatron.utils import check_adlr_autoresume_termination\nfrom megatron.utils import make_data_loader\nfrom megatron.utils import report_memory, flops_calculator\n\nimport deepspeed\nfrom deepspeed.runtime.utils import see_memory_usage\n\n\ndef pretrain(train_valid_test_dataset_provider, model_provider,\n             forward_step_func, extra_args_provider=None, args_defaults={}):\n    \"\"\"Main training program.\n\n    This function will run the followings in the order provided:\n        1) initialize Megatron.\n        2) setup model, optimizer and lr schedule using the model_provider.\n        3) call train_val_test_data_provider to get train/val/test datasets.\n        4) train the modle using the forward_step_func.\n\n    Arguments:\n        train_valid_test_dataset_provider: a function that takes the size of\n            train/valid/test dataset and returns `train, valid, test` datasets.\n        model_provider: a function that returns a vanilla version of the\n            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.\n        forward_step_func: a function that takes a `data iterator` and `model`,\n            and returns a `loss` scalar with a dictionary with key:values being\n            the info we would like to monitor during training, for example\n            `lm-loss: value`. We also require that this function add\n            `batch generator` to the timers class.\n        extra_args_provider: a function that takes a parser and adds arguments\n            to it. It is used for programs to add their own arguments.\n        args_defaults: a dictionary from argument-name to argument-value. It\n            to set already parse arguments.\n    \"\"\"\n\n    # Initalize and get arguments, timers, and Tensorboard writer.\n    initialize_megatron(extra_args_provider=extra_args_provider,\n                        args_defaults=args_defaults)\n\n    args = get_args()\n    timers = get_timers()\n\n    args.curriculum_learning = False\n    if args.deepspeed:\n        args.deepspeed_configuration = json.load(\n            open(args.deepspeed_config, 'r', encoding='utf-8'))\n        if \"curriculum_learning\" in args.deepspeed_configuration:\n            if \"enabled\" in args.deepspeed_configuration[\"curriculum_learning\"]:\n                args.curriculum_learning = args.deepspeed_configuration[\"curriculum_learning\"][\"enabled\"]\n\n    # Model, optimizer, and learning rate.\n    timers('model and optimizer').start()\n    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)\n    timers('model and optimizer').stop()\n\n    # Data stuff.\n    timers('train/valid/test data iterators').start()\n    train_data_iterator, valid_data_iterator, test_data_iterator \\\n        = build_train_valid_test_data_iterators(\n        train_valid_test_dataset_provider)\n    timers('train/valid/test data iterators').stop()\n\n    # Print setup timing.\n    print_rank_0('done with setups ...')\n    timers.log(['model and optimizer', 'train/valid/test data iterators'])\n    print_rank_0('training ...')\n\n    iteration = 0\n    if args.do_train and args.train_iters > 0:\n        iteration = train(forward_step_func,\n                          model, optimizer, lr_scheduler,\n                          train_data_iterator, valid_data_iterator)\n\n    if args.do_valid:\n        prefix = 'the end of training for val data'\n        evaluate_and_print_results(prefix, forward_step_func,\n                                   valid_data_iterator, model,\n                                   iteration, False)\n\n    if args.save and iteration != 0:\n        save_checkpoint(iteration, model, optimizer, lr_scheduler)\n\n    if args.do_test:\n        # Run on test data.\n        prefix = 'the end of training for test data'\n        evaluate_and_print_results(prefix, forward_step_func,\n                                   test_data_iterator, model,\n                                   0, True)\n\n\ndef get_model(model_provider_func):\n    \"\"\"Build the model.\"\"\"\n    args = get_args()\n\n    # Build model on cpu.\n    model = model_provider_func()\n\n    if args.deepspeed:\n        # DeepSpeed handles CUDA, FP16, and DDP components.\n        return model\n\n    # GPU allocation.\n    model.cuda(torch.cuda.current_device())\n\n    # Fp16 conversion.\n    if args.fp16:\n        model = FP16_Module(model)\n\n    # Wrap model for distributed training.\"\"\"\n    if args.DDP_impl == 'torch':\n        i = torch.cuda.current_device()\n        model = torchDDP(model, device_ids=[i], output_device=i,\n                         process_group=mpu.get_data_parallel_group())\n        return model\n    if args.DDP_impl == 'local':\n        model = LocalDDP(model)\n        return model\n\n    raise NotImplementedError('Unknown DDP implementation specified: {}. '\n                              'Exiting.'.format(args.DDP_impl))\n\n\ndef get_optimizer(model):\n    \"\"\"Set up the optimizer.\"\"\"\n    args = get_args()\n\n    # Build parameter groups (weight decay and non-decay).\n    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):\n        model = model.module\n\n    if args.moe_param_group:\n        param_groups = create_moe_param_groups(model)\n    else:\n        param_groups = get_params_for_weight_decay_optimization(model)\n\n    # Add model parallel attribute if it is not set.\n    for param_group in param_groups:\n        for param in param_group['params']:\n            if not hasattr(param, 'model_parallel'):\n                param.model_parallel = False\n\n    if args.cpu_optimizer:\n        if args.cpu_torch_adam:\n            cpu_adam_optimizer = torch.optim.AdamW\n        else:\n            from deepspeed.ops.adam import DeepSpeedCPUAdam\n            cpu_adam_optimizer = DeepSpeedCPUAdam\n        optimizer = cpu_adam_optimizer(param_groups,\n                                       lr=args.lr,\n                                       weight_decay=args.weight_decay)\n    else:\n        # Use torch Adam instead of Fused Adam from NVIDIA which seems to have some issue.\n        #optimizer = Adam(param_groups,\n        if args.moe:\n            import torch_optimizer as topt\n            optimizer = topt.Adafactor(param_groups,\n                                       lr=args.lr,\n                                       weight_decay=args.weight_decay,\n                                       beta1=args.adam_beta1,\n                                       eps2=(1e-30, 1e-3))\n            print(\">>>>>> Alpa benchmark: we're using the {} optimizer.\".format(type(optimizer)))\n        else:\n            optimizer = torch.optim.AdamW(param_groups,\n                                          lr=args.lr,\n                                          weight_decay=args.weight_decay,\n                                          betas=(args.adam_beta1, args.adam_beta2),\n                                          eps=args.adam_eps)\n\n    if args.deepspeed:\n        # fp16 wrapper is not required for DeepSpeed.\n        return optimizer\n\n    # Wrap into fp16 optimizer.\n    if args.fp16:\n        optimizer = FP16_Optimizer(optimizer,\n                                   static_loss_scale=args.loss_scale,\n                                   dynamic_loss_scale=args.dynamic_loss_scale,\n                                   dynamic_loss_args={\n                                       'scale_window': args.loss_scale_window,\n                                       'min_scale': args.min_scale,\n                                       'delayed_shift': args.hysteresis})\n\n    return optimizer\n\n\ndef get_learning_rate_scheduler(optimizer):\n    \"\"\"Build the learning rate scheduler.\"\"\"\n    args = get_args()\n\n    # Add linear learning rate scheduler.\n    if args.lr_decay_iters is not None:\n        num_iters = args.lr_decay_iters\n    else:\n        num_iters = args.train_iters\n    num_iters = max(1, num_iters)\n    init_step = 0\n    if args.warmup_iters is not None:\n        warmup_iter = args.warmup_iters\n    else:\n        warmup_iter = args.warmup * num_iters\n    lr_scheduler = AnnealingLR(\n        optimizer,\n        start_lr=args.lr,\n        warmup_iter=warmup_iter,\n        total_iters=num_iters,\n        decay_style=args.lr_decay_style,\n        last_iter=init_step,\n        min_lr=args.min_lr,\n        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,\n        override_lr_scheduler=args.override_lr_scheduler)\n\n    return lr_scheduler\n\n\ndef create_moe_param_groups(model):\n    from deepspeed.moe.utils import is_moe_param\n\n    params_with_weight_decay = {'params': [], 'name': 'weight_decay_params'}\n    moe_params_with_weight_decay = {\n        'params': [],\n        'moe': True,\n        'name': 'weight_decay_moe_params'\n    }\n\n    for module_ in model.modules():\n        moe_params_with_weight_decay['params'].extend([\n            p for n, p in list(module_._parameters.items())\n            if p is not None and is_moe_param(p)\n        ])\n        params_with_weight_decay['params'].extend([\n            p for n, p in list(module_._parameters.items())\n            if p is not None and not is_moe_param(p)\n        ])\n\n    return params_with_weight_decay, moe_params_with_weight_decay\n\n\ndef setup_model_and_optimizer(model_provider_func):\n    \"\"\"Setup model and optimizer.\"\"\"\n    args = get_args()\n\n    if args.deepspeed and args.moe:\n        print(\">>>>>>> ep_size {}..\".format(args.ep_world_size))\n        deepspeed.utils.groups.initialize(ep_size=args.ep_world_size, mpu=mpu)\n\n    model = get_model(model_provider_func)\n\n    parameters = filter(lambda p: p.requires_grad, model.parameters())\n    if args.moe_param_group:\n        parameters = create_moe_param_groups(model)\n\n    optimizer = get_optimizer(model)\n    lr_scheduler = get_learning_rate_scheduler(optimizer)\n\n    if args.deepspeed:\n        print_rank_0(\"DeepSpeed is enabled.\")\n        model, optimizer, _, lr_scheduler = deepspeed.initialize(\n            model=model,\n            optimizer=optimizer,\n            args=args,\n            lr_scheduler=lr_scheduler,\n            mpu=mpu,\n            dist_init_required=False,\n            model_parameters=parameters)\n    if args.load is not None:\n        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)\n    else:\n        args.iteration = 0\n\n    # get model without FP16 and/or TorchDDP wrappers\n    unwrapped_model = model\n    while hasattr(unwrapped_model, 'module'):\n        unwrapped_model = unwrapped_model.module\n\n    if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):\n        print(\"Initializing ICT from pretrained BERT model\", flush=True)\n        unwrapped_model.init_state_dict_from_bert()\n\n    return model, optimizer, lr_scheduler\n\n\ndef backward_step(optimizer, model, loss):\n    \"\"\"Backward step.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    # Backward pass.\n    timers('backward-backward').start()\n    if args.deepspeed:\n        model.backward(loss)\n    else:\n        optimizer.zero_grad(set_grads_to_None=True)\n        if args.fp16:\n            optimizer.backward(loss, update_master_grads=False)\n        else:\n            loss.backward()\n    timers('backward-backward').stop()\n\n    if args.deepspeed:\n        # DeepSpeed backward propagation already addressed all reduce communication.\n        # Reset the timer to avoid breaking timer logs below.\n        timers('backward-allreduce').reset()\n    else:\n        # All-reduce if needed.\n        if args.DDP_impl == 'local':\n            timers('backward-allreduce').start()\n            model.allreduce_params(reduce_after=False,\n                                   fp32_allreduce=args.fp32_allreduce)\n            timers('backward-allreduce').stop()\n\n    if not args.deepspeed:\n        # Update master gradients.\n        timers('backward-master-grad').start()\n        if args.fp16:\n            optimizer.update_master_grads()\n        timers('backward-master-grad').stop()\n\n        # Clipping gradients helps prevent the exploding gradient.\n        timers('backward-clip-grad').start()\n        if args.clip_grad > 0:\n            if not args.fp16:\n                mpu.clip_grad_norm(model.parameters(), args.clip_grad)\n            else:\n                optimizer.clip_master_grads(args.clip_grad)\n        timers('backward-clip-grad').stop()\n\nimport time\nglobal step_latencies\nstep_latencies = []\n\ndef train_step(forward_step_func, data_iterator,\n               model, optimizer, lr_scheduler):\n    \"\"\"Single training step.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    #see_memory_usage(f'before forward {model.global_steps}', force=True)\n    # Forward model for one step.\n    timers('forward').start()\n    tic = time.time()\n    loss, loss_reduced = forward_step_func(data_iterator, model, args.curriculum_learning)\n    timers('forward').stop()\n\n    #see_memory_usage(f'before backward {model.global_steps}', force=True)\n    # Calculate gradients, reduce across processes, and clip.\n    timers('backward').start()\n    backward_step(optimizer, model, loss)\n    timers('backward').stop()\n\n\n    #see_memory_usage(f'before optimizer {model.global_steps}', force=True)\n    # Update parameters.\n    skipped_iter = 0\n    timers('optimizer').start()\n    if args.deepspeed:\n        model.step()\n    else:\n        optimizer.step()\n        # Update learning rate.\n        if not (args.fp16 and optimizer.overflow):\n            lr_scheduler.step()\n        else:\n            skipped_iter = 1\n    timers('optimizer').stop()\n\n    step_latencies.append(time.time() - tic - timers('batch generator').elapsed(reset=False))\n\n    return loss_reduced, skipped_iter\n\n\ndef training_log(loss_dict, total_loss_dict, learning_rate, iteration,\n                 loss_scale, report_memory_flag, skipped_iter, model=None):\n    \"\"\"Log training information such as losses, timing, ....\"\"\"\n    args = get_args()\n    timers = get_timers()\n    writer = get_tensorboard_writer()\n\n    # Update losses.\n    skipped_iters_key = 'skipped iterations'\n    total_loss_dict[skipped_iters_key] = total_loss_dict.get(\n        skipped_iters_key, 0) + skipped_iter\n    got_nan_key = 'got nan'\n\n    got_nan = False\n    for key in loss_dict:\n        if not skipped_iter:\n            total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]\n        else:\n            value = loss_dict[key].float().sum().item()\n            is_nan = value == float('inf') or \\\n                     value == -float('inf') or \\\n                     value != value\n            got_nan = got_nan or is_nan\n\n    total_loss_dict[got_nan_key] = total_loss_dict.get(\n        got_nan_key, 0) + int(got_nan)\n\n    # Logging.\n    timers_to_log = []\n\n    def add_to_logging(name):\n        if name in timers.timers:\n            timers_to_log.append(name)\n    add_to_logging('forward')\n    add_to_logging('backward')\n    add_to_logging('backward-backward')\n    add_to_logging('backward-allreduce')\n    add_to_logging('backward-master-grad')\n    add_to_logging('backward-clip-grad')\n    add_to_logging('optimizer')\n    add_to_logging('batch generator')\n\n    # Tensorboard values.\n    if writer and torch.distributed.get_rank() == 0:\n        writer.add_scalar('tokens', args.tokens, iteration)\n        writer.add_scalar('learning_rate', learning_rate, iteration)\n        writer.add_scalar('learning_rate/vs tokens', learning_rate, args.tokens)\n        if args.curriculum_learning:\n            writer.add_scalar('seqlen',\n                              args.curriculum_seqlen, iteration)\n            writer.add_scalar('seqlen/vs tokens',\n                              args.curriculum_seqlen, args.tokens)\n        for key in loss_dict:\n            writer.add_scalar(key, loss_dict[key], iteration)\n            writer.add_scalar(key + '/vs tokens', loss_dict[key], args.tokens)\n        if args.fp16:\n            writer.add_scalar('loss_scale', loss_scale, iteration)\n        normalizer = iteration % args.log_interval\n        if normalizer == 0:\n            normalizer = args.log_interval\n        timers.write(timers_to_log, writer, iteration,\n                     normalizer=normalizer)\n\n    if iteration % args.log_interval == 0:\n        elapsed_time = timers('interval time').elapsed()\n        if writer and torch.distributed.get_rank() == 0:\n            writer.add_scalar('iteration_time',\n                              elapsed_time / args.log_interval, iteration)\n        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,\n                                                       args.train_iters)\n        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(\n            elapsed_time * 1000.0 / args.log_interval)\n        log_string += ' learning rate: {:.3E} |'.format(learning_rate)\n        num_iterations = max(\n            1, args.log_interval - total_loss_dict[skipped_iters_key])\n        for key in total_loss_dict:\n            if key not in [skipped_iters_key, got_nan_key]:\n                avg = total_loss_dict[key].item() / float(num_iterations)\n                log_string += ' {}: {:.6E} |'.format(key, avg)\n                total_loss_dict[key] = 0.0\n        if args.fp16:\n            log_string += ' loss scale: {:.1f} |'.format(loss_scale)\n        log_string += ' number of skipped iterations: {:3d} |'.format(\n            total_loss_dict[skipped_iters_key])\n        log_string += ' number of nan iterations: {:3d} |'.format(\n            total_loss_dict[got_nan_key])\n        total_loss_dict[skipped_iters_key] = 0\n        total_loss_dict[got_nan_key] = 0\n        print_rank_0(log_string)\n        if report_memory_flag:\n            report_memory('after {} iterations'.format(iteration))\n            report_memory_flag = False\n        timers.log(timers_to_log, normalizer=args.log_interval)\n        flops_calculator(model, args, elapsed_time)\n\n    return report_memory_flag\n\n\ndef train(forward_step_func, model, optimizer, lr_scheduler,\n          train_data_iterator, valid_data_iterator):\n    \"\"\"Train the model function.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    # Turn on training mode which enables dropout.\n    model.train()\n\n    # Tracking loss.\n    total_loss_dict = {}\n\n    # Iterations.\n    iteration = args.iteration\n\n    timers('interval time').start()\n    report_memory_flag = True\n    data_parallel_size = mpu.get_data_parallel_world_size()\n    global_batch_size = args.batch_size * data_parallel_size\n    while iteration < args.train_iters and \\\n            (args.train_tokens is None or args.tokens < args.train_tokens):\n        loss_dict, skipped_iter = train_step(forward_step_func,\n                                             train_data_iterator,\n                                             model,\n                                             optimizer,\n                                             lr_scheduler)\n        iteration += 1\n        if args.curriculum_learning:\n            args.tokens += global_batch_size * args.curriculum_seqlen\n        else:\n            args.tokens += global_batch_size * args.seq_length\n\n        # Logging.\n        loss_scale = None\n        if args.fp16:\n            loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale\n        report_memory_flag = training_log(loss_dict, total_loss_dict,\n                                          optimizer.param_groups[0]['lr'],\n                                          iteration, loss_scale,\n                                          report_memory_flag, skipped_iter,\n                                          model=model)\n\n        # Autoresume\n        if args.adlr_autoresume and \\\n                (iteration % args.adlr_autoresume_interval == 0):\n            check_adlr_autoresume_termination(iteration, model, optimizer,\n                                              lr_scheduler)\n\n        # Checkpointing\n        if args.save and args.save_interval and \\\n                iteration % args.save_interval == 0:\n            save_checkpoint(iteration, model, optimizer, lr_scheduler)\n\n        # Evaluation\n        # XXX temporarily disabled for ZeRO-3\n        \"\"\"\n        if args.eval_interval and iteration % args.eval_interval == 0 and \\\n           args.do_valid:\n            prefix = 'iteration {}'.format(iteration)\n            evaluate_and_print_results(prefix, forward_step_func,\n                                       valid_data_iterator, model,\n                                       iteration, False)\n        \"\"\"\n\n        if args.exit_interval and iteration % args.exit_interval == 0:\n            torch.distributed.barrier()\n            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n            rank = torch.distributed.get_rank()\n            print_rank_0('rank: {} | time: {} | exiting the program at '\n                         'iteration {}'.format(rank, time_str, iteration))\n            sys.exit()\n\n    return iteration\n\n\ndef evaluate(forward_step_func, data_iterator, model, verbose=False):\n    \"\"\"Evaluation.\"\"\"\n    args = get_args()\n\n    # Turn on evaluation mode which disables dropout.\n    model.eval()\n\n    total_loss_dict = {}\n\n    with torch.no_grad():\n        iteration = 0\n        while iteration < args.eval_iters:\n            iteration += 1\n            if verbose and iteration % args.log_interval == 0:\n                print_rank_0('Evaluating iter {}/{}'.format(iteration,\n                                                            args.eval_iters))\n            # Forward evaluation.\n            _, loss_dict = forward_step_func(data_iterator, model)\n\n            # When contiguous memory optimizations are enabled, the buffers\n            # allocated by the optimizations are deallocated during backward pass\n            # in the absence of backward pass the buffers should be reset after each\n            # forward pass\n            if args.deepspeed and args.deepspeed_activation_checkpointing:\n                deepspeed.checkpointing.reset()\n\n            # Reduce across processes.\n            for key in loss_dict:\n                total_loss_dict[key] = total_loss_dict.get(key, 0.) + \\\n                                       loss_dict[key]\n    # Move model back to the train mode.\n    model.train()\n\n    for key in total_loss_dict:\n        total_loss_dict[key] /= args.eval_iters\n\n    return total_loss_dict\n\n\ndef evaluate_and_print_results(prefix, forward_step_func,\n                               data_iterator, model,\n                               iteration, verbose=False):\n    \"\"\"Helper function to evaluate and dump results on screen.\"\"\"\n    writer = get_tensorboard_writer()\n    args = get_args()\n\n    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)\n    string = ' validation loss at {} | '.format(prefix)\n    for key in total_loss_dict:\n        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())\n        ppl = math.exp(min(20, total_loss_dict[key].item()))\n        string += '{} PPL: {:.6E} | '.format(key, ppl)\n        if writer and torch.distributed.get_rank() == 0:\n            writer.add_scalar('{} value'.format(key),\n                              total_loss_dict[key].item(),\n                              iteration)\n            writer.add_scalar('{} value/vs tokens'.format(key),\n                              total_loss_dict[key].item(),\n                              args.tokens)\n            writer.add_scalar('{} ppl'.format(key), ppl, iteration)\n            writer.add_scalar('{} ppl/vs tokens'.format(key), ppl, args.tokens)\n\n    length = len(string) + 1\n    print_rank_0('-' * length)\n    print_rank_0(string)\n    print_rank_0('-' * length)\n\n\ndef build_train_valid_test_data_iterators(\n        build_train_valid_test_datasets_provider):\n    \"\"\"XXX\"\"\"\n    args = get_args()\n\n    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)\n\n    print_rank_0('> building train, validation, and test datasets ...')\n    # Data loader only on rank 0 of each model parallel group.\n    if mpu.get_model_parallel_rank() == 0:\n        # Rank, size, and global batch size.\n        data_parallel_size = mpu.get_data_parallel_world_size()\n        global_batch_size = args.batch_size * data_parallel_size\n\n        # Number of train/valid/test samples.\n        train_iters = args.train_iters\n        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters\n        test_iters = args.eval_iters\n        train_val_test_num_samples = [train_iters * global_batch_size,\n                                      eval_iters * global_batch_size,\n                                      test_iters * global_batch_size]\n        print_rank_0(' > datasets target sizes (minimum size):')\n        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))\n        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))\n        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))\n\n        # Build the datasets.\n        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(\n            train_val_test_num_samples)\n\n        # Build dataloders.\n        train_dataloader = make_data_loader(train_ds)\n        valid_dataloader = make_data_loader(valid_ds)\n        test_dataloader = make_data_loader(test_ds)\n\n        # Flags to know if we need to do training/validation/testing.\n        do_train = train_dataloader is not None and args.train_iters > 0\n        do_valid = valid_dataloader is not None and args.eval_iters > 0\n        do_test = test_dataloader is not None and args.eval_iters > 0\n        # Need to broadcast num_tokens and num_type_tokens.\n        flags = torch.cuda.LongTensor(\n            [int(do_train), int(do_valid), int(do_test)])\n    else:\n        flags = torch.cuda.LongTensor([0, 0, 0])\n\n    # Broadcast num tokens.\n    torch.distributed.broadcast(flags,\n                                mpu.get_model_parallel_src_rank(),\n                                group=mpu.get_model_parallel_group())\n    args.do_train = flags[0].item()\n    args.do_valid = flags[1].item()\n    args.do_test = flags[2].item()\n\n    # Shift the start iterations.\n    if train_dataloader is not None:\n        train_dataloader.batch_sampler.start_iter = args.iteration % \\\n                                                    len(train_dataloader)\n        print_rank_0('setting training data start iteration to {}'.\n                     format(train_dataloader.batch_sampler.start_iter))\n    if valid_dataloader is not None:\n        start_iter_val = (args.iteration // args.eval_interval) * \\\n                         args.eval_iters\n        valid_dataloader.batch_sampler.start_iter = start_iter_val % \\\n                                                    len(valid_dataloader)\n        print_rank_0('setting validation data start iteration to {}'.\n                     format(valid_dataloader.batch_sampler.start_iter))\n\n    # Build iterators.\n    if train_dataloader is not None:\n        train_data_iterator = iter(train_dataloader)\n    else:\n        train_data_iterator = None\n\n    if valid_dataloader is not None:\n        valid_data_iterator = iter(valid_dataloader)\n    else:\n        valid_data_iterator = None\n\n    if test_dataloader is not None:\n        test_data_iterator = iter(test_dataloader)\n    else:\n        test_data_iterator = None\n\n    return train_data_iterator, valid_data_iterator, test_data_iterator\n"
  },
  {
    "path": "benchmark/deepspeed/patch/transformer.py",
    "content": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Transformer.\"\"\"\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\nfrom megatron import get_args\nfrom megatron import mpu\nfrom megatron.mpu import LayerNorm\nfrom megatron.module import MegatronModule\nfrom megatron.checkpointing import get_checkpoint_version\nfrom megatron.model.fused_softmax import FusedScaleMaskSoftmax\nfrom megatron.model.fused_bias_gelu import bias_gelu_impl\nfrom megatron.model.utils import openai_gelu, erf_gelu\n\nimport deepspeed\n\n# flags required to enable jit fusion kernels\ntorch._C._jit_set_profiling_mode(False)\ntorch._C._jit_set_profiling_executor(False)\ntorch._C._jit_override_can_fuse_on_cpu(True)\ntorch._C._jit_override_can_fuse_on_gpu(True)\n\n\"\"\" We use the following notation throughout this file:\n     h: hidden size\n     n: number of attention heads\n     p: number of model parallel partitions\n     np: n/p\n     hp: h/p\n     hn: h/n\n     b: batch size\n     s: sequence length\n     l: number of layers\n    Transformer takes input of size [s, b, h] and returns a\n    tensor of the same size. We use the following arguments:\n        hyperparameters: transformer hyperparameters\n        attention_mask_func: a function that takes `unmaksed-attention-scores`\n            with size [b, np, s, s] and an `attention-mask` and will apply\n            the masking. The function should return a masked score of the\n            same size [b, np, s, s].\n               masked-attention-scores = attention_mask_func(\n                                     unmaksed-attention-scores, attention-mask)\n\"\"\"\n\nclass ParallelMLP(MegatronModule):\n    \"\"\"MLP.\n\n    MLP will take the input with h hidden state, project it to 4*h\n    hidden dimension, perform nonlinear transformation, and project the\n    state back into h hidden dimension. At the end, dropout is also\n    applied.\n    \"\"\"\n\n    def __init__(self, init_method, output_layer_init_method):\n        super(ParallelMLP, self).__init__()\n        args = get_args()\n\n        # Project to 4h.\n        if not args.memory_centric_tiled_linear:\n            self.dense_h_to_4h = mpu.ColumnParallelLinear(\n                args.hidden_size,\n                8 * args.hidden_size,\n                gather_output=False,\n                init_method=init_method,\n                skip_bias_add=True)\n        else:\n            self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias(\n                in_features=args.hidden_size,\n                out_features=8*args.hidden_size,\n                linear_cls=mpu.ColumnParallelLinear,\n                in_splits=args.tile_factor,\n                out_splits=8*args.tile_factor,\n                combine_out_splits=True,\n                gather_output=False,\n                init_method=init_method,\n                skip_bias_add=True)\n\n        self.bias_gelu_fusion = args.bias_gelu_fusion\n        self.activation_func = F.gelu\n        if args.openai_gelu:\n            self.activation_func = openai_gelu\n        elif args.onnx_safe:\n            self.activation_func = erf_gelu\n\n        # Project back to h.\n        if not args.memory_centric_tiled_linear:\n            self.dense_4h_to_h = mpu.RowParallelLinear(\n                8 * args.hidden_size,\n                args.hidden_size,\n                input_is_parallel=True,\n                init_method=output_layer_init_method,\n                skip_bias_add=True)\n        else:\n            self.dense_4h_to_h = deepspeed.zero.TiledLinearReturnBias(\n                in_features=8*args.hidden_size,\n                out_features=args.hidden_size,\n                linear_cls=mpu.RowParallelLinear,\n                in_splits=8*args.tile_factor,\n                out_splits=args.tile_factor,\n                input_is_already_split=False,\n                combine_out_splits=True,\n                input_is_parallel=True,\n                init_method=output_layer_init_method,\n                skip_bias_add=True)\n         \n    def forward(self, hidden_states):\n\n        # [s, b, 4hp]\n        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)\n\n        if self.bias_gelu_fusion:\n            intermediate_parallel = \\\n                    bias_gelu_impl(intermediate_parallel, bias_parallel)\n        else:\n            intermediate_parallel = \\\n                self.activation_func(intermediate_parallel + bias_parallel)\n\n        # [s, b, h]\n        output, output_bias = self.dense_4h_to_h(intermediate_parallel)\n        return output, output_bias\n\n\nclass LinearReturnBias(torch.nn.Linear):\n    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):\n        super(LinearReturnBias, self).__init__(in_features, out_features,\n                                               bias=bias)\n\n    def forward(self, input):\n        return super().forward(input), self.state_dict()[\"bias\"]\n\n\nclass NormalMLP(MegatronModule):\n    \"\"\"MLP.\n\n    MLP will take the input with h hidden state, project it to 4*h\n    hidden dimension, perform nonlinear transformation, and project the\n    state back into h hidden dimension. At the end, dropout is also\n    applied.\n    \"\"\"\n\n    def __init__(self, init_method, output_layer_init_method):\n        super(NormalMLP, self).__init__()\n        args = get_args()\n\n        # Project to 4h.\n        if not args.memory_centric_tiled_linear:\n            self.dense_h_to_4h = mpu.ColumnParallelLinear(\n                args.hidden_size,\n                8 * args.hidden_size,\n                gather_output=False,\n                init_method=init_method,\n                skip_bias_add=True)\n            # self.dense_h_to_4h = LinearReturnBias(\n            #     args.hidden_size,\n            #     8 * args.hidden_size,\n            #     bias=True)\n        else:\n            self.dense_h_to_4h = deepspeed.zero.TiledLinearReturnBias(\n                in_features=args.hidden_size,\n                out_features=8*args.hidden_size,\n                linear_cls=mpu.ColumnParallelLinear,\n                in_splits=args.tile_factor,\n                out_splits=8*args.tile_factor,\n                combine_out_splits=True,\n                gather_output=False,\n                init_method=init_method,\n                skip_bias_add=True)\n\n        self.bias_gelu_fusion = args.bias_gelu_fusion\n        self.activation_func = F.gelu\n        if args.openai_gelu:\n            self.activation_func = openai_gelu\n        elif args.onnx_safe:\n            self.activation_func = erf_gelu\n\n        # Project back to h.\n        if not args.memory_centric_tiled_linear:\n            self.dense_4h_to_h = mpu.RowParallelLinear(\n                8 * args.hidden_size,\n                args.hidden_size,\n                input_is_parallel=True,\n                init_method=output_layer_init_method,\n                skip_bias_add=True)\n            # self.dense_4h_to_h = LinearReturnBias(\n            #     8 * args.hidden_size,\n            #     args.hidden_size,\n            #     bias=True)\n        else:\n            self.dense_4h_to_h = deepspeed.zero.TiledLinearReturnBias(\n                in_features=8*args.hidden_size,\n                out_features=args.hidden_size,\n                linear_cls=mpu.RowParallelLinear,\n                in_splits=8*args.tile_factor,\n                out_splits=args.tile_factor,\n                input_is_already_split=False,\n                combine_out_splits=True,\n                input_is_parallel=True,\n                init_method=output_layer_init_method,\n                skip_bias_add=True)\n\n    def forward(self, hidden_states):\n\n        # [s, b, 4hp]\n        intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)\n\n        if self.bias_gelu_fusion:\n            intermediate_parallel = \\\n                bias_gelu_impl(intermediate_parallel, bias_parallel)\n        else:\n            intermediate_parallel = \\\n                self.activation_func(intermediate_parallel + bias_parallel)\n\n        # [s, b, h]\n        output, output_bias = self.dense_4h_to_h(intermediate_parallel)\n        return output, output_bias\n\n\nclass ParallelSelfAttention(MegatronModule):\n    \"\"\"Parallel self-attention layer abstract class.\n\n    Self-attention layer takes input with size [b, s, h]\n    and returns output of the same size.\n    \"\"\"\n\n    def __init__(self, attention_mask_func, init_method,\n                 output_layer_init_method, layer_number):\n        super(ParallelSelfAttention, self).__init__()\n        args = get_args()\n        self.fp16 = args.fp16\n\n        self.attention_mask_func = attention_mask_func\n        self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling\n        self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32\n        if self.apply_query_key_layer_scaling:\n            self.attention_softmax_in_fp32 = True\n        self.layer_number = max(1, layer_number)\n\n        # Per attention head and per partition values.\n        world_size = mpu.get_model_parallel_world_size()\n        self.hidden_size_per_partition = mpu.divide(args.hidden_size,\n                                                    world_size)\n        self.hidden_size_per_attention_head = mpu.divide(\n            args.hidden_size, args.num_attention_heads)\n        self.num_attention_heads_per_partition = mpu.divide(\n            args.num_attention_heads, world_size)\n\n        # Strided linear layer.\n        if not args.memory_centric_tiled_linear:\n            self.query_key_value = mpu.ColumnParallelLinear(\n                args.hidden_size,\n                3 * args.hidden_size,\n                gather_output=False,\n                init_method=init_method)\n        else:\n            self.query_key_value = deepspeed.zero.TiledLinearReturnBias(\n                in_features=args.hidden_size,\n                out_features=3*args.hidden_size,\n                linear_cls=mpu.ColumnParallelLinear,\n                gather_output=False,\n                init_method=init_method,\n                in_splits=args.tile_factor,\n                out_splits=3*args.tile_factor,\n                combine_out_splits=True\n            )\n\n        coeff = None\n        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)\n        if self.apply_query_key_layer_scaling:\n            coeff = self.layer_number\n            self.norm_factor *= coeff\n\n        self.scale_mask_softmax = FusedScaleMaskSoftmax(\n            self.fp16,\n            args.scaled_upper_triang_masked_softmax_fusion,\n            args.scaled_masked_softmax_fusion,\n            self.attention_mask_func,\n            self.attention_softmax_in_fp32,\n            coeff)\n\n        # Dropout. Note that for a single iteration, this layer will generate\n        # different outputs on different number of parallel partitions but\n        # on average it should not be partition dependent.\n        self.attention_dropout = torch.nn.Dropout(args.attention_dropout)\n\n        # Output.\n        if not args.memory_centric_tiled_linear:\n            self.dense = mpu.RowParallelLinear(\n                args.hidden_size,\n                args.hidden_size,\n                input_is_parallel=True,\n                init_method=output_layer_init_method,\n                skip_bias_add=True)\n        else:\n            self.dense = deepspeed.zero.TiledLinearReturnBias(\n                in_features=args.hidden_size,\n                out_features=args.hidden_size,\n                linear_cls=mpu.RowParallelLinear,\n                input_is_parallel=True,\n                init_method=output_layer_init_method,\n                skip_bias_add=True,\n                out_splits=args.tile_factor,\n                in_splits=args.tile_factor,\n                combine_out_splits=True\n            )\n\n\n        if deepspeed.checkpointing.is_configured():\n            global get_cuda_rng_tracker, checkpoint\n            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker\n            checkpoint = deepspeed.checkpointing.checkpoint\n\n    def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):\n        input_shape = mixed_layer.size();\n        if num_splits_first:\n            \"\"\"[s, b, num_splits * np * hn] \n            -->(view) [s, b, num_splits, np, hn] \n            -->(tranpose) [s, b, np, num_splits, hn] \n            -->(view) [s, b, np * num_splits * hn] \"\"\"\n\n            intermediate_shape = input_shape[:-1] +\\\n                (num_splits, self.num_attention_heads_per_partition,\n                 self.hidden_size_per_attention_head)\n\n            mixed_layer = mixed_layer.view(*intermediate_shape)\n            mixed_layer = mixed_layer.transpose(-2, -3).contiguous()\n        else:\n            \"\"\"[s, b, np * hn * num_splits] \n            -->(view) [s, b, np, hn, num_splits] \n            -->(tranpose) [s, b, np, num_splits, hn] \n            -->(view) [s, b, np * num_splits * hn] \"\"\"\n\n            intermediate_shape = input_shape[:-1] +\\\n                (self.num_attention_heads_per_partition,\n                 self.hidden_size_per_attention_head, num_splits)\n\n            mixed_layer = mixed_layer.view(*intermediate_shape)\n            mixed_layer = mixed_layer.transpose(-1, -2).contiguous()\n        mixed_layer = mixed_layer.view(*input_shape)\n        \n        return mixed_layer\n\n    def forward(self, hidden_states, attention_mask, layer_past=None,\n                get_key_value=False):\n        # hidden_states: [sq, b, h]\n\n        # =====================\n        # Query, Key, and Value\n        # =====================\n        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]\n        mixed_x_layer, _ = self.query_key_value(hidden_states)\n\n        checkpoint_version = get_checkpoint_version()\n        if checkpoint_version is not None:\n           if checkpoint_version == 0:\n               # [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]\n               mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)\n           elif checkpoint_version == 1.0:\n               # [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]\n               mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)\n\n        # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]\n        new_tensor_shape = mixed_x_layer.size()[:-1] + \\\n            (self.num_attention_heads_per_partition,\n             3 * self.hidden_size_per_attention_head)\n        mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)\n\n        # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]\n        (query_layer,\n         key_layer,\n         value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3)\n\n        # ==================================\n        # Adjust key and value for inference\n        # ==================================\n\n        if layer_past is not None:\n            past_key, past_value = layer_past\n            key_layer = torch.cat((past_key.type_as(key_layer),\n                                   key_layer), dim=0)\n            value_layer = torch.cat((past_value.type_as(value_layer),\n                                     value_layer), dim=0)\n        if get_key_value:\n            present = (key_layer, value_layer)\n\n\n        # ===================================\n        # Raw attention scores. [b, np, s, s]\n        # ===================================\n        \n        # [b, np, sq, sk]\n        output_size = (query_layer.size(1), \n                       query_layer.size(2), \n                       query_layer.size(0), \n                       key_layer.size(0))\n        \n        # [sq, b, np, hn] -> [sq, b * np, hn]\n        query_layer = query_layer.view(output_size[2],\n                                       output_size[0] * output_size[1], -1)\n        key_layer = key_layer.view(output_size[3],\n                                   output_size[0] * output_size[1], -1)\n\n        # preallocting result tensor: [b * np, sq, sk]\n        matmul_result = torch.empty(\n            output_size[0]*output_size[1], \n            output_size[2], \n            output_size[3],\n            dtype=query_layer.dtype, \n            device=torch.cuda.current_device())\n\n        # Raw attention scores. [b * np, sq, sk]\n        matmul_result = torch.baddbmm(matmul_result, \n            query_layer.transpose(0, 1),   # [b * np, sq, hn]\n            key_layer.transpose(0,1).transpose(1, 2),  #[b * np, hn, sk]\n            beta=0.0, alpha=(1.0/self.norm_factor))\n\n        # change view to [b, np, sq, sk]\n        attention_scores = matmul_result.view(*output_size)\n\n\n        # ==================================================\n        # Update attention mask for inference. [b, np, sq, sk]\n        # ==================================================\n\n        if get_key_value:\n            with torch.no_grad():\n                if layer_past is not None:\n                    attention_mask = attention_mask[\n                        ...,\n                        attention_scores.size(3) - 1,\n                        :attention_scores.size(3)].unsqueeze(2)\n                else:\n                    attention_mask = attention_mask[\n                        ...,\n                        :attention_scores.size(3),\n                        :attention_scores.size(3)]\n\n\n        # ===========================\n        # Attention probs and dropout\n        # ===========================\n\n        # attention scores and attention mask [b, np, sq, sk]\n        attention_probs = self.scale_mask_softmax(attention_scores,\n                                                  attention_mask)\n\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        with mpu.get_cuda_rng_tracker().fork():\n            attention_probs = self.attention_dropout(attention_probs)\n\n\n        # =========================\n        # Context layer. [sq, b, hp]\n        # =========================\n\n        # value_layer -> context layer.\n        # [sk, b, np, hn] --> [b, np, sq, hn]\n\n        # context layer shape: [b, np, sq, hn]\n        output_size = (value_layer.size(1), \n                       value_layer.size(2), \n                       query_layer.size(0), \n                       value_layer.size(3)) \n\n        # change view [sk, b * np, hn] \n        value_layer = value_layer.view(value_layer.size(0),\n                                       output_size[0] * output_size[1], -1)\n        \n        # change view [b * np, sq, sk]\n        attention_probs = attention_probs.view(output_size[0] * output_size[1],\n                                               output_size[2], -1)\n        \n        # matmul: [b * np, sq, hn]\n        context_layer = torch.bmm(attention_probs, value_layer.transpose(0,1))\n\n        # change view [b, np, sq, hn]\n        context_layer = context_layer.view(*output_size)\n\n        # [b, np, sq, hn] --> [sq, b, np, hn]\n        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()\n\n        # [sq, b, np, hn] --> [sq, b, hp]\n        new_context_layer_shape = context_layer.size()[:-2] + \\\n            (self.hidden_size_per_partition,)\n        context_layer = context_layer.view(*new_context_layer_shape)\n\n\n        # =================\n        # Output. [sq, b, h]\n        # =================\n\n        output, bias = self.dense(context_layer)\n\n        if get_key_value:\n            output = [output, present]\n\n        return output, bias\n\n\ndef bias_dropout_add(x, bias, residual, prob, training) :\n    # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor\n    out = torch.nn.functional.dropout(x + bias, p=prob, training=training)\n    # print(\">>>>>>>>>>>>>>>> getting dropout: {}, {}\".format(x.shape, bias.shape))\n    out = residual + out\n    return out\n\n\ndef get_bias_dropout_add(training):\n    def _bias_dropout_add(x, bias, residual, prob):\n        return bias_dropout_add(x, bias, residual, prob, training)\n    return _bias_dropout_add\n\n\n@torch.jit.script\ndef bias_dropout_add_fused_train(x, bias, residual, prob) :\n    # type: (Tensor, Tensor, Tensor, float) -> Tensor\n    return bias_dropout_add(x, bias, residual, prob, True)\n\n\n@torch.jit.script\ndef bias_dropout_add_fused_inference(x, bias, residual, prob) :\n    # type: (Tensor, Tensor, Tensor, float) -> Tensor\n    return bias_dropout_add(x, bias, residual, prob, False)\n\n\nclass ParallelTransformerLayer(MegatronModule):\n    \"\"\"A single transformer layer.\n\n    Transformore layer takes input with size [b, s, h] and returns an\n    output of the same size.\n    \"\"\"\n\n    def __init__(self, attention_mask_func, init_method, \n                 output_layer_init_method, layer_number):\n        args = get_args()\n\n        super(ParallelTransformerLayer, self).__init__()\n        self.layer_number = layer_number\n\n        self.apply_residual_connection_post_layernorm \\\n            = args.apply_residual_connection_post_layernorm\n\n        # Memory-saving optimization\n        self.scattered_attn_output = args.scattered_embeddings\n\n        # Layernorm on the input data.\n        self.input_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        # Self attention.\n        self.attention = ParallelSelfAttention(attention_mask_func, init_method,\n                                               output_layer_init_method,\n                                               layer_number)\n        self.hidden_dropout = args.hidden_dropout\n        self.bias_dropout_fusion = args.bias_dropout_fusion\n\n        # Layernorm on the input data.\n        self.post_attention_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        # MLP\n        self.mlp = ParallelMLP(init_method,\n                               output_layer_init_method)\n\n\n    def forward(self, hidden_states, attention_mask, layer_past=None,\n                get_key_value=False):\n        # hidden_states: [b, s, h]\n\n        # Layer norm at the begining of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n        # Self attention.\n        attention_output, attention_bias = \\\n            self.attention(layernorm_output,\n                           attention_mask,\n                           layer_past=layer_past,\n                           get_key_value=get_key_value)\n\n        if get_key_value:\n            attention_output, presents = attention_output\n\n        if self.scattered_attn_output:\n            attention_output = mpu.scatter_to_model_parallel_region(attention_output)\n            attention_bias = mpu.scatter_to_model_parallel_region(attention_bias)\n    \n        # Residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        if self.scattered_attn_output:\n            residual = mpu.scatter_to_model_parallel_region(residual)\n\n        # jit scripting for a nn.module (with dropout) is not \n        # trigerring the fusion kernel. For now, we use two \n        # different nn.functional routines to account for varying\n        # dropout semantics during training and inference phases.\n        if self.bias_dropout_fusion:\n            if self.training:\n                bias_dropout_add_func = bias_dropout_add_fused_train\n            else:\n                bias_dropout_add_func = bias_dropout_add_fused_inference\n        else:\n            bias_dropout_add_func = get_bias_dropout_add(self.training)\n\n        #re-enable torch grad to enable fused optimization.\n        with torch.enable_grad():\n            layernorm_input = bias_dropout_add_func(\n                attention_output,\n                attention_bias.expand_as(residual),\n                residual,\n                self.hidden_dropout)\n\n        # Collect the scattered result from the fused dropout.\n        if self.scattered_attn_output:\n            layernorm_input = mpu.gather_from_model_parallel_region(layernorm_input)\n            # Attention output/bias are not used again, so no need to gather\n\n        # Layer norm post the self attention.\n        layernorm_output = self.post_attention_layernorm(layernorm_input)\n\n        # MLP.\n        mlp_output, mlp_bias = self.mlp(layernorm_output)\n        \n        # Second residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = layernorm_input\n\n        #re-enable torch grad to enable fused optimization.\n        with torch.enable_grad():\n            output = bias_dropout_add_func(\n                mlp_output,\n                mlp_bias.expand_as(residual),\n                residual,\n                self.hidden_dropout)\n\n        if get_key_value:\n            output = [output, presents]\n\n        return output\n\nclass ParallelTransformerLayerPart1(MegatronModule):\n    \"\"\"A single transformer layer.\n\n    Transformore layer takes input with size [b, s, h] and returns an\n    output of the same size.\n    \"\"\"\n\n    def __init__(self, attention_mask_func, init_method, \n                 output_layer_init_method, layer_number):\n        args = get_args()\n\n        super(ParallelTransformerLayerPart1, self).__init__()\n        self.layer_number = layer_number\n\n        self.apply_residual_connection_post_layernorm \\\n            = args.apply_residual_connection_post_layernorm\n\n        # Layernorm on the input data.\n        self.input_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        # Self attention.\n        self.attention = ParallelSelfAttention(attention_mask_func, init_method,\n                                               output_layer_init_method,\n                                               layer_number)\n        self.hidden_dropout = args.hidden_dropout\n        self.bias_dropout_fusion = args.bias_dropout_fusion\n\n\n    def forward(self, hidden_states, attention_mask, layer_past=None,\n                get_key_value=False):\n        # hidden_states: [b, s, h]\n        # Layer norm at the begining of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n        # Self attention.\n        attention_output, attention_bias = \\\n            self.attention(layernorm_output,\n                           attention_mask,\n                           layer_past=layer_past,\n                           get_key_value=get_key_value)\n\n        presents = None\n        if get_key_value:\n            raise NotImplementedError('get_key_value param is not yet supported with split-transformers')\n            attention_output, presents = attention_output\n\n    \n        # Residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        if self.scattered_attn_output:\n            residual = mpu.scatter_to_model_parallel_region(residual)\n\n        # jit scripting for a nn.module (with dropout) is not \n        # trigerring the fusion kernel. For now, we use two \n        # different nn.functional routines to account for varying\n        # dropout semantics during training and inference phases.\n        if self.bias_dropout_fusion:\n            if self.training:\n                bias_dropout_add_func = bias_dropout_add_fused_train\n            else:\n                bias_dropout_add_func = bias_dropout_add_fused_inference\n        else:\n            bias_dropout_add_func = get_bias_dropout_add(self.training)\n\n        #re-enable torch grad to enable fused optimization.\n        with torch.enable_grad():\n            layernorm_input = bias_dropout_add_func(\n                attention_output,\n                attention_bias.expand_as(residual),\n                residual,\n                self.hidden_dropout)\n\n        return layernorm_input\n\nclass ParallelTransformerLayerPart2(MegatronModule):\n    \"\"\"A single transformer layer.\n\n    Transformore layer takes input with size [b, s, h] and returns an\n    output of the same size.\n    \"\"\"\n\n    def __init__(self, attention_mask_func, init_method, \n                 output_layer_init_method, layer_number):\n        args = get_args()\n\n        super(ParallelTransformerLayerPart2, self).__init__()\n        self.layer_number = layer_number\n\n        self.apply_residual_connection_post_layernorm \\\n            = args.apply_residual_connection_post_layernorm\n\n        self.hidden_dropout = args.hidden_dropout\n        self.bias_dropout_fusion = args.bias_dropout_fusion\n\n        # Layernorm on the input data.\n        self.post_attention_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        # MLP\n        self.mlp = ParallelMLP(init_method,\n                               output_layer_init_method)\n\n\n    def forward(self, layernorm_input, attention_mask, presents=None, layer_past=None,\n                get_key_value=False):\n        # hidden_states: [b, s, h]\n        \n        # Collect the scattered result from the fused dropout.\n        if self.scattered_attn_output:\n            layernorm_input = mpu.gather_from_model_parallel_region(layernorm_input)\n            # Attention output/bias are not used again, so no need to gather\n\n        # Layer norm post the self attention.\n        layernorm_output = self.post_attention_layernorm(layernorm_input)\n\n        # MLP.\n        mlp_output, mlp_bias = self.mlp(layernorm_output)\n        \n        # Second residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = layernorm_input\n\n        # jit scripting for a nn.module (with dropout) is not \n        # trigerring the fusion kernel. For now, we use two \n        # different nn.functional routines to account for varying\n        # dropout semantics during training and inference phases.\n        if self.bias_dropout_fusion:\n            if self.training:\n                bias_dropout_add_func = bias_dropout_add_fused_train\n            else:\n                bias_dropout_add_func = bias_dropout_add_fused_inference\n        else:\n            bias_dropout_add_func = get_bias_dropout_add(self.training)\n\n        #re-enable torch grad to enable fused optimization.\n        with torch.enable_grad():\n            output = bias_dropout_add_func(\n                mlp_output,\n                mlp_bias.expand_as(residual),\n                residual,\n                self.hidden_dropout)\n\n        if get_key_value:\n            output = [output, presents]\n\n        return output\n\nclass ParallelTransformerLayerPart1(MegatronModule):\n    \"\"\"A single transformer layer.\n\n    Transformore layer takes input with size [b, s, h] and returns an\n    output of the same size.\n    \"\"\"\n\n    def __init__(self, attention_mask_func, init_method, \n                 output_layer_init_method, layer_number):\n        args = get_args()\n\n        super(ParallelTransformerLayerPart1, self).__init__()\n        self.layer_number = layer_number\n\n        self.apply_residual_connection_post_layernorm \\\n            = args.apply_residual_connection_post_layernorm\n\n        # Layernorm on the input data.\n        self.input_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        # Self attention.\n        self.attention = ParallelSelfAttention(attention_mask_func, init_method,\n                                               output_layer_init_method,\n                                               layer_number)\n        self.hidden_dropout = args.hidden_dropout\n        self.bias_dropout_fusion = args.bias_dropout_fusion\n\n\n    def forward(self, hidden_states, attention_mask, layer_past=None,\n                get_key_value=False):\n        # hidden_states: [b, s, h]\n\n        # Layer norm at the begining of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n        # Self attention.\n        attention_output, attention_bias = \\\n            self.attention(layernorm_output,\n                           attention_mask,\n                           layer_past=layer_past,\n                           get_key_value=get_key_value)\n\n        presents = None\n        if get_key_value:\n            raise NotImplementedError('get_key_value param is not yet supported with split-transformers')\n            attention_output, presents = attention_output\n    \n        # Residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        # jit scripting for a nn.module (with dropout) is not \n        # trigerring the fusion kernel. For now, we use two \n        # different nn.functional routines to account for varying\n        # dropout semantics during training and inference phases.\n        if self.bias_dropout_fusion:\n            if self.training:\n                bias_dropout_add_func = bias_dropout_add_fused_train\n            else:\n                bias_dropout_add_func = bias_dropout_add_fused_inference\n        else:\n            bias_dropout_add_func = get_bias_dropout_add(self.training)\n\n        #re-enable torch grad to enable fused optimization.\n        with torch.enable_grad():\n            layernorm_input = bias_dropout_add_func(\n                attention_output,\n                attention_bias.expand_as(residual),\n                residual,\n                self.hidden_dropout)\n\n        return layernorm_input\n\nclass ParallelTransformerLayerPart2(MegatronModule):\n    \"\"\"A single transformer layer.\n\n    Transformore layer takes input with size [b, s, h] and returns an\n    output of the same size.\n    \"\"\"\n\n    def __init__(self, attention_mask_func, init_method, \n                 output_layer_init_method, layer_number):\n        args = get_args()\n\n        super(ParallelTransformerLayerPart2, self).__init__()\n        self.layer_number = layer_number\n\n        self.apply_residual_connection_post_layernorm \\\n            = args.apply_residual_connection_post_layernorm\n\n        self.hidden_dropout = args.hidden_dropout\n        self.bias_dropout_fusion = args.bias_dropout_fusion\n\n        # Layernorm on the input data.\n        self.post_attention_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        # MLP\n        self.mlp = ParallelMLP(init_method,\n                               output_layer_init_method)\n\n\n    def forward(self, layernorm_input, attention_mask, presents=None, layer_past=None,\n                get_key_value=False):\n        # hidden_states: [b, s, h]\n\n        # Layer norm post the self attention.\n        layernorm_output = self.post_attention_layernorm(layernorm_input)\n\n        # MLP.\n        mlp_output, mlp_bias = self.mlp(layernorm_output)\n        \n        # Second residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = layernorm_input\n\n        # jit scripting for a nn.module (with dropout) is not \n        # trigerring the fusion kernel. For now, we use two \n        # different nn.functional routines to account for varying\n        # dropout semantics during training and inference phases.\n        if self.bias_dropout_fusion:\n            if self.training:\n                bias_dropout_add_func = bias_dropout_add_fused_train\n            else:\n                bias_dropout_add_func = bias_dropout_add_fused_inference\n        else:\n            bias_dropout_add_func = get_bias_dropout_add(self.training)\n\n        #re-enable torch grad to enable fused optimization.\n        with torch.enable_grad():\n            output = bias_dropout_add_func(\n                mlp_output,\n                mlp_bias.expand_as(residual),\n                residual,\n                self.hidden_dropout)\n\n        if get_key_value:\n            output = [output, presents]\n\n        return output\n\n\nclass ParallelMOETransformerLayer(MegatronModule):\n    \"\"\"A single transformer layer.\n\n    Transformore layer takes input with size [b, s, h] and returns an\n    output of the same size.\n    \"\"\"\n\n    def __init__(self, attention_mask_func, init_method,\n                 output_layer_init_method, layer_number):\n        args = get_args()\n\n        super(ParallelMOETransformerLayer, self).__init__()\n        self.layer_number = layer_number\n\n        self.apply_residual_connection_post_layernorm \\\n            = args.apply_residual_connection_post_layernorm\n\n        # Memory-saving optimization\n        self.scattered_attn_output = args.scattered_embeddings\n\n        # Layernorm on the input data.\n        self.input_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        # Self attention.\n        self.attention = ParallelSelfAttention(attention_mask_func, init_method,\n                                               output_layer_init_method,\n                                               layer_number)\n        self.hidden_dropout = args.hidden_dropout\n        self.bias_dropout_fusion = args.bias_dropout_fusion\n\n        # Layernorm on the input data.\n        self.post_attention_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        # MoE\n        self.moe = deepspeed.moe.layer.MoE(\n            hidden_size = args.hidden_size,\n            expert=NormalMLP(init_method, output_layer_init_method),\n            num_experts=args.num_experts,\n            k=args.top_k,\n            min_capacity=args.min_capacity,\n            noisy_gate_policy=args.noisy_gate_policy\n        )\n\n        # self.mlp = ParallelMLP(init_method,\n                               # output_layer_init_method)\n\n\n    def forward(self, hidden_states, attention_mask, layer_past=None,\n                get_key_value=False):\n        # hidden_states: [b, s, h]\n\n        # Layer norm at the begining of the transformer layer.\n        layernorm_output = self.input_layernorm(hidden_states)\n        # Self attention.\n        attention_output, attention_bias = \\\n            self.attention(layernorm_output,\n                           attention_mask,\n                           layer_past=layer_past,\n                           get_key_value=get_key_value)\n\n        if get_key_value:\n            attention_output, presents = attention_output\n\n        if self.scattered_attn_output:\n            attention_output = mpu.scatter_to_model_parallel_region(attention_output)\n            attention_bias = mpu.scatter_to_model_parallel_region(attention_bias)\n\n        # Residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        if self.scattered_attn_output:\n            residual = mpu.scatter_to_model_parallel_region(residual)\n\n        # jit scripting for a nn.module (with dropout) is not\n        # trigerring the fusion kernel. For now, we use two\n        # different nn.functional routines to account for varying\n        # dropout semantics during training and inference phases.\n        if self.bias_dropout_fusion:\n            if self.training:\n                bias_dropout_add_func = bias_dropout_add_fused_train\n            else:\n                bias_dropout_add_func = bias_dropout_add_fused_inference\n        else:\n            bias_dropout_add_func = get_bias_dropout_add(self.training)\n\n        #re-enable torch grad to enable fused optimization.\n        with torch.enable_grad():\n            layernorm_input = bias_dropout_add_func(\n                attention_output,\n                attention_bias.expand_as(residual),\n                residual,\n                self.hidden_dropout)\n\n        # Collect the scattered result from the fused dropout.\n        if self.scattered_attn_output:\n            layernorm_input = mpu.gather_from_model_parallel_region(layernorm_input)\n            # Attention output/bias are not used again, so no need to gather\n\n        # Layer norm post the self attention.\n        layernorm_output = self.post_attention_layernorm(layernorm_input)\n\n        # MLP.\n        # moe_output, moe_bias = self.mlp(layernorm_output)\n        # MoE\n        moe_output, _, _ = self.moe(layernorm_output)\n        moe_bias = torch.zeros_like(moe_output, dtype=moe_output.dtype,\n                                    device=moe_output.device)\n\n        # Second residual connection.\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = layernorm_input\n\n        #re-enable torch grad to enable fused optimization.\n        # Note(Hao): moe does not have bias cuz they do not support it.\n        with torch.enable_grad():\n            output = bias_dropout_add_func(\n                moe_output,\n                moe_bias,\n                residual,\n                self.hidden_dropout)\n\n        if get_key_value:\n            output = [output, presents]\n\n        return output\n\n\nclass ParallelTransformer(MegatronModule):\n    \"\"\"Transformer class.\"\"\"\n\n    def __init__(self, attention_mask_func,\n                 init_method, output_layer_init_method):\n        super(ParallelTransformer, self).__init__()\n        args = get_args()\n\n        # Store activation checkpoiting flag.\n        self.checkpoint_activations = args.checkpoint_activations\n        self.checkpoint_num_layers = args.checkpoint_num_layers\n\n        # Number of layers:\n        self.num_layers = args.num_layers\n        self.num_unique_layers = args.num_unique_layers\n        if self.num_unique_layers is None:\n            self.num_unique_layers = self.num_layers\n        assert self.num_layers % self.num_unique_layers == 0, \\\n            'number of layers should be divisible by number of unique layers'\n        self.param_sharing_style = args.param_sharing_style\n\n        # Transformer layers.\n        def build_layer(layer_number):\n            return ParallelTransformerLayer(\n                attention_mask_func, init_method,\n                output_layer_init_method, layer_number)\n\n        def build_layer_part1(layer_number):\n            return ParallelTransformerLayerPart1(\n                attention_mask_func, init_method,\n                output_layer_init_method, layer_number)\n        def build_layer_part2(layer_number):\n            return ParallelTransformerLayerPart2(\n                attention_mask_func, init_method,\n                output_layer_init_method, layer_number)\n\n        def build_moe_layer(layer_number):\n            return ParallelMOETransformerLayer(\n                attention_mask_func, init_method,\n                output_layer_init_method, layer_number)\n\n        if args.moe:\n            layers = []\n            assert self.num_unique_layers % 2 == 0\n            for i in range(self.num_layers):\n                if i % 2 == 0:\n                    layers.append(build_layer(i + 1))\n                else:\n                    layers.append(build_moe_layer(i + 1))\n            self.layers = torch.nn.ModuleList(layers)\n        elif args.split_transformers:\n            layers = []\n            for i in range(self.num_unique_layers):\n                layers.append(build_layer_part1(i + 1))\n                layers.append(build_layer_part2(i + 1))\n            self.layers = torch.nn.ModuleList(layers)\n            self.num_layers *= 2\n            self.num_unique_layers *= 2\n        else:\n            self.layers = torch.nn.ModuleList(\n                [build_layer(i + 1) for i in range(self.num_unique_layers)])\n\n        # Print layer ordering.\n        if self.num_layers != self.num_unique_layers:\n            if torch.distributed.get_rank() == 0:\n                print('> will be using the following layer ordering:')\n                for i in range(self.num_layers):\n                    print('   layer id: {:3d} --> unique layer id: '\n                          '{:3d}'.format(i, self._get_layer_index(i)),\n                          flush=True)\n\n        # Final layer norm before output.\n        self.final_layernorm = LayerNorm(\n            args.hidden_size,\n            eps=args.layernorm_epsilon)\n\n        if deepspeed.checkpointing.is_configured():\n            global get_cuda_rng_tracker, checkpoint\n            get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker\n            checkpoint = deepspeed.checkpointing.checkpoint\n\n    def _get_layer_index(self, layer_number):\n        if self.param_sharing_style == 'grouped':\n            return layer_number % self.num_unique_layers\n        if self.param_sharing_style == 'spaced':\n            return layer_number // (self.num_layers // self.num_unique_layers) \n        assert False, 'should not be here'\n\n    def _get_layer(self, layer_number):\n        return self.layers[self._get_layer_index(layer_number)]\n\n    def _checkpointed_forward(self, hidden_states, attention_mask):\n        \"\"\"Forward method with activation checkpointing.\"\"\"\n        def custom(start, end):\n            def custom_forward(*inputs):\n                x_ = inputs[0]\n                for index in range(start, end):\n                    layer = self._get_layer(index)\n                    x_ = layer(x_, inputs[1])\n                return x_\n            return custom_forward\n\n        # Make sure memory is freed.\n        mpu.reset_checkpointed_activations_memory_buffer()\n        l = 0\n        while l < self.num_layers:\n            hidden_states = mpu.checkpoint(\n                custom(l, l + self.checkpoint_num_layers),\n                hidden_states, attention_mask)\n            l += self.checkpoint_num_layers\n\n        return hidden_states\n\n    def forward(self, hidden_states, attention_mask, layer_past=None,\n                get_key_value=False):\n\n        # Checks\n        if layer_past is not None:\n            assert get_key_value, \\\n                'for not None values in layer_past, ' \\\n                'expected get_key_value to be set'\n        if get_key_value:\n            assert not self.checkpoint_activations, \\\n                'get_key_value does not work with ' \\\n                'activation checkpointing'\n\n        # data format change to avoid explicit tranposes : [b s h] --> [s b h]\n        hidden_states = hidden_states.transpose(0, 1).contiguous()\n\n        if self.checkpoint_activations:\n            hidden_states = self._checkpointed_forward(hidden_states,\n                                                       attention_mask)\n        else:\n            if get_key_value:\n                presents = []\n            for index in range(self.num_layers):\n                layer = self._get_layer(index)\n                past = None\n                if layer_past is not None:\n                    past = layer_past[index]\n                hidden_states = layer(hidden_states,\n                                      attention_mask,\n                                      layer_past=past,\n                                      get_key_value=get_key_value)\n                if get_key_value:\n                    hidden_states, present = hidden_states\n                    presents.append(present)\n        # reverting data format change [s b h] --> [b s h]\n        hidden_states = hidden_states.transpose(0, 1).contiguous()\n\n        # Final layer norm.\n        output = self.final_layernorm(hidden_states)\n        if get_key_value:\n            output = [output, presents]\n\n        return output\n"
  },
  {
    "path": "benchmark/deepspeed/pretrain_gpt2.py",
    "content": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Pretrain GPT2\"\"\"\n\nimport os\nimport json\n\nimport torch\nimport numpy as np\n\nfrom megatron import get_args\nfrom megatron import print_rank_0\nfrom megatron import get_timers\nfrom megatron import get_tokenizer\nfrom megatron import mpu\nfrom megatron.data.gpt2_dataset import build_train_valid_test_datasets\nfrom megatron.model import GPT2Model\nfrom megatron.training import pretrain\nfrom megatron.utils import get_ltor_masks_and_position_ids\nfrom megatron.utils import reduce_losses, get_parameters_in_billions\nfrom benchmark.deepspeed.pretrain_gpt2_moe import moe_parser\n\n\nimport deepspeed\nfrom deepspeed.runtime.utils import see_memory_usage\n\ndef model_provider():\n    \"\"\"Build the model.\"\"\"\n\n    print_rank_0('building GPT2 model ...')\n    see_memory_usage(f\"Before Building Model\", force=True)\n    args = get_args()\n\n    args.padded_vocab_size = int(os.environ.get(\"PYTHON_VOCAB_SIZE\", 25600))\n\n    with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),\n                             remote_device=None if args.remote_device=='none' else args.remote_device,\n                             config=args.deepspeed_config,\n                             enabled=args.zero_stage==3):\n        model = GPT2Model(num_tokentypes=0, parallel_output=True)\n    see_memory_usage(f\"After Building Model\", force=True)\n\n    if mpu.get_data_parallel_rank() == 0:\n        billion_params = get_parameters_in_billions(model)\n        print(f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\\\n            {round(billion_params, 3)} Billion',\n            flush=True)\n\n    return model\n\n\ndef get_batch(data_iterator):\n    \"\"\"Generate a batch\"\"\"\n    args = get_args()\n    tokenizer = get_tokenizer()\n\n    # Items and their type.\n    keys = ['text']\n    datatype = torch.int64\n\n    # Broadcast data.\n    if data_iterator is not None:\n        data = next(data_iterator)\n    else:\n        data = None\n    data_b = mpu.broadcast_data(keys, data, datatype)\n\n    # Unpack.\n    tokens_ = data_b['text'].long()\n\n    # Hack for our vocab_size modification\n    tokens_ = (tokens_.float() / args.padded_vocab_size).long()\n    tokenizer_eod = args.padded_vocab_size - 1\n\n    labels = tokens_[:, 1:].contiguous()\n    tokens = tokens_[:, :-1].contiguous()\n\n    # Get the masks and postition ids.\n    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(\n        tokens,\n        tokenizer_eod,\n        args.reset_position_ids,\n        args.reset_attention_mask,\n        args.eod_mask_loss)\n\n    return tokens, labels, loss_mask, attention_mask, position_ids\n\n\ndef forward_step(data_iterator, model, curriculum_learning=False):\n    \"\"\"Forward step.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    # Get the batch.\n    timers('batch generator').start()\n    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(\n        data_iterator)\n    timers('batch generator').stop()\n    # Forward model.\n    losses = model(tokens, position_ids, attention_mask, labels=labels)\n    if curriculum_learning and args.curriculum_seqlen < args.seq_length:\n        loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()\n    loss_mask = loss_mask.view(-1)\n    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()\n\n    # Reduce loss for logging.\n    reduced_loss = reduce_losses([loss])\n\n    return loss, {'lm loss': reduced_loss[0]}\n\n\ndef train_valid_test_datasets_provider(train_val_test_num_samples):\n    \"\"\"Build train, valid, and test datasets.\"\"\"\n    args = get_args()\n\n    print_rank_0('> building train, validation, and test datasets '\n                 'for GPT2 ...')\n    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(\n        data_prefix=args.data_path,\n        data_impl=args.data_impl,\n        splits_string=args.split,\n        train_valid_test_num_samples=train_val_test_num_samples,\n        seq_length=args.seq_length,\n        seed=args.seed,\n        skip_warmup=(not args.mmap_warmup))\n    print_rank_0(\"> finished creating GPT2 datasets ...\")\n\n    return train_ds, valid_ds, test_ds\n\n\nif __name__ == \"__main__\":\n    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,\n             args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},\n             extra_args_provider=moe_parser)\n\n    if torch.distributed.get_rank() == 0:\n        import numpy as np\n        from util import compute_gpt_parameter_count, compute_gpt_tflops, write_tsv\n        from megatron.training import step_latencies\n        GB = 1 << 30\n\n\n\n        args = get_args()\n        seq_len = args.seq_length\n        num_layers = args.num_layers\n        hidden_size = args.hidden_size\n        num_heads = args.num_attention_heads\n        vocab_size = args.padded_vocab_size\n        if args.deepspeed:\n            num_micro_batches = json.load(open(\n                args.deepspeed_config))[\"gradient_accumulation_steps\"]\n        else:\n            num_micro_batches = 1\n        batch_size = args.batch_size * mpu.get_data_parallel_world_size() * num_micro_batches\n        warmup_iter = 2\n\n        alloc_mem = torch.cuda.max_memory_allocated(0)\n        latencies = np.array(step_latencies[warmup_iter * num_micro_batches:])\\\n                    .reshape((-1, num_micro_batches)).sum(axis=-1)\n        param_count = compute_gpt_parameter_count(\n            num_layers, hidden_size, vocab_size)\n        tflops = compute_gpt_tflops(batch_size, seq_len, num_layers,\n                                    hidden_size, vocab_size,\n                                    torch.distributed.get_world_size(),\n                                    np.mean(latencies))\n        model_config = (batch_size, seq_len, hidden_size, num_layers, num_heads, vocab_size)\n        parallel_config = (mpu.get_data_parallel_world_size(),\n                           mpu.get_model_parallel_world_size(),\n                           args.checkpoint_activations,\n                           num_micro_batches,\n                           args.deepspeed)\n\n        # Log results\n        heads = [\"Model\", \"Model Config\", \"Parallel Config\", \"Param Count\",\n                 \"Alloc Mem\", \"ILP Objective\", \"Mean Latency\", \"Std Latency\", \"TFLOPS\"]\n        values = [\"gpt\", model_config, parallel_config,\n                  f\"{param_count/1e9:.3f}\", f\"{alloc_mem/GB:.3f}\", \"-1\",\n                  f\"{np.mean(latencies):.3f}\", f\"{np.std(latencies):.3f}\", f\"{tflops:.2f}\"]\n        write_tsv(heads, values, f\"result_gpt.tsv\")\n"
  },
  {
    "path": "benchmark/deepspeed/pretrain_gpt2_moe.py",
    "content": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Pretrain GPT2\"\"\"\n\nimport json\nimport os\nimport torch\n\nimport deepspeed\nfrom deepspeed.runtime.utils import see_memory_usage\nfrom megatron import get_args\nfrom megatron import get_timers\nfrom megatron import get_tokenizer\nfrom megatron import mpu\nfrom megatron import print_rank_0\nfrom megatron.model import GPT2Model\nfrom megatron.training import pretrain\nfrom megatron.utils import get_ltor_masks_and_position_ids\nfrom megatron.utils import reduce_losses, get_parameters_in_billions\nfrom megatron.data.gpt2_dataset import build_train_valid_test_datasets\n\n\ndef moe_parser(parser):\n    #data\n    # cuda\n    # parser.add_argument('--with_cuda',\n    #                     default=False,\n    #                     action='store_true',\n    #                     help='use CPU in case there\\'s no GPU support')\n    # parser.add_argument('--use_ema',\n    #                     default=False,\n    #                     action='store_true',\n    #                     help='whether use exponential moving average')\n\n    # train\n    # parser.add_argument('-b',\n    #                     '--batch_size',\n    #                     default=32,\n    #                     type=int,\n    #                     help='mini-batch size (default: 32)')\n    # parser.add_argument('-e',\n    #                     '--epochs',\n    #                     default=30,\n    #                     type=int,\n    #                     help='number of total epochs (default: 30)')\n    # parser.add_argument('--local_rank',\n    #                     type=int,\n    #                     default=-1,\n    #                     help='local rank passed from distributed launcher')\n    #\n    # parser.add_argument('--log-interval',\n    #                     type=int,\n    #                     default=2000,\n    #                     help=\"output logging information at a given interval\")\n    group = parser.add_argument_group(title='MOE')\n    group.add_argument(\"--vocab-size\",\n                       default=51200,\n                       type=int,\n                       help=\"vocabulary size\")\n    group.add_argument('--moe',\n                        default=False,\n                        action='store_true',\n                        help='use deepspeed mixture of experts (moe)')\n    group.add_argument('--ep-world-size',\n                        default=1,\n                        type=int,\n                        help='(moe) expert parallel world size')\n    group.add_argument('--num-experts',\n                        default=1,\n                        type=int,\n                        help='(moe) number of total experts')\n    group.add_argument('--top-k',\n                        default=1,\n                        type=int,\n                        help='(moe) gating top 1 and 2 supported')\n    group.add_argument(\n        '--min-capacity',\n        default=0,\n        type=int,\n        help=\n        '(moe) minimum capacity of an expert regardless of the capacity_factor'\n    )\n    group.add_argument(\n        '--noisy-gate-policy',\n        default=None,\n        type=str,\n        help=\n        '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter'\n    )\n    group.add_argument(\n        '--moe-param-group',\n        default=False,\n        action='store_true',\n        help=\n        '(moe) create separate moe param groups, required when using ZeRO w. MoE'\n    )\n    group.add_argument(\n        '--output_name',\n        default=\"none\",\n        help=\"where to save results.\"\n    )\n    return parser\n\n\ndef model_provider():\n    \"\"\"Build the model.\"\"\"\n\n    print_rank_0('building GPT2 model ...')\n    see_memory_usage(f\"Before Building Model\", force=True)\n    args = get_args()\n\n    args.padded_vocab_size = int(os.environ.get(\"PYTHON_VOCAB_SIZE\", 25600))\n\n    with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),\n                             remote_device=None if args.remote_device=='none' else args.remote_device,\n                             config=args.deepspeed_config,\n                             enabled=args.zero_stage==3):\n        model = GPT2Model(num_tokentypes=0, parallel_output=True)\n    see_memory_usage(f\"After Building Model\", force=True)\n\n    if mpu.get_data_parallel_rank() == 0:\n        billion_params = get_parameters_in_billions(model)\n        print(f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\\\n            {round(billion_params, 3)} Billion',\n            flush=True)\n\n    return model\n\n\ndef get_batch(data_iterator):\n    \"\"\"Generate a batch\"\"\"\n    args = get_args()\n    tokenizer = get_tokenizer()\n\n    # Items and their type.\n    keys = ['text']\n    datatype = torch.int64\n\n    # Broadcast data.\n    if data_iterator is not None:\n        data = next(data_iterator)\n    else:\n        data = None\n    data_b = mpu.broadcast_data(keys, data, datatype)\n\n    # Unpack.\n    tokens_ = data_b['text'].long()\n\n    # Hack for our vocab_size modification\n    tokens_ = (tokens_.float() / args.padded_vocab_size).long()\n    tokenizer_eod = args.padded_vocab_size - 1\n\n    labels = tokens_[:, 1:].contiguous()\n    tokens = tokens_[:, :-1].contiguous()\n\n    # Get the masks and postition ids.\n    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(\n        tokens,\n        tokenizer_eod,\n        args.reset_position_ids,\n        args.reset_attention_mask,\n        args.eod_mask_loss)\n\n    return tokens, labels, loss_mask, attention_mask, position_ids\n\n\ndef forward_step(data_iterator, model, curriculum_learning=False):\n    \"\"\"Forward step.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    # Get the batch.\n    timers('batch generator').start()\n    tokens, labels, loss_mask, attention_mask, position_ids = get_batch(\n        data_iterator)\n    timers('batch generator').stop()\n    # Forward model.\n    losses = model(tokens, position_ids, attention_mask, labels=labels)\n    if curriculum_learning and args.curriculum_seqlen < args.seq_length:\n        loss_mask = loss_mask[:, :args.curriculum_seqlen].contiguous()\n    loss_mask = loss_mask.view(-1)\n    loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()\n\n    # Reduce loss for logging.\n    reduced_loss = reduce_losses([loss])\n\n    return loss, {'lm loss': reduced_loss[0]}\n\n\ndef train_valid_test_datasets_provider(train_val_test_num_samples):\n    \"\"\"Build train, valid, and test datasets.\"\"\"\n    args = get_args()\n\n    print_rank_0('> building train, validation, and test datasets '\n                 'for GPT2 ...')\n    train_ds, valid_ds, test_ds = build_train_valid_test_datasets(\n        data_prefix=args.data_path,\n        data_impl=args.data_impl,\n        splits_string=args.split,\n        train_valid_test_num_samples=train_val_test_num_samples,\n        seq_length=args.seq_length,\n        seed=args.seed,\n        skip_warmup=(not args.mmap_warmup))\n    print_rank_0(\"> finished creating GPT2 datasets ...\")\n\n    return train_ds, valid_ds, test_ds\n\n\nif __name__ == \"__main__\":\n\n    pretrain(train_valid_test_datasets_provider, model_provider, forward_step,\n             args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},\n             extra_args_provider=moe_parser)\n    args = get_args()\n    rank = torch.distributed.get_rank()\n    if rank == 0:\n        import numpy as np\n        from util import compute_moe_parameter_count, compute_moe_tflops, write_tsv\n        from megatron.training import step_latencies\n        GB = 1 << 30\n\n        args = get_args()\n        seq_len = args.seq_length\n        num_layers = args.num_layers\n        hidden_size = args.hidden_size\n        num_heads = args.num_attention_heads\n        num_experts = args.num_experts\n        vocab_size = args.padded_vocab_size\n        mlp_factor = 8\n        if args.deepspeed:\n            num_micro_batches = json.load(open(\n                args.deepspeed_config))[\"gradient_accumulation_steps\"]\n        else:\n            num_micro_batches = 1\n        batch_size = args.batch_size * mpu.get_data_parallel_world_size() * num_micro_batches\n        warmup_iter = 2\n\n        alloc_mem = torch.cuda.max_memory_allocated(0)\n        latencies = np.array(step_latencies[warmup_iter * num_micro_batches:])\\\n                    .reshape((-1, num_micro_batches)).sum(axis=-1)\n\n        param_count = compute_moe_parameter_count(\n            num_layers, hidden_size, vocab_size, num_experts, mlp_factor=mlp_factor)\n\n        expert_group_size = batch_size * seq_len // num_micro_batches \\\n                            // mpu.get_data_parallel_world_size()\n\n        tflops = compute_moe_tflops(batch_size, seq_len, num_layers,\n                                    hidden_size, expert_group_size,\n                                    vocab_size, num_experts,\n                                    torch.distributed.get_world_size(),\n                                    np.mean(latencies), mlp_factor=mlp_factor)\n        tflops_ckpt = compute_moe_tflops(batch_size, seq_len, num_layers,\n                                         hidden_size, expert_group_size ,\n                                         vocab_size, num_experts, torch.distributed.get_world_size(),\n                                         np.mean(latencies), mlp_factor=mlp_factor,\n                                         checkpoint_activations=True)\n        model_config = (batch_size, seq_len, hidden_size, num_layers, num_heads, num_experts)\n        parallel_config = (mpu.get_data_parallel_world_size(),\n                           mpu.get_model_parallel_world_size(),\n                           1,\n                           args.ep_world_size)\n\n        # Log results\n        heads = [\"Type\", \"Model Config\", \"Parallel Config\", \"P-mesh shape\", \"#Microbatch\",\n                 \"Force DP\", \"Remat\", \"Mean Time\", \"Std Time\", \"#Params\", \"TFLOPs\", \"TFLOPs (ckpt)\",\n                 \"Peak Mem\"]\n        values = [\"MOE\", str(model_config), str(parallel_config),\n                  \"N/A\", str(num_micro_batches), \"N/A\",\n                  str(args.checkpoint_activations), f\"{np.mean(latencies):.3f}s\", f\"{np.std(latencies):.3f}\",\n                  f\"{param_count/1e9:.3f}B\", f\"{tflops:.2f}\", f\"{tflops_ckpt:.2f}\",\n                  f\"{alloc_mem/GB:5.3f}G\"]\n        write_tsv(heads, values,f\"moe_deepspeed_{args.output_name}_rank{rank}.tsv\")\n"
  },
  {
    "path": "benchmark/deepspeed/training.py",
    "content": "# coding=utf-8\n# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Pretrain utilities.\"\"\"\n\nfrom datetime import datetime\nimport math\nimport sys\nimport torch\nimport json\nfrom torch.nn.parallel.distributed import DistributedDataParallel as torchDDP\nfrom apex.optimizers import FusedAdam as Adam\n\nfrom megatron import get_args\nfrom megatron import get_timers\nfrom megatron import get_tensorboard_writer\nfrom megatron import mpu\nfrom megatron import print_rank_0\nfrom megatron.checkpointing import load_checkpoint\nfrom megatron.checkpointing import save_checkpoint\nfrom megatron.fp16 import FP16_Module\nfrom megatron.fp16 import FP16_Optimizer\nfrom megatron.initialize import initialize_megatron\nfrom megatron.learning_rates import AnnealingLR\nfrom megatron.model import DistributedDataParallel as LocalDDP\nfrom megatron.model import get_params_for_weight_decay_optimization\nfrom megatron.model.realm_model import ICTBertModel\nfrom megatron.utils import check_adlr_autoresume_termination\nfrom megatron.utils import make_data_loader\nfrom megatron.utils import report_memory, flops_calculator\n\nimport deepspeed\nfrom deepspeed.runtime.utils import see_memory_usage\n\n\ndef pretrain(train_valid_test_dataset_provider, model_provider,\n             forward_step_func, extra_args_provider=None, args_defaults={}):\n    \"\"\"Main training program.\n\n    This function will run the followings in the order provided:\n        1) initialize Megatron.\n        2) setup model, optimizer and lr schedule using the model_provider.\n        3) call train_val_test_data_provider to get train/val/test datasets.\n        4) train the modle using the forward_step_func.\n\n    Arguments:\n        train_valid_test_dataset_provider: a function that takes the size of\n            train/valid/test dataset and returns `train, valid, test` datasets.\n        model_provider: a function that returns a vanilla version of the\n            model. By vanilla we mean a simple model on cpu with no fp16 or ddp.\n        forward_step_func: a function that takes a `data iterator` and `model`,\n            and returns a `loss` scalar with a dictionary with key:values being\n            the info we would like to monitor during training, for example\n            `lm-loss: value`. We also require that this function add\n            `batch generator` to the timers class.\n        extra_args_provider: a function that takes a parser and adds arguments\n            to it. It is used for programs to add their own arguments.\n        args_defaults: a dictionary from argument-name to argument-value. It\n            to set already parse arguments.\n    \"\"\"\n\n    # Initalize and get arguments, timers, and Tensorboard writer.\n    initialize_megatron(extra_args_provider=extra_args_provider,\n                        args_defaults=args_defaults)\n\n    args = get_args()\n    timers = get_timers()\n\n    args.curriculum_learning = False\n    if args.deepspeed:\n        args.deepspeed_configuration = json.load(\n            open(args.deepspeed_config, 'r', encoding='utf-8'))\n        if \"curriculum_learning\" in args.deepspeed_configuration:\n            if \"enabled\" in args.deepspeed_configuration[\"curriculum_learning\"]:\n                args.curriculum_learning = args.deepspeed_configuration[\"curriculum_learning\"][\"enabled\"]\n\n    # Model, optimizer, and learning rate.\n    timers('model and optimizer').start()\n    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)\n    timers('model and optimizer').stop()\n\n    # Data stuff.\n    timers('train/valid/test data iterators').start()\n    train_data_iterator, valid_data_iterator, test_data_iterator \\\n        = build_train_valid_test_data_iterators(\n            train_valid_test_dataset_provider)\n    timers('train/valid/test data iterators').stop()\n\n    # Print setup timing.\n    print_rank_0('done with setups ...')\n    timers.log(['model and optimizer', 'train/valid/test data iterators'])\n    print_rank_0('training ...')\n\n    iteration = 0\n    if args.do_train and args.train_iters > 0:\n        iteration = train(forward_step_func,\n                          model, optimizer, lr_scheduler,\n                          train_data_iterator, valid_data_iterator)\n\n    if args.do_valid:\n        prefix = 'the end of training for val data'\n        evaluate_and_print_results(prefix, forward_step_func,\n                                   valid_data_iterator, model,\n                                   iteration, False)\n\n    if args.save and iteration != 0:\n        save_checkpoint(iteration, model, optimizer, lr_scheduler)\n\n    if args.do_test:\n        # Run on test data.\n        prefix = 'the end of training for test data'\n        evaluate_and_print_results(prefix, forward_step_func,\n                                   test_data_iterator, model,\n                                   0, True)\n\n\ndef get_model(model_provider_func):\n    \"\"\"Build the model.\"\"\"\n    args = get_args()\n\n    # Build model on cpu.\n    model = model_provider_func()\n\n    if args.deepspeed:\n        # DeepSpeed handles CUDA, FP16, and DDP components.\n        return model\n\n    # GPU allocation.\n    model.cuda(torch.cuda.current_device())\n\n    # Fp16 conversion.\n    if args.fp16:\n        model = FP16_Module(model)\n\n    # Wrap model for distributed training.\"\"\"\n    if args.DDP_impl == 'torch':\n        i = torch.cuda.current_device()\n        model = torchDDP(model, device_ids=[i], output_device=i,\n                         process_group=mpu.get_data_parallel_group())\n        return model\n    if args.DDP_impl == 'local':\n        model = LocalDDP(model)\n        return model\n\n    raise NotImplementedError('Unknown DDP implementation specified: {}. '\n                              'Exiting.'.format(args.DDP_impl))\n\n\ndef get_optimizer(model):\n    \"\"\"Set up the optimizer.\"\"\"\n    args = get_args()\n\n    # Build parameter groups (weight decay and non-decay).\n    while isinstance(model, (torchDDP, LocalDDP, FP16_Module)):\n        model = model.module\n    param_groups = get_params_for_weight_decay_optimization(model)\n\n    # Add model parallel attribute if it is not set.\n    for param_group in param_groups:\n        for param in param_group['params']:\n            if not hasattr(param, 'model_parallel'):\n                param.model_parallel = False\n\n    if args.cpu_optimizer:\n        if args.cpu_torch_adam:\n            cpu_adam_optimizer = torch.optim.AdamW\n        else:\n            from deepspeed.ops.adam import DeepSpeedCPUAdam\n            cpu_adam_optimizer = DeepSpeedCPUAdam\n        optimizer = cpu_adam_optimizer(param_groups,\n                                       lr=args.lr,\n                                       weight_decay=args.weight_decay)\n    else:\n        # Use torch Adam instead of Fused Adam from NVIDIA which seems to have some issue.\n        #optimizer = Adam(param_groups,\n        optimizer = torch.optim.AdamW(param_groups,\n                         lr=args.lr,\n                         weight_decay=args.weight_decay,\n                         betas=(args.adam_beta1, args.adam_beta2),\n                         eps=args.adam_eps)\n\n    if args.deepspeed:\n        # fp16 wrapper is not required for DeepSpeed.\n        return optimizer\n\n    # Wrap into fp16 optimizer.\n    if args.fp16:\n        optimizer = FP16_Optimizer(optimizer,\n                                   static_loss_scale=args.loss_scale,\n                                   dynamic_loss_scale=args.dynamic_loss_scale,\n                                   dynamic_loss_args={\n                                       'scale_window': args.loss_scale_window,\n                                       'min_scale': args.min_scale,\n                                       'delayed_shift': args.hysteresis})\n\n    return optimizer\n\n\ndef get_learning_rate_scheduler(optimizer):\n    \"\"\"Build the learning rate scheduler.\"\"\"\n    args = get_args()\n\n    # Add linear learning rate scheduler.\n    if args.lr_decay_iters is not None:\n        num_iters = args.lr_decay_iters\n    else:\n        num_iters = args.train_iters\n    num_iters = max(1, num_iters)\n    init_step = 0\n    if args.warmup_iters is not None:\n        warmup_iter = args.warmup_iters\n    else:\n        warmup_iter = args.warmup * num_iters\n    lr_scheduler = AnnealingLR(\n        optimizer,\n        start_lr=args.lr,\n        warmup_iter=warmup_iter,\n        total_iters=num_iters,\n        decay_style=args.lr_decay_style,\n        last_iter=init_step,\n        min_lr=args.min_lr,\n        use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,\n        override_lr_scheduler=args.override_lr_scheduler)\n\n    return lr_scheduler\n\n\ndef setup_model_and_optimizer(model_provider_func):\n    \"\"\"Setup model and optimizer.\"\"\"\n    args = get_args()\n\n    model = get_model(model_provider_func)\n    optimizer = get_optimizer(model)\n    lr_scheduler = get_learning_rate_scheduler(optimizer)\n\n    if args.deepspeed:\n        print_rank_0(\"DeepSpeed is enabled.\")\n\n        model, optimizer, _, lr_scheduler = deepspeed.initialize(\n            model=model,\n            optimizer=optimizer,\n            args=args,\n            lr_scheduler=lr_scheduler,\n            mpu=mpu,\n            dist_init_required=False)\n    if args.load is not None:\n        args.iteration = load_checkpoint(model, optimizer, lr_scheduler)\n    else:\n        args.iteration = 0\n\n    # get model without FP16 and/or TorchDDP wrappers\n    unwrapped_model = model\n    while hasattr(unwrapped_model, 'module'):\n        unwrapped_model = unwrapped_model.module\n\n    if args.iteration == 0 and hasattr(unwrapped_model, 'init_state_dict_from_bert'):\n        print(\"Initializing ICT from pretrained BERT model\", flush=True)\n        unwrapped_model.init_state_dict_from_bert()\n\n    return model, optimizer, lr_scheduler\n\n\ndef backward_step(optimizer, model, loss):\n    \"\"\"Backward step.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    # Backward pass.\n    timers('backward-backward').start()\n    if args.deepspeed:\n        model.backward(loss)\n    else:\n        optimizer.zero_grad(set_grads_to_None=True)\n        if args.fp16:\n            optimizer.backward(loss, update_master_grads=False)\n        else:\n            loss.backward()\n    timers('backward-backward').stop()\n\n    if args.deepspeed:\n        # DeepSpeed backward propagation already addressed all reduce communication.\n        # Reset the timer to avoid breaking timer logs below.\n        timers('backward-allreduce').reset()\n    else:\n        # All-reduce if needed.\n        if args.DDP_impl == 'local':\n            timers('backward-allreduce').start()\n            model.allreduce_params(reduce_after=False,\n                                   fp32_allreduce=args.fp32_allreduce)\n            timers('backward-allreduce').stop()\n\n    if not args.deepspeed:\n        # Update master gradients.\n        timers('backward-master-grad').start()\n        if args.fp16:\n            optimizer.update_master_grads()\n        timers('backward-master-grad').stop()\n\n        # Clipping gradients helps prevent the exploding gradient.\n        timers('backward-clip-grad').start()\n        if args.clip_grad > 0:\n            if not args.fp16:\n                mpu.clip_grad_norm(model.parameters(), args.clip_grad)\n            else:\n                optimizer.clip_master_grads(args.clip_grad)\n        timers('backward-clip-grad').stop()\n\nimport time\nglobal step_latencies\nstep_latencies = []\n\ndef train_step(forward_step_func, data_iterator,\n               model, optimizer, lr_scheduler):\n    \"\"\"Single training step.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    #see_memory_usage(f'before forward {model.global_steps}', force=True)\n    # Forward model for one step.\n    timers('forward').start()\n    tic = time.time()\n    loss, loss_reduced = forward_step_func(data_iterator, model, args.curriculum_learning)\n    timers('forward').stop()\n\n    #see_memory_usage(f'before backward {model.global_steps}', force=True)\n    # Calculate gradients, reduce across processes, and clip.\n    timers('backward').start()\n    backward_step(optimizer, model, loss)\n    timers('backward').stop()\n\n\n    #see_memory_usage(f'before optimizer {model.global_steps}', force=True)\n    # Update parameters.\n    skipped_iter = 0\n    timers('optimizer').start()\n    if args.deepspeed:\n        model.step()\n    else:\n        optimizer.step()\n        # Update learning rate.\n        if not (args.fp16 and optimizer.overflow):\n            lr_scheduler.step()\n        else:\n            skipped_iter = 1\n    timers('optimizer').stop()\n\n    step_latencies.append(time.time() - tic - timers('batch generator').elapsed(reset=False))\n\n    return loss_reduced, skipped_iter\n\n\ndef training_log(loss_dict, total_loss_dict, learning_rate, iteration,\n                 loss_scale, report_memory_flag, skipped_iter, model=None):\n    \"\"\"Log training information such as losses, timing, ....\"\"\"\n    args = get_args()\n    timers = get_timers()\n    writer = get_tensorboard_writer()\n\n    # Update losses.\n    skipped_iters_key = 'skipped iterations'\n    total_loss_dict[skipped_iters_key] = total_loss_dict.get(\n        skipped_iters_key, 0) + skipped_iter\n    got_nan_key = 'got nan'\n\n    got_nan = False\n    for key in loss_dict:\n        if not skipped_iter:\n            total_loss_dict[key] = total_loss_dict.get(key, 0.) + loss_dict[key]\n        else:\n            value = loss_dict[key].float().sum().item()\n            is_nan = value == float('inf') or \\\n                     value == -float('inf') or \\\n                     value != value\n            got_nan = got_nan or is_nan\n\n    total_loss_dict[got_nan_key] = total_loss_dict.get(\n        got_nan_key, 0) + int(got_nan)\n\n    # Logging.\n    timers_to_log = []\n\n    def add_to_logging(name):\n        if name in timers.timers:\n            timers_to_log.append(name)\n    add_to_logging('forward')\n    add_to_logging('backward')\n    add_to_logging('backward-backward')\n    add_to_logging('backward-allreduce')\n    add_to_logging('backward-master-grad')\n    add_to_logging('backward-clip-grad')\n    add_to_logging('optimizer')\n    add_to_logging('batch generator')\n\n    # Tensorboard values.\n    if writer and torch.distributed.get_rank() == 0:\n        writer.add_scalar('tokens', args.tokens, iteration)\n        writer.add_scalar('learning_rate', learning_rate, iteration)\n        writer.add_scalar('learning_rate/vs tokens', learning_rate, args.tokens)\n        if args.curriculum_learning:\n            writer.add_scalar('seqlen',\n                args.curriculum_seqlen, iteration)\n            writer.add_scalar('seqlen/vs tokens',\n                args.curriculum_seqlen, args.tokens)\n        for key in loss_dict:\n            writer.add_scalar(key, loss_dict[key], iteration)\n            writer.add_scalar(key + '/vs tokens', loss_dict[key], args.tokens)\n        if args.fp16:\n            writer.add_scalar('loss_scale', loss_scale, iteration)\n        normalizer = iteration % args.log_interval\n        if normalizer == 0:\n            normalizer = args.log_interval\n        timers.write(timers_to_log, writer, iteration,\n                     normalizer=normalizer)\n\n    if iteration % args.log_interval == 0:\n        elapsed_time = timers('interval time').elapsed()\n        if writer and torch.distributed.get_rank() == 0:\n            writer.add_scalar('iteration_time',\n                              elapsed_time / args.log_interval, iteration)\n        log_string = ' iteration {:8d}/{:8d} |'.format(iteration,\n                                                       args.train_iters)\n        log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(\n            elapsed_time * 1000.0 / args.log_interval)\n        log_string += ' learning rate: {:.3E} |'.format(learning_rate)\n        num_iterations = max(\n            1, args.log_interval - total_loss_dict[skipped_iters_key])\n        for key in total_loss_dict:\n            if key not in [skipped_iters_key, got_nan_key]:\n                avg = total_loss_dict[key].item() / float(num_iterations)\n                log_string += ' {}: {:.6E} |'.format(key, avg)\n                total_loss_dict[key] = 0.0\n        if args.fp16:\n            log_string += ' loss scale: {:.1f} |'.format(loss_scale)\n        log_string += ' number of skipped iterations: {:3d} |'.format(\n            total_loss_dict[skipped_iters_key])\n        log_string += ' number of nan iterations: {:3d} |'.format(\n            total_loss_dict[got_nan_key])\n        total_loss_dict[skipped_iters_key] = 0\n        total_loss_dict[got_nan_key] = 0\n        print_rank_0(log_string)\n        if report_memory_flag:\n            report_memory('after {} iterations'.format(iteration))\n            report_memory_flag = False\n        timers.log(timers_to_log, normalizer=args.log_interval)\n        flops_calculator(model, args, elapsed_time)\n\n    return report_memory_flag\n\n\ndef train(forward_step_func, model, optimizer, lr_scheduler,\n          train_data_iterator, valid_data_iterator):\n    \"\"\"Train the model function.\"\"\"\n    args = get_args()\n    timers = get_timers()\n\n    # Turn on training mode which enables dropout.\n    model.train()\n\n    # Tracking loss.\n    total_loss_dict = {}\n\n    # Iterations.\n    iteration = args.iteration\n\n    timers('interval time').start()\n    report_memory_flag = True\n    data_parallel_size = mpu.get_data_parallel_world_size()\n    global_batch_size = args.batch_size * data_parallel_size\n    while iteration < args.train_iters and \\\n        (args.train_tokens is None or args.tokens < args.train_tokens):\n        loss_dict, skipped_iter = train_step(forward_step_func,\n                                             train_data_iterator,\n                                             model,\n                                             optimizer,\n                                             lr_scheduler)\n        iteration += 1\n        if args.curriculum_learning:\n            args.tokens += global_batch_size * args.curriculum_seqlen\n        else:\n            args.tokens += global_batch_size * args.seq_length\n\n        # Logging.\n        loss_scale = None\n        if args.fp16:\n            loss_scale = optimizer.cur_scale if args.deepspeed else optimizer.loss_scale\n        report_memory_flag = training_log(loss_dict, total_loss_dict,\n                                          optimizer.param_groups[0]['lr'],\n                                          iteration, loss_scale,\n                                          report_memory_flag, skipped_iter,\n                                          model=model)\n\n        # Autoresume\n        if args.adlr_autoresume and \\\n           (iteration % args.adlr_autoresume_interval == 0):\n            check_adlr_autoresume_termination(iteration, model, optimizer,\n                                              lr_scheduler)\n\n        # Checkpointing\n        if args.save and args.save_interval and \\\n           iteration % args.save_interval == 0:\n            save_checkpoint(iteration, model, optimizer, lr_scheduler)\n\n        # Evaluation\n        # XXX temporarily disabled for ZeRO-3\n        \"\"\"\n        if args.eval_interval and iteration % args.eval_interval == 0 and \\\n           args.do_valid:\n            prefix = 'iteration {}'.format(iteration)\n            evaluate_and_print_results(prefix, forward_step_func,\n                                       valid_data_iterator, model,\n                                       iteration, False)\n        \"\"\"\n\n        if args.exit_interval and iteration % args.exit_interval == 0:\n            torch.distributed.barrier()\n            time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n            rank = torch.distributed.get_rank()\n            print_rank_0('rank: {} | time: {} | exiting the program at '\n                         'iteration {}'.format(rank, time_str, iteration))\n            sys.exit()\n\n    return iteration\n\n\ndef evaluate(forward_step_func, data_iterator, model, verbose=False):\n    \"\"\"Evaluation.\"\"\"\n    args = get_args()\n\n    # Turn on evaluation mode which disables dropout.\n    model.eval()\n\n    total_loss_dict = {}\n\n    with torch.no_grad():\n        iteration = 0\n        while iteration < args.eval_iters:\n            iteration += 1\n            if verbose and iteration % args.log_interval == 0:\n                print_rank_0('Evaluating iter {}/{}'.format(iteration,\n                                                            args.eval_iters))\n            # Forward evaluation.\n            _, loss_dict = forward_step_func(data_iterator, model)\n\n            # When contiguous memory optimizations are enabled, the buffers\n            # allocated by the optimizations are deallocated during backward pass\n            # in the absence of backward pass the buffers should be reset after each\n            # forward pass\n            if args.deepspeed and args.deepspeed_activation_checkpointing:\n                deepspeed.checkpointing.reset()\n\n            # Reduce across processes.\n            for key in loss_dict:\n                total_loss_dict[key] = total_loss_dict.get(key, 0.) + \\\n                    loss_dict[key]\n    # Move model back to the train mode.\n    model.train()\n\n    for key in total_loss_dict:\n        total_loss_dict[key] /= args.eval_iters\n\n    return total_loss_dict\n\n\ndef evaluate_and_print_results(prefix, forward_step_func,\n                               data_iterator, model,\n                               iteration, verbose=False):\n    \"\"\"Helper function to evaluate and dump results on screen.\"\"\"\n    writer = get_tensorboard_writer()\n    args = get_args()\n\n    total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)\n    string = ' validation loss at {} | '.format(prefix)\n    for key in total_loss_dict:\n        string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())\n        ppl = math.exp(min(20, total_loss_dict[key].item()))\n        string += '{} PPL: {:.6E} | '.format(key, ppl)\n        if writer and torch.distributed.get_rank() == 0:\n            writer.add_scalar('{} value'.format(key),\n                              total_loss_dict[key].item(),\n                              iteration)\n            writer.add_scalar('{} value/vs tokens'.format(key),\n                              total_loss_dict[key].item(),\n                              args.tokens)\n            writer.add_scalar('{} ppl'.format(key), ppl, iteration)\n            writer.add_scalar('{} ppl/vs tokens'.format(key), ppl, args.tokens)\n\n    length = len(string) + 1\n    print_rank_0('-' * length)\n    print_rank_0(string)\n    print_rank_0('-' * length)\n\n\ndef build_train_valid_test_data_iterators(\n        build_train_valid_test_datasets_provider):\n    \"\"\"XXX\"\"\"\n    args = get_args()\n\n    (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)\n\n    print_rank_0('> building train, validation, and test datasets ...')\n    # Data loader only on rank 0 of each model parallel group.\n    if mpu.get_model_parallel_rank() == 0:\n        # Rank, size, and global batch size.\n        data_parallel_size = mpu.get_data_parallel_world_size()\n        global_batch_size = args.batch_size * data_parallel_size\n\n        # Number of train/valid/test samples.\n        train_iters = args.train_iters\n        eval_iters = (train_iters // args.eval_interval + 1) * args.eval_iters\n        test_iters = args.eval_iters\n        train_val_test_num_samples = [train_iters * global_batch_size,\n                                      eval_iters * global_batch_size,\n                                      test_iters * global_batch_size]\n        print_rank_0(' > datasets target sizes (minimum size):')\n        print_rank_0('    train:      {}'.format(train_val_test_num_samples[0]))\n        print_rank_0('    validation: {}'.format(train_val_test_num_samples[1]))\n        print_rank_0('    test:       {}'.format(train_val_test_num_samples[2]))\n\n        # Build the datasets.\n        train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(\n            train_val_test_num_samples)\n\n        # Build dataloders.\n        train_dataloader = make_data_loader(train_ds)\n        valid_dataloader = make_data_loader(valid_ds)\n        test_dataloader = make_data_loader(test_ds)\n\n        # Flags to know if we need to do training/validation/testing.\n        do_train = train_dataloader is not None and args.train_iters > 0\n        do_valid = valid_dataloader is not None and args.eval_iters > 0\n        do_test = test_dataloader is not None and args.eval_iters > 0\n        # Need to broadcast num_tokens and num_type_tokens.\n        flags = torch.cuda.LongTensor(\n            [int(do_train), int(do_valid), int(do_test)])\n    else:\n        flags = torch.cuda.LongTensor([0, 0, 0])\n\n    # Broadcast num tokens.\n    torch.distributed.broadcast(flags,\n                                mpu.get_model_parallel_src_rank(),\n                                group=mpu.get_model_parallel_group())\n    args.do_train = flags[0].item()\n    args.do_valid = flags[1].item()\n    args.do_test = flags[2].item()\n\n    # Shift the start iterations.\n    if train_dataloader is not None:\n        train_dataloader.batch_sampler.start_iter = args.iteration % \\\n            len(train_dataloader)\n        print_rank_0('setting training data start iteration to {}'.\n                     format(train_dataloader.batch_sampler.start_iter))\n    if valid_dataloader is not None:\n        start_iter_val = (args.iteration // args.eval_interval) * \\\n            args.eval_iters\n        valid_dataloader.batch_sampler.start_iter = start_iter_val % \\\n            len(valid_dataloader)\n        print_rank_0('setting validation data start iteration to {}'.\n                     format(valid_dataloader.batch_sampler.start_iter))\n\n    # Build iterators.\n    if train_dataloader is not None:\n        train_data_iterator = iter(train_dataloader)\n    else:\n        train_data_iterator = None\n\n    if valid_dataloader is not None:\n        valid_data_iterator = iter(valid_dataloader)\n    else:\n        valid_data_iterator = None\n\n    if test_dataloader is not None:\n        test_data_iterator = iter(test_dataloader)\n    else:\n        test_data_iterator = None\n\n    return train_data_iterator, valid_data_iterator, test_data_iterator\n"
  },
  {
    "path": "benchmark/megatron/README.md",
    "content": "# Benchmark Megatron-LM\n\n## Requirements\n```\n# torch 1.8.0 and CUDA 11.1\npip3 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html\n\npip3 install ninja\n\n# Install Megatron\ngit clone https://github.com/NVIDIA/Megatron-LM.git\ncd Megatron-LM\necho 'export PYTHONPATH=$PYTHONPATH:~/efs/Megatron-LM' >> ~/.bashrc   # use your own path\nsource ~/.bashrc\n\n# Install Apex\ngit clone https://github.com/NVIDIA/apex\ncd apex\n# Comment out the raised RuntimeError in setup.py if you get errors running the following command.\npip3 install -v --no-cache-dir --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext\" ./\n```\n\n## Instructions\n### Single Node\n```\n# MLP\npython3 benchmark_mlp.py --nproc_per_node 4\n# Transfomer layer\npython3 benchmark_transformer_layer.py --nproc_per_node 4\n# GPT\npython3 benchmark_gpt_bert.py --nproc_per_node 1 --suite gpt.tmp\npython3 benchmark_gpt_bert.py --nproc_per_node 8 --suite gpt.tmp\n```\n\n### Multiple Nodes\n```\n# on node 0\npython3 benchmark_gpt_bert.py --suite gpt.tmp --nproc_per_node 8 --nnodes 2 --node_rank 0 --master_port 11000 --master_addr 172.31.16.139\n# on node 1\npython3 benchmark_gpt_bert.py --suite gpt.tmp --nproc_per_node 8 --nnodes 2 --node_rank 1 --master_port 11000 --master_addr 172.31.16.139\n```\n\nFor other models, replace `benchmark_gpt_bert.py` with the corresponding filenames.\n\n### With nvprof\n```\nnvprof --profile-child-processes python3 benchmark_mlp.py --nproc_per_node 4 &> megatron.prof\n```\n"
  },
  {
    "path": "benchmark/megatron/benchmark_gpt_bert.py",
    "content": "import argparse\nfrom datetime import datetime\n\nfrom util import run_cmd\n\nfrom benchmark.alpa import suite_manual_gpt\n\nbenchmark_suites = {\n    \"gpt.tmp\": suite_manual_gpt.tmp_suite,\n    #\"gpt.grid_search_manual\": suite_manual_gpt.grid_search_manual,\n}\n\ndef benchmark_all(args):\n    num_gpus = args.nproc_per_node * args.nnodes\n\n    try:\n        _ = benchmark_suites[args.suite][num_gpus]\n    except KeyError:\n        print(f\"No available benchmark suite for {args.suite} with {num_gpus} GPUs.\")\n        exit()\n    output_name = args.exp_name + \"-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n    model = args.suite.split(\".\")[0]\n\n    for case in benchmark_suites[args.suite][num_gpus]:\n        case = tuple(tuple(x) if isinstance(x, tuple) else x for x in case)\n        case_str = str((model,) + case)\n\n        if args.nnodes == 1:\n            # Single node\n            ret = run_cmd('python3 -m torch.distributed.launch '\n                         f'--nproc_per_node {args.nproc_per_node} '\n                         'benchmark_gpt_bert_one_case.py '\n                          f'\"{case_str}\" '\n                          f'{output_name}')\n        else:\n            # Multiple nodes\n            ret = run_cmd('python3 -m torch.distributed.launch '\n                         f'--nproc_per_node {args.nproc_per_node} '\n                         f'--nnodes {args.nnodes} '\n                         f'--node_rank {args.node_rank} '\n                         f'--master_addr {args.master_addr} '\n                         f'--master_port {args.master_port} '\n                         'benchmark_gpt_bert_one_case.py '\n                         f'\"{case_str}\" '\n                         f'{output_name}')\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--nproc_per_node\", type=int, required=True)\n    parser.add_argument(\"--nnodes\", type=int, default=1)\n    parser.add_argument(\"--node_rank\", type=int)\n    parser.add_argument(\"--master_addr\", type=str)\n    parser.add_argument(\"--master_port\", type=str)\n    parser.add_argument(\"--suite\", type=str, default=\"gpt.tmp\")\n    parser.add_argument(\"--exp_name\", type=str, default=\"\")\n    args = parser.parse_args()\n\n    benchmark_all(args)\n"
  },
  {
    "path": "benchmark/megatron/benchmark_gpt_bert_one_case.py",
    "content": "import argparse\nimport gc\nfrom functools import partial\nimport os\nimport sys\nimport time\n\nimport numpy as np\n\nfrom megatron.utils import average_losses_across_data_parallel_group\nfrom megatron.model import BertModel, GPTModel\nfrom megatron.model import ModelType\nfrom megatron import mpu, initialize_megatron, get_args, get_timers\nfrom megatron.training import train_step, setup_model_and_optimizer\nimport torch\n\nfrom util import write_tsv, benchmark_func,\\\n    compute_gpt_tflops, compute_gpt_parameter_count\n\nGB = 1024**3\n\n\ndef get_gpt_functions():\n    args = get_args()\n    micro_batch_size = args.micro_batch_size\n    seq_len = args.encoder_seq_length\n\n    def model_provider(pre_process=True, post_process=True):\n        model = GPTModel(num_tokentypes=0,\n                         parallel_output=True,\n                         pre_process=pre_process,\n                         post_process=post_process)\n        return model\n\n    def loss_func(loss_mask, output_tensor):\n        losses = output_tensor.float()\n        loss_mask = loss_mask.view(-1).float()\n        loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()\n\n        # Reduce loss for logging.\n        #averaged_loss = average_losses_across_data_parallel_group([loss])\n        averaged_loss = [0]\n        return loss, {'lm loss': averaged_loss[0]}\n\n    tokens = torch.ones((micro_batch_size, seq_len)).cuda().long()\n    labels = torch.ones((micro_batch_size, seq_len)).cuda().long()\n    loss_mask = torch.ones((micro_batch_size, seq_len)).cuda().int()\n    attention_mask = \\\n        torch.ones(micro_batch_size, 1, seq_len, seq_len).cuda().bool()\n    position_ids = torch.ones((micro_batch_size, seq_len)).cuda().long()\n\n    def forward_step(data_iterator, model):\n        output_tensor = model(tokens,\n                              position_ids,\n                              attention_mask,\n                              labels=labels)\n        return output_tensor, partial(loss_func, loss_mask)\n\n    return model_provider, loss_func, forward_step\n\n\ndef get_bert_functions():\n    args = get_args()\n    micro_batch_size = args.micro_batch_size\n    seq_len = args.encoder_seq_length\n\n    def model_provider(pre_process=True, post_process=True):\n        num_tokentypes = 2 if args.bert_binary_head else 0\n        model = BertModel(num_tokentypes=num_tokentypes,\n                          add_binary_head=args.bert_binary_head,\n                          parallel_output=True,\n                          pre_process=pre_process,\n                          post_process=post_process)\n\n        return model\n\n    def loss_func(loss_mask, sentence_order, output_tensor):\n        lm_loss_, sop_logits = output_tensor\n\n        lm_loss_ = lm_loss_.float()\n        loss_mask = loss_mask.float()\n        lm_loss = torch.sum(\n            lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()\n\n        if sop_logits is not None:\n            sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),\n                                       sentence_order.view(-1),\n                                       ignore_index=-1)\n            sop_loss = sop_loss.float()\n            loss = lm_loss + sop_loss\n            #averaged_losses = average_losses_across_data_parallel_group(\n            #    [lm_loss, sop_loss])\n            averaged_losses = [0, 0]\n            return loss, {\n                'lm loss': averaged_losses[0],\n                'sop loss': averaged_losses[1]\n            }\n        else:\n            loss = lm_loss\n            #averaged_losses = average_losses_across_data_parallel_group(\n            #    [lm_loss])\n            averaged_losses = [0]\n            return loss, {'lm loss': averaged_losses[0]}\n\n    tokens = torch.ones((micro_batch_size, seq_len)).cuda().long()\n    padding_mask = \\\n        torch.ones(micro_batch_size, seq_len).cuda().bool()\n    types = torch.ones((micro_batch_size, seq_len)).cuda().long()\n    lm_labels = torch.ones((micro_batch_size, seq_len)).cuda().long()\n    loss_mask = torch.ones((micro_batch_size, seq_len)).cuda().int()\n    sentence_order = None\n\n    def forward_step(data_iterator, model):\n        if not args.bert_binary_head:\n            types = None\n\n        output_tensor = model(tokens,\n                              padding_mask,\n                              tokentype_ids=types,\n                              lm_labels=lm_labels)\n        return output_tensor, partial(loss_func, loss_mask, sentence_order)\n\n    return model_provider, loss_func, forward_step\n\n\ndef benchmark_gpt_bert_one_case(benchmark_case, output_file_name):\n    # Model configs\n    model_type = \"gpt\"\n    (global_batch_size, model_config, num_micro_batches, parallel_mode,\n     parallel_args) = benchmark_case\n    (seq_len, hidden_size, num_layers, num_heads,\n     vocab_size) = model_config\n    assert parallel_mode == \"uniform\"\n    (prefer_reduce_scatter, use_remat, dp, op, pp,\n     force_batch_dim_mapping) = parallel_args\n\n    dp_size, tensor_mp_size, pipeline_mp_size = dp, op, pp\n    checkpoint_activations = use_remat\n\n    num_gpus = dp_size * tensor_mp_size * pipeline_mp_size\n    assert global_batch_size % (dp_size * num_micro_batches) == 0\n    micro_batch_size = global_batch_size // dp_size // num_micro_batches\n\n    # always use local DDP\n    ddp_impl = True\n\n    # Parallel configs\n    # Initialize megatron\n    sys.argv += [\"--micro-batch-size\", str(micro_batch_size)]\n    sys.argv += [\"--tensor-model-parallel-size\", str(tensor_mp_size)]\n    sys.argv += [\"--pipeline-model-parallel-size\", str(pipeline_mp_size)]\n    sys.argv += [\"--global-batch-size\", str(global_batch_size)]\n    sys.argv += [\"--num-layers\", str(num_layers)]\n    sys.argv += [\"--hidden-size\", str(hidden_size)]\n    sys.argv += [\"--num-attention-heads\", str(num_heads)]\n    sys.argv += [\"--seq-length\", str(seq_len)]\n    sys.argv += [\"--max-position-embeddings\", str(seq_len)]\n    sys.argv += [\"--optimizer\", \"adam\"]\n    sys.argv += [\"--train-iters\", \"100\"]\n    sys.argv += [\"--lr\", \"0.00015\"]\n    sys.argv += [\"--bert-no-binary-head\"]\n    sys.argv += [\"--DDP-impl\", \"local\" if ddp_impl else \"torch\"]\n    sys.argv += [\"--fp16\"]\n    sys.argv += [\"--loss-scale\", \"8\"]\n    if checkpoint_activations:\n        sys.argv += [\"--checkpoint-activations\"]\n    # sys.argv += [\"--no-masked-softmax-fusion\"]\n    # sys.argv += [\"--no-async-tensor-model-parallel-allreduce\"]\n    # sys.argv += [\"--no-scatter-gather-tensors-in-pipeline\"]\n    initialize_megatron()\n    args = get_args()\n    args.padded_vocab_size = vocab_size\n    rank = torch.distributed.get_rank()\n\n    # Check initialization\n    assert dp_size == mpu.get_data_parallel_world_size()\n    assert tensor_mp_size == mpu.get_tensor_model_parallel_world_size()\n    assert pipeline_mp_size == mpu.get_pipeline_model_parallel_world_size()\n\n    # Build model\n    if model_type == \"gpt\":\n        model_provider, loss_func, forward_step = get_gpt_functions()\n    elif model_type == \"bert\":\n        model_provider, loss_func, forward_step = get_bert_functions()\n\n    model, optimizer, lr_scheduler = setup_model_and_optimizer(\n        model_provider, model_type=ModelType.encoder_or_decoder)\n\n    parameter_count = compute_gpt_parameter_count(num_layers, hidden_size,\n                                                  vocab_size)\n\n    def run_func():\n        train_step(forward_step, None, model, optimizer, lr_scheduler)\n\n    # Warmup and reset timers\n    run_func()\n    timers = get_timers()\n    names = list(timers.timers.keys())\n    for name in names:\n        timers(name).reset()\n\n    # Benchmark step time\n    repeat = 2\n    number = 1\n    costs = benchmark_func(run_func,\n                           sync_func=None,\n                           warmup=0,\n                           repeat=repeat,\n                           number=number)\n    timers.log(names, normalizer=repeat * number)\n\n    # Print results\n    if rank == 0:\n        peak_mem = torch.cuda.max_memory_allocated(0)\n        tflops = compute_gpt_tflops(global_batch_size, seq_len, num_layers,\n                                    hidden_size, vocab_size,\n                                    torch.distributed.get_world_size(),\n                                    np.mean(costs))\n        tflops_ckpt = compute_gpt_tflops(global_batch_size, seq_len, num_layers,\n                                         hidden_size, vocab_size,\n                                         torch.distributed.get_world_size(),\n                                         np.mean(costs), True)\n        heads = [\n            \"Type\", \"Model Config\", \"Parallel Config\", \"P-mesh shape\",\n            \"#Microbatch\", \"Force DP\", \"Remat\", \"Mean Time\", \"Std Time\",\n            \"#Params\", \"TFLOPs\", \"TFLOPs (ckpt)\", \"Peak Mem\"\n        ]\n        values = [\n            model_type,\n            str(benchmark_case[1:6]),\n            str((dp_size, tensor_mp_size, pipeline_mp_size)), \"N/A\",\n            str(num_micro_batches), \"N/A\",\n            str(checkpoint_activations), f\"{np.mean(costs):.3f}\",\n            f\"{np.std(costs):.3f}\", f\"{parameter_count/1e9:.3f}\",\n            f\"{tflops:.2f}\", f\"{tflops_ckpt:.2f}\", f\"{peak_mem/GB:5.3f}\"\n        ]\n        write_tsv(heads, values,\n                  f\"{model_type}_megatron_{output_file_name}_rank{rank}.tsv\")\n        print(\"Sleeping for 30 seconds before starting the next case. \")\n        time.sleep(30)\n\n\nif __name__ == \"__main__\":\n    case = eval(sys.argv[-2])\n    output_file_name = sys.argv[-1]\n    del sys.argv[-1]\n    del sys.argv[-1]\n    benchmark_gpt_bert_one_case(case, output_file_name)\n"
  },
  {
    "path": "benchmark/megatron/benchmark_mlp.py",
    "content": "import argparse\n\nfrom util import run_cmd\n\n# B = batch_size, S = seq_len, H = hidden_size, L = num_layers,\n# #head = num_heads, DP = dp_size, TMP = tensor_mp_size, DPI = ddp_implementation,\n\nbenchmark_suite_4_gpu = [\n    # B,  S,    H,    L,  #head,     DP, TMP, DPI\n    (32,  1024, 2304, 4,  2304//96,  4,  1,   1),\n    (32,  1024, 2304, 4,  2304//96,  2,  2,   1),\n    (32,  1024, 2304, 4,  2304//96,  1,  4,   1),\n\n    # B,  S,    H,    L,  #head,     DP, TMP, DPI\n    (8,   256,  5760, 4,  5760//96,  4,  1,   1),\n    (8,   256,  5760, 4,  5760//96,  2,  2,   1),\n    (8,   256,  5760, 4,  5760//96,  1,  4,   1),\n]\n\n\ndef benchmark_all():\n    for case in benchmark_suite_4_gpu:\n        nproc_per_node = 4\n        case_str = str(case)\n        ret = run_cmd('python3 -m torch.distributed.launch '\n                     f'--nproc_per_node {nproc_per_node} '\n                     'benchmark_mlp_one_case.py '\n                     f'\"{case_str}\"')\n        if ret != 0:\n            return\n\nif __name__ == \"__main__\":\n    benchmark_all()\n\n"
  },
  {
    "path": "benchmark/megatron/benchmark_mlp_one_case.py",
    "content": "import argparse\nimport os\nimport sys\n\nimport numpy as np\nfrom megatron.model.transformer import ParallelTransformerLayer, ParallelMLP\nfrom megatron.model.utils import init_method_normal, scaled_init_method_normal\nfrom megatron.model import DistributedDataParallel as LocalDDP\nfrom megatron import mpu, initialize_megatron, get_args\nimport torch\nfrom torch.nn.parallel.distributed import DistributedDataParallel as torchDDP\n\nfrom util import write_tsv, benchmark_func\n\nGB = 1024 ** 3\n\n\ndef get_memory_usage(print_info=False):\n    \"\"\"Get accurate gpu memory usage by querying torch runtime\"\"\"\n    rank = torch.distributed.get_rank()\n    device = rank % torch.cuda.device_count()\n    allocated = torch.cuda.memory_allocated(device)\n    reserved = torch.cuda.memory_reserved(device)\n    if print_info:\n        print(\"allocated: %.2f MB\" % (allocated / 1024 / 1024), flush=True)\n        print(\"reserved:  %.2f MB\" % (reserved / 1024 / 1024), flush=True)\n    return allocated\n\n\nclass MultiLayerMLP(torch.nn.Module):\n    def __init__(self, num_layers):\n        super().__init__()\n\n        self.num_layers = num_layers\n\n        init_method_std = 0.02\n        init_method = init_method_normal(init_method_std)\n        scaled_init_method = scaled_init_method_normal(init_method_std, num_layers)\n        for i in range(self.num_layers):\n            setattr(self, f\"layer_{i}\", ParallelMLP(init_method, scaled_init_method))\n\n    def forward(self, x):\n        out = x\n        for i in range(self.num_layers):\n            out, out_bias = getattr(self, f\"layer_{i}\")(out)\n            out = out + out_bias\n        return out\n\n\ndef benchmark_mlp_one_case(benchmark_case):\n    # Model configs\n    batch_size, seq_len, hidden_size, num_layers, num_heads, \\\n        dp_size, tensor_mp_size, ddp_impl = benchmark_case\n\n    # Parallel configs\n    micro_batch_size = batch_size // dp_size\n\n    # Initialize megatron\n    sys.argv += [\"--micro-batch-size\", str(micro_batch_size)]\n    sys.argv += [\"--tensor-model-parallel-size\", str(tensor_mp_size)]\n    sys.argv += [\"--global-batch-size\", str(micro_batch_size * dp_size)]\n    sys.argv += [\"--num-layers\", str(num_layers)]\n    sys.argv += [\"--hidden-size\", str(hidden_size)]\n    sys.argv += [\"--num-attention-heads\", str(num_heads)]\n    sys.argv += [\"--max-position-embeddings\", str(seq_len)]\n    sys.argv += [\"--encoder-seq-length\", str(seq_len)]\n    initialize_megatron()\n    rank = torch.distributed.get_rank()\n\n    # Check initialization\n    assert dp_size == mpu.get_data_parallel_world_size()\n    assert tensor_mp_size == mpu.get_tensor_model_parallel_world_size()\n\n    # Build model and input batch\n    model = MultiLayerMLP(num_layers)\n    model.cuda(torch.cuda.current_device())\n\n    i = torch.cuda.current_device()\n    if ddp_impl == 0:\n        model = torchDDP(model, device_ids=[i], output_device=i,\n                         process_group=mpu.get_data_parallel_group())\n    else:\n        model = LocalDDP(model, False, True)\n\n    if rank == 0:\n        print(model)\n\n    weight_mem = get_memory_usage() \n\n    x = torch.randn(micro_batch_size, seq_len, hidden_size).cuda()\n    y = torch.randn(micro_batch_size, seq_len, hidden_size).cuda()\n\n    input_mem = get_memory_usage() - weight_mem\n    before_backward_mem = [None]\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n\n    # Benchmark step time\n    def run_func():\n        if isinstance(model, LocalDDP):\n            model.zero_grad_buffer()\n        else:\n            optimizer.zero_grad()\n\n        output = model(x)\n        loss = ((output - y) ** 2)\n        loss = loss.mean()\n        loss.backward()\n\n        if isinstance(model, LocalDDP):\n            model.allreduce_gradients()\n            for param_group in optimizer.param_groups:\n                for param in param_group['params']:\n                    param.grad = param.main_grad\n\n        optimizer.step()\n\n        torch.distributed.barrier()\n\n    def sync_func():\n        torch.cuda.synchronize()\n\n    costs = benchmark_func(run_func, sync_func,\n                           warmup=1, repeat=2, number=5)\n\n    # Print results\n    if rank == 0:\n        peak_mem = torch.cuda.max_memory_allocated(0)\n        heads = [\"Type\", \"Case\", \"WeightMem\", \"PeakMem\", \"Mean Time\", \"Std Time\"]\n        values = [\"mlp\", str(benchmark_case), f\"{weight_mem/GB:.2f}\", f\"{peak_mem/GB:.2f}\",\n                  f\"{np.mean(costs):.3f}\", f\"{np.std(costs):.3f}\"]\n        write_tsv(heads, values, \"result_mlp.tsv\")\n\n\nif __name__ == \"__main__\":\n    case = eval(sys.argv[-1])\n    del sys.argv[-1]\n    benchmark_mlp_one_case(case)\n\n"
  },
  {
    "path": "benchmark/megatron/benchmark_transformer_layer.py",
    "content": "import argparse\n\nfrom util import run_cmd\n\n# B = batch_size, S = seq_len, H = hidden_size, L = num_layers,\n# #head = num_heads, DP = dp_size, TMP = tensor_mp_size, DPI = ddp_implementation,\n\nbenchmark_suite_2_gpu = [\n    # B,  S,    H,    L,  #head,     DP, TP, PP, NB, DI, CK\n    # (32,  1024, 1536, 2,  1536//96,  1, 1, 2, 1, 1, 0),\n    # (8,  128, 384, 2,  1536//96,  1,  1, 2, 1, True, False),\n    # (8,  128, 384, 2,  1536//96,  1,  1, 2, 2, True, False),\n    # (8,  128, 384, 2,  1536//96,  1,  1, 2, 4, True, False),\n    # (8,  128, 384, 2,  1536//96,  1,  1, 2, 8, True, False),\n\n    (32,  1024, 1536, 2,  1536//96,  1,  1, 2, 1, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  1, 2, 2, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  1, 2, 4, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  1, 2, 8, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  1, 2, 16, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  1, 2, 32, True, False),\n]\n\n\nbenchmark_suite_4_gpu = [\n    # B,  S,    H,    L,  #head,     DP, TP, PP, NB, DI, CK\n\n    # DP + PP\n    (32,  1024, 1536, 2,  1536//96,  2,  1, 2, 1, True, False),\n    (32,  1024, 1536, 2,  1536//96,  2,  1, 2, 2, True, False),\n    (32,  1024, 1536, 2,  1536//96,  2,  1, 2, 4, True, False),\n    (32,  1024, 1536, 2,  1536//96,  2,  1, 2, 8, True, False),\n    (32,  1024, 1536, 2,  1536//96,  2,  1, 2, 16, True, False),\n    (32,  1024, 1536, 2,  1536//96,  2,  1, 2, 32, True, False), # wrong case\n\n    # MP + PP\n    (32,  1024, 1536, 2,  1536//96,  1,  2, 2, 1, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  2, 2, 2, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  2, 2, 4, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  2, 2, 8, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  2, 2, 16, True, False),\n    (32,  1024, 1536, 2,  1536//96,  1,  2, 2, 32, True, False),\n\n    # DP + PP, 4 layers\n    (32,  1024, 1536, 4,  1536//96,  2,  1, 2, 1, True, False),\n    (32,  1024, 1536, 4,  1536//96,  2,  1, 2, 2, True, False),\n    (32,  1024, 1536, 4,  1536//96,  2,  1, 2, 4, True, False),\n    (32,  1024, 1536, 4,  1536//96,  2,  1, 2, 8, True, False),\n    (32,  1024, 1536, 4,  1536//96,  2,  1, 2, 16, True, False),\n    (32,  1024, 1536, 4,  1536//96,  2,  1, 2, 32, True, False), # wrong case\n\n    # MP + PP, 4 layers\n    (32,  1024, 1536, 4,  1536//96,  1,  2, 2, 1, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  2, 2, 2, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  2, 2, 4, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  2, 2, 8, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  2, 2, 16, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  2, 2, 32, True, False),\n\n    # PP, 4 layers\n    (32,  1024, 1536, 4,  1536//96,  1,  1, 4, 1, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  1, 4, 2, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  1, 4, 4, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  1, 4, 8, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  1, 4, 16, True, False),\n    (32,  1024, 1536, 4,  1536//96,  1,  1, 4, 32, True, False),\n]\n\n\nbenchmark_suite_8_gpu = [\n    # B,  S,    H,    L,  #head,     DP, TP, PP, NB, DI, CK\n    # # (32,  1024, 1536, 2,  1536//96,  1,  4, 2, 1, 1, 0),\n    # (32,  1024, 1536, 4,  1536//96,  8, 1, 1, 1, 1, 0),\n    # (32,  1024, 1536, 4,  1536//96,  4, 1, 2, 1, 1, 0),\n    # (32,  1024, 1536, 4,  1536//96,  2, 1, 4, 1, 1, 0),\n    (32,  1024, 1536, 4,  1536//96,  1, 8, 1, 1, 1, 0),\n    (32,  1024, 1536, 4,  1536//96,  1, 2, 4, 1, 1, 0),\n    (32,  1024, 1536, 4,  1536//96,  1, 4, 2, 1, 1, 0),\n    # (32,  128,  5120, 2,  5120//128, 1,  4, 2, 1, 1, 0),\n    # (32,  128,  5120, 2,  5120//128, 4,  1, 2, 1, 1, 0),\n\n]\n\n\ndef benchmark_all(args):\n    num_gpus = args.nproc_per_node * args.nnodes\n\n    benchmark_suites = {\n        2 : benchmark_suite_2_gpu,\n        4 : benchmark_suite_4_gpu,\n        8 : benchmark_suite_8_gpu,\n    }\n\n    for case in benchmark_suites[num_gpus]:\n        case_str = str(case)\n\n        if args.master_addr is None:\n            # Single node\n            ret = run_cmd('python3 -m torch.distributed.launch '\n                         f'--nproc_per_node {args.nproc_per_node} '\n                         'benchmark_transformer_layer_one_case.py '\n                         f'\"{case_str}\"')\n        else:\n            # Multiple nodes\n            ret = run_cmd('python3 -m torch.distributed.launch '\n                         f'--nproc_per_node {args.nproc_per_node} '\n                         f'--nnodes {args.nnodes} '\n                         f'--node_rank {args.node_rank} '\n                         f'--master_addr {args.master_addr} '\n                         f'--master_port {args.master_port} '\n                         'benchmark_transformer_layer_one_case.py '\n                         f'\"{case_str}\"')\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--nproc_per_node\", type=int, required=True)\n    parser.add_argument(\"--nnodes\", type=int, default=1)\n    parser.add_argument(\"--node_rank\", type=int)\n    parser.add_argument(\"--master_addr\", type=str)\n    parser.add_argument(\"--master_port\", type=str)\n    args = parser.parse_args()\n\n    benchmark_all(args)\n"
  },
  {
    "path": "benchmark/megatron/benchmark_transformer_layer_one_case.py",
    "content": "import time\n\nimport argparse\nimport os\nimport sys\nimport timeit\nfrom functools import partial\n\nimport numpy as np\n\nfrom benchmark.alpa.benchmark_gpt_bert import compute_tflops\nfrom megatron.model.transformer import ParallelTransformer, ParallelMLP\nfrom megatron.model.utils import init_method_normal, scaled_init_method_normal\nfrom megatron.model import DistributedDataParallel as LocalDDP\nfrom megatron.model import ModelType\nfrom megatron import mpu, initialize_megatron, get_args, get_timers\nfrom megatron.training import train_step, setup_model_and_optimizer\n\nimport torch\n\nfrom util import write_tsv, benchmark_func\n\nGB = 1024 ** 3\n\n# Note(Hao): in order for this to run with Megatron, disable the if-branch\n# here in Megatron: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/training.py#L390\n\n\ndef get_memory_usage(print_info=False):\n    \"\"\"Get accurate gpu memory usage by querying torch runtime\"\"\"\n    rank = torch.distributed.get_rank()\n    device = rank % torch.cuda.device_count()\n    allocated = torch.cuda.memory_allocated(device)\n    reserved = torch.cuda.memory_reserved(device)\n    if print_info:\n        print(\"allocated: %.2f GB\" % (allocated / GB), flush=True)\n        print(\"reserved:  %.2f GB\" % (reserved / GB), flush=True)\n    return allocated\n\n\ndef benchmark_transformer_layer_one_case(benchmark_case):\n    # Model configs\n    global_batch_size, seq_len, hidden_size, num_layers, num_heads, \\\n        dp_size, tensor_mp_size, pipeline_mp_size, num_micro_batches, \\\n        ddp_impl, checkpoint_activations = benchmark_case\n\n    # Parallel configs\n    assert global_batch_size % (dp_size * num_micro_batches) == 0\n    micro_batch_size = global_batch_size // dp_size // num_micro_batches\n\n    # Initialize megatron\n    sys.argv += [\"--micro-batch-size\", str(micro_batch_size)]\n    sys.argv += [\"--tensor-model-parallel-size\", str(tensor_mp_size)]\n    sys.argv += [\"--pipeline-model-parallel-size\", str(pipeline_mp_size)]\n    sys.argv += [\"--global-batch-size\", str(global_batch_size)]\n    sys.argv += [\"--num-layers\", str(num_layers)]\n    sys.argv += [\"--hidden-size\", str(hidden_size)]\n    sys.argv += [\"--num-attention-heads\", str(num_heads)]\n    sys.argv += [\"--max-position-embeddings\", str(seq_len)]\n    sys.argv += [\"--encoder-seq-length\", str(seq_len)]\n    sys.argv += [\"--optimizer\", \"adam\"]\n    sys.argv += [\"--train-iters\", \"100\"]\n    sys.argv += [\"--lr\", \"0.00015\"]\n    sys.argv += [\"--DDP-impl\", \"local\" if ddp_impl else \"torch\"]\n    # sys.argv += [\"--no-scatter-gather-tensors-in-pipeline\"]\n    # sys.argv += [\"--fp16\"]\n    if checkpoint_activations:\n        sys.argv += [\"--checkpoint-activations\"]\n\n    initialize_megatron()\n    rank = torch.distributed.get_rank()\n\n    # Check initialization\n    assert dp_size == mpu.get_data_parallel_world_size()\n    assert tensor_mp_size == mpu.get_tensor_model_parallel_world_size()\n    assert pipeline_mp_size == mpu.get_pipeline_model_parallel_world_size()\n\n    args = get_args()\n    micro_batch_size = args.micro_batch_size\n    seq_len = args.encoder_seq_length\n\n    i = torch.cuda.current_device()\n    x = torch.randn(seq_len, micro_batch_size, hidden_size).cuda(i)\n    y = torch.randn(seq_len, micro_batch_size, hidden_size).cuda(i)\n    attention_mask = torch.ones(micro_batch_size, 1, seq_len, seq_len). \\\n        to(torch.bool).cuda(i)\n\n\n    def get_transformer_functions():\n        args = get_args()\n\n        def model_provider(pre_process=True, post_process=True):\n            init_method_std = 0.02\n            init_method = init_method_normal(init_method_std)\n            scaled_init_method = scaled_init_method_normal(init_method_std, args.num_layers)\n            model = ParallelTransformer(init_method, scaled_init_method, 0,\n                                        pre_process=False, post_process=False)\n            model.cuda(torch.cuda.current_device())\n            return model\n\n        def loss_func(output_tensor):\n            loss = ((output_tensor - y) ** 2)\n            loss = loss.mean()\n            # averaged_losses = [0]\n            return loss, {\"avg loss\": 0}\n\n        def forward_step(data_iterator, model):\n            # Note(Hao): Megatron PP uses model.module.input_tensor to overwrite\n            # the input tensor to `model()`.\n            if model.module.input_tensor == [None]:\n                model.module.set_input_tensor(x)\n            else:\n                input_tensor = model.module.input_tensor\n                model.module.set_input_tensor(input_tensor[0])\n            output_tensor = model(x, attention_mask)\n            return output_tensor, loss_func\n\n        return model_provider, loss_func, forward_step\n\n    # Build model\n    model_provider, loss_func, forward_step = get_transformer_functions()\n    model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider,\n                                                               model_type=ModelType.encoder_or_decoder)\n    if rank == 0:\n        print(model)\n\n    def run_func():\n        train_step(forward_step, None, model, optimizer, lr_scheduler)\n\n    # Warmup and reset timers\n    run_func()\n    timers = get_timers()\n    names = list(timers.timers.keys())\n    for name in names:\n        timers(name).reset()\n\n    def sync_func():\n        torch.cuda.synchronize()\n\n    repeat = 10\n    number = 1\n    costs = benchmark_func(run_func, sync_func=sync_func,\n                           warmup=0, repeat=repeat, number=number)\n    timers.log(names, normalizer=repeat * number)\n\n\n    # Print results\n    # if rank == 0:\n    peak_mem = torch.cuda.max_memory_allocated(0)\n    heads = [\"Type\", \"Case\", \"Mesh Shape\", \"#MB\", \"DDP Impl\",\n             \"Peak Mem\", \"Mean Time\", \"Std Time\"]\n    values = [\"transformer-layer\", str(benchmark_case[:-3]),\n              str(benchmark_case[-6:-3]), str(benchmark_case[-3]), str(benchmark_case[-2]),\n              f\"{peak_mem/GB:5.3f}\", f\"{np.mean(costs):.3f}\", f\"{np.std(costs):.3f}\", ]\n    result_tsv = \"result_trans-\" + str(rank) + \".tsv\"\n    write_tsv(heads, values, result_tsv)\n    time.sleep(10)\n\n\nif __name__ == \"__main__\":\n    case = eval(sys.argv[-1])\n    del sys.argv[-1]\n    benchmark_transformer_layer_one_case(case)\n\n"
  },
  {
    "path": "build_jaxlib/.bazelrc",
    "content": "############################################################################\n# All default build options below.\n\n# Sets the default Apple platform to macOS.\nbuild --apple_platform_type=macos\nbuild --macos_minimum_os=10.14\n\n# Make Bazel print out all options from rc files.\nbuild --announce_rc\n\nbuild --define open_source_build=true\n\nbuild --spawn_strategy=standalone\n\nbuild --enable_platform_specific_config\n\nbuild --experimental_cc_shared_library\n\n# Disable enabled-by-default TensorFlow features that we don't care about.\nbuild --define=no_aws_support=true\nbuild --define=no_gcp_support=true\nbuild --define=no_hdfs_support=true\nbuild --define=no_kafka_support=true\nbuild --define=no_ignite_support=true\n\nbuild --define=grpc_no_ares=true\n\nbuild -c opt\n\nbuild --config=short_logs\n\nbuild --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.\n\n# Later Bazel flag values override earlier values; if CUDA/ROCM/TPU are enabled,\n# these values are overridden.\nbuild --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false\nbuild --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false\nbuild --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false\n\n###########################################################################\n\nbuild:posix --copt=-fvisibility=hidden\nbuild:posix --copt=-Wno-sign-compare\nbuild:posix --cxxopt=-std=c++17\nbuild:posix --host_cxxopt=-std=c++17\n\nbuild:avx_posix --copt=-mavx\nbuild:avx_posix --host_copt=-mavx\n\nbuild:avx_windows --copt=/arch=AVX\n\nbuild:avx_linux --copt=-mavx\nbuild:avx_linux --host_copt=-mavx\n\nbuild:native_arch_posix --copt=-march=native\nbuild:native_arch_posix --host_copt=-march=native\n\nbuild:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1\n\nbuild:cuda --repo_env TF_NEED_CUDA=1\n# \"sm\" means we emit only cubin, which is forward compatible within a GPU generation.\n# \"compute\" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations.\nbuild:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES=\"sm_35,sm_52,sm_60,sm_70,compute_80\"\nbuild:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain\nbuild:cuda --@local_config_cuda//:enable_cuda\nbuild:cuda --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true\nbuild:cuda --define=xla_python_enable_gpu=true\n\nbuild:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain\nbuild:rocm --define=using_rocm=true --define=using_rocm_hipcc=true\nbuild:rocm --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=true\nbuild:rocm --define=xla_python_enable_gpu=true\nbuild:rocm --repo_env TF_NEED_ROCM=1\nbuild:rocm --action_env TF_ROCM_AMDGPU_TARGETS=\"gfx900,gfx906,gfx908\"\n\nbuild:nonccl --define=no_nccl_support=true\n\n# Tensorflow uses M_* math constants that only get defined by MSVC headers if\n# _USE_MATH_DEFINES is defined.\nbuild:windows --copt=/D_USE_MATH_DEFINES\nbuild:windows --host_copt=/D_USE_MATH_DEFINES\n# Make sure to include as little of windows.h as possible\nbuild:windows --copt=-DWIN32_LEAN_AND_MEAN\nbuild:windows --host_copt=-DWIN32_LEAN_AND_MEAN\nbuild:windows --copt=-DNOGDI\nbuild:windows --host_copt=-DNOGDI\n# https://devblogs.microsoft.com/cppblog/announcing-full-support-for-a-c-c-conformant-preprocessor-in-msvc/\n# otherwise, there will be some compiling error due to preprocessing.\nbuild:windows --copt=/Zc:preprocessor\nbuild:windows --cxxopt=/std:c++17\nbuild:windows --host_cxxopt=/std:c++17\n# Generate PDB files, to generate useful PDBs, in opt compilation_mode\n# --copt /Z7 is needed.\nbuild:windows --linkopt=/DEBUG\nbuild:windows --host_linkopt=/DEBUG\nbuild:windows --linkopt=/OPT:REF\nbuild:windows --host_linkopt=/OPT:REF\nbuild:windows --linkopt=/OPT:ICF\nbuild:windows --host_linkopt=/OPT:ICF\nbuild:windows --incompatible_strict_action_env=true\n\nbuild:linux --config=posix\nbuild:linux --copt=-Wno-unknown-warning-option\n# Workaround for gcc 10+ warnings related to upb.\n# See https://github.com/tensorflow/tensorflow/issues/39467\nbuild:linux --copt=-Wno-stringop-truncation\nbuild:linux --copt=-Wno-array-parameter\n\nbuild:macos --config=posix\n\n# Suppress all warning messages.\nbuild:short_logs --output_filter=DONT_MATCH_ANYTHING\n\nbuild:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true\nbuild:tpu --define=with_tpu_support=true\n\nbuild:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true\n\n#########################################################################\n# RBE config options below.\n# Flag to enable remote config\ncommon --experimental_repo_remote_exec\n\nbuild:rbe --repo_env=BAZEL_DO_NOT_DETECT_CPP_TOOLCHAIN=1\nbuild:rbe --google_default_credentials\nbuild:rbe --bes_backend=buildeventservice.googleapis.com\nbuild:rbe --bes_results_url=\"https://source.cloud.google.com/results/invocations\"\nbuild:rbe --bes_timeout=600s\nbuild:rbe --define=EXECUTOR=remote\nbuild:rbe --distinct_host_configuration=false\nbuild:rbe --flaky_test_attempts=3\nbuild:rbe --jobs=200\nbuild:rbe --remote_executor=grpcs://remotebuildexecution.googleapis.com\nbuild:rbe --remote_timeout=3600\nbuild:rbe --spawn_strategy=remote,worker,standalone,local\ntest:rbe --test_env=USER=anon\n# Attempt to minimize the amount of data transfer between bazel and the remote\n# workers:\nbuild:rbe --remote_download_toplevel\n\nbuild:rbe_linux --config=rbe\nbuild:rbe_linux --action_env=PATH=\"/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/go/bin\"\nbuild:rbe_linux --host_javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8\nbuild:rbe_linux --javabase=@bazel_toolchains//configs/ubuntu16_04_clang/1.1:jdk8\nbuild:rbe_linux --host_java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8\nbuild:rbe_linux --java_toolchain=@bazel_tools//tools/jdk:toolchain_hostjdk8\n\n# Non-rbe settings we should include because we do not run configure\nbuild:rbe_linux --config=avx_linux\nbuild:rbe_linux --linkopt=-lrt\nbuild:rbe_linux --host_linkopt=-lrt\nbuild:rbe_linux --linkopt=-lm\nbuild:rbe_linux --host_linkopt=-lm\n\n# Use the GPU toolchain until the CPU one is ready.\n# https://github.com/bazelbuild/bazel/issues/13623\nbuild:rbe_cpu_linux_base --config=rbe_linux\nbuild:rbe_cpu_linux_base --host_crosstool_top=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain\"\nbuild:rbe_cpu_linux_base --crosstool_top=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain\"\nbuild:rbe_cpu_linux_base --extra_toolchains=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64\"\nbuild:rbe_cpu_linux_base --extra_execution_platforms=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_cpu_linux_base --host_platform=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_cpu_linux_base --platforms=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform\"\n\nbuild:rbe_cpu_linux_py37 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7\"\nbuild:rbe_cpu_linux_py37 --python_path=\"/usr/local/bin/python3.7\"\nbuild:rbe_cpu_linux_py38 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8\"\nbuild:rbe_cpu_linux_py38 --python_path=\"/usr/local/bin/python3.8\"\nbuild:rbe_cpu_linux_py39 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9\"\nbuild:rbe_cpu_linux_py39 --python_path=\"/usr/local/bin/python3.9\"\nbuild:rbe_cpu_linux_py310 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.10\"\nbuild:rbe_cpu_linux_py310 --python_path=\"/usr/local/bin/python3.10\"\n\nbuild:rbe_linux_cuda_base --config=rbe_linux\nbuild:rbe_linux_cuda_base --config=cuda\nbuild:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1\n\nbuild:rbe_linux_cuda11.1_nvcc_base --config=rbe_linux_cuda_base\nbuild:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDA_VERSION=11\nbuild:rbe_linux_cuda11.1_nvcc_base --action_env=TF_CUDNN_VERSION=8\nbuild:rbe_linux_cuda11.1_nvcc_base --action_env=CUDA_TOOLKIT_PATH=\"/usr/local/cuda-11.1\"\nbuild:rbe_linux_cuda11.1_nvcc_base --action_env=LD_LIBRARY_PATH=\"/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib\"\nbuild:rbe_linux_cuda11.1_nvcc_base --action_env=GCC_HOST_COMPILER_PATH=\"/dt9/usr/bin/gcc\"\ntest:rbe_linux_cuda11.1_nvcc_base --test_env=LD_LIBRARY_PATH=\"/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64\"\nbuild:rbe_linux_cuda11.1_nvcc_base --host_crosstool_top=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain\"\nbuild:rbe_linux_cuda11.1_nvcc_base --crosstool_top=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain\"\nbuild:rbe_linux_cuda11.1_nvcc_base --extra_toolchains=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64\"\nbuild:rbe_linux_cuda11.1_nvcc_base --extra_execution_platforms=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.1_nvcc_base --host_platform=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.1_nvcc_base --platforms=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_cuda\"\nbuild:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_tensorrt\"\nbuild:rbe_linux_cuda11.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_nccl\"\nbuild:rbe_linux_cuda11.1_nvcc_py3.7 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.7\"\nbuild:rbe_linux_cuda11.1_nvcc_py3.7 --python_path=\"/usr/local/bin/python3.7\"\nbuild:rbe_linux_cuda11.1_nvcc_py3.8 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.8\"\nbuild:rbe_linux_cuda11.1_nvcc_py3.8 --python_path=\"/usr/local/bin/python3.8\"\nbuild:rbe_linux_cuda11.1_nvcc_py3.9 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.9\"\nbuild:rbe_linux_cuda11.1_nvcc_py3.9 --python_path=\"/usr/local/bin/python3.9\"\nbuild:rbe_linux_cuda11.1_nvcc_py3.10 --config=rbe_linux_cuda11.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.1-cudnn8-tensorrt7.2_config_python3.10\"\nbuild:rbe_linux_cuda11.1_nvcc_py3.10 --python_path=\"/usr/local/bin/python3.10\"\n\nbuild:rbe_linux_cuda11.2_nvcc_base --config=rbe_linux_cuda_base\nbuild:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDA_VERSION=11\nbuild:rbe_linux_cuda11.2_nvcc_base --action_env=TF_CUDNN_VERSION=8\nbuild:rbe_linux_cuda11.2_nvcc_base --action_env=CUDA_TOOLKIT_PATH=\"/usr/local/cuda-11.2\"\nbuild:rbe_linux_cuda11.2_nvcc_base --action_env=LD_LIBRARY_PATH=\"/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib\"\nbuild:rbe_linux_cuda11.2_nvcc_base --action_env=GCC_HOST_COMPILER_PATH=\"/dt9/usr/bin/gcc\"\ntest:rbe_linux_cuda11.2_nvcc_base --test_env=LD_LIBRARY_PATH=\"/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda-11.1/lib64\"\nbuild:rbe_linux_cuda11.2_nvcc_base --host_crosstool_top=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain\"\nbuild:rbe_linux_cuda11.2_nvcc_base --crosstool_top=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain\"\nbuild:rbe_linux_cuda11.2_nvcc_base --extra_toolchains=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64\"\nbuild:rbe_linux_cuda11.2_nvcc_base --extra_execution_platforms=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.2_nvcc_base --host_platform=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.2_nvcc_base --platforms=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda\"\nbuild:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_tensorrt\"\nbuild:rbe_linux_cuda11.2_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_nccl\"\nbuild:rbe_linux_cuda11.2_nvcc_py3.7 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.7\"\nbuild:rbe_linux_cuda11.2_nvcc_py3.7 --python_path=\"/usr/local/bin/python3.7\"\nbuild:rbe_linux_cuda11.2_nvcc_py3.8 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.8\"\nbuild:rbe_linux_cuda11.2_nvcc_py3.8 --python_path=\"/usr/local/bin/python3.8\"\nbuild:rbe_linux_cuda11.2_nvcc_py3.9 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.9\"\nbuild:rbe_linux_cuda11.2_nvcc_py3.9 --python_path=\"/usr/local/bin/python3.9\"\nbuild:rbe_linux_cuda11.2_nvcc_py3.10 --config=rbe_linux_cuda11.2_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_python3.10\"\nbuild:rbe_linux_cuda11.2_nvcc_py3.10 --python_path=\"/usr/local/bin/python3.10\"\n\nbuild:rbe_linux_cuda11.4_nvcc_base --config=rbe_linux_cuda_base\nbuild:rbe_linux_cuda11.4_nvcc_base --action_env=TF_CUDA_VERSION=11\nbuild:rbe_linux_cuda11.4_nvcc_base --action_env=TF_CUDNN_VERSION=8\nbuild:rbe_linux_cuda11.4_nvcc_base --action_env=CUDA_TOOLKIT_PATH=\"/usr/local/cuda-11.4\"\nbuild:rbe_linux_cuda11.4_nvcc_base --action_env=LD_LIBRARY_PATH=\"/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib\"\nbuild:rbe_linux_cuda11.4_nvcc_base --action_env=GCC_HOST_COMPILER_PATH=\"/dt9/usr/bin/gcc\"\nbuild:rbe_linux_cuda11.4_nvcc_base --host_crosstool_top=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain\"\nbuild:rbe_linux_cuda11.4_nvcc_base --crosstool_top=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain\"\nbuild:rbe_linux_cuda11.4_nvcc_base --extra_toolchains=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda//crosstool:toolchain-linux-x86_64\"\nbuild:rbe_linux_cuda11.4_nvcc_base --extra_execution_platforms=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.4_nvcc_base --host_platform=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.4_nvcc_base --platforms=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_platform//:platform\"\nbuild:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_cuda\"\nbuild:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_tensorrt\"\nbuild:rbe_linux_cuda11.4_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_nccl\"\nbuild:rbe_linux_cuda11.4_nvcc_py3.7 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.7\"\nbuild:rbe_linux_cuda11.4_nvcc_py3.7 --python_path=\"/usr/local/bin/python3.7\"\nbuild:rbe_linux_cuda11.4_nvcc_py3.8 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.8\"\nbuild:rbe_linux_cuda11.4_nvcc_py3.8 --python_path=\"/usr/local/bin/python3.8\"\nbuild:rbe_linux_cuda11.4_nvcc_py3.9 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.9\"\nbuild:rbe_linux_cuda11.4_nvcc_py3.9 --python_path=\"/usr/local/bin/python3.9\"\nbuild:rbe_linux_cuda11.4_nvcc_py3.10 --config=rbe_linux_cuda11.4_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO=\"@ubuntu20.04-gcc9_manylinux2014-cuda11.4-cudnn8.2-tensorrt7.2_config_python3.10\"\nbuild:rbe_linux_cuda11.4_nvcc_py3.10 --python_path=\"/usr/local/bin/python3.10\"\n\n# These you may need to change for your own GCP project.\nbuild:tensorflow_testing_rbe --project_id=tensorflow-testing\ncommon:tensorflow_testing_rbe_linux --remote_instance_name=projects/tensorflow-testing/instances/default_instance\nbuild:tensorflow_testing_rbe_linux --config=tensorflow_testing_rbe\n#############################################################################\n\n# Load `.jax_configure.bazelrc` file written by build.py\ntry-import %workspace%/.jax_configure.bazelrc\n\n# Load rc file with user-specific options.\ntry-import %workspace%/.bazelrc.user\n"
  },
  {
    "path": "build_jaxlib/.bazelversion",
    "content": "5.1.1\n"
  },
  {
    "path": "build_jaxlib/WORKSPACE",
    "content": "load(\"@bazel_tools//tools/build_defs/repo:http.bzl\", \"http_archive\")\n\n# To update TensorFlow to a new revision,\n# a) update URL and strip_prefix to the new git commit hash\n# b) get the sha256 hash of the commit by running:\n#    curl -L https://github.com/tensorflow/tensorflow/archive/<git hash>.tar.gz | sha256sum\n#    and update the sha256 with the result.\nhttp_archive(\n    name = \"org_tensorflow\",\n    sha256 = \"9a7a7a87356bdeef5874fae135de380466482b593469035be3609a9cd2c153c4\",\n    strip_prefix = \"tensorflow-cb946f223b9b3fa04efdbb7a0e6a9dabb22a7057\",\n    urls = [\n        \"https://github.com/tensorflow/tensorflow/archive/cb946f223b9b3fa04efdbb7a0e6a9dabb22a7057.tar.gz\",\n    ],\n)\n\n# For development, one often wants to make changes to the TF repository as well\n# as the JAX repository. You can override the pinned repository above with a\n# local checkout by either:\n# a) overriding the TF repository on the build.py command line by passing a flag\n#    like:\n#    python build/build.py --bazel_options=--override_repository=org_tensorflow=/path/to/tensorflow\n#    or\n# b) by commenting out the http_archive above and uncommenting the following:\n# local_repository(\n#    name = \"org_tensorflow\",\n#    path = \"/path/to/tensorflow\",\n# )\n\nload(\"//third_party/ducc:workspace.bzl\", ducc = \"repo\")\nducc()\n\n# Initialize TensorFlow's external dependencies.\nload(\"@org_tensorflow//tensorflow:workspace3.bzl\", \"tf_workspace3\")\ntf_workspace3()\n\nload(\"@org_tensorflow//tensorflow:workspace2.bzl\", \"tf_workspace2\")\ntf_workspace2()\n\nload(\"@org_tensorflow//tensorflow:workspace1.bzl\", \"tf_workspace1\")\ntf_workspace1()\n\nload(\"@org_tensorflow//tensorflow:workspace0.bzl\", \"tf_workspace0\")\ntf_workspace0()\n"
  },
  {
    "path": "build_jaxlib/build/BUILD.bazel",
    "content": "# Copyright 2018 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# JAX is Autograd and XLA\n\nload(\"@bazel_skylib//rules:common_settings.bzl\", \"bool_flag\")\nload(\"@local_config_cuda//cuda:build_defs.bzl\", \"if_cuda\")\nload(\"@local_config_rocm//rocm:build_defs.bzl\", \"if_rocm\")\nload(\"//jaxlib:jax.bzl\", \"if_windows\")\n\nlicenses([\"notice\"])  # Apache 2\n\npackage(default_visibility = [\"//visibility:public\"])\n\nbool_flag(\n    name = \"enable_remote_tpu\",\n    build_setting_default = False,\n)\n\nconfig_setting(\n    name = \"remote_tpu_enabled\",\n    flag_values = {\n        \":enable_remote_tpu\": \"True\",\n    },\n)\n\npy_binary(\n    name = \"build_wheel\",\n    srcs = [\"build_wheel.py\"],\n    data = [\n        \"LICENSE.txt\",\n        \"//jaxlib\",\n        \"//jaxlib:README.md\",\n        \"//jaxlib:setup.py\",\n        \"//jaxlib:setup.cfg\",\n        \"@org_tensorflow//tensorflow/compiler/xla/python:xla_client\",\n    ] + if_windows([\n        \"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll\",\n    ]) + select({\n        \":remote_tpu_enabled\": [\"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client\"],\n        \"//conditions:default\": [],\n    }) + if_cuda([\n        \"//jaxlib/cuda:cuda_gpu_support\",\n        \"@local_config_cuda//cuda:cuda-nvvm\",\n    ]) + if_rocm([\n        \"//jaxlib/rocm:rocm_gpu_support\",\n    ]),\n    deps = [\"@bazel_tools//tools/python/runfiles\"],\n)\n"
  },
  {
    "path": "build_jaxlib/build/LICENSE.txt",
    "content": "--------------------------------------------------------------------------------\nLicense for JAX:\n\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n\n--------------------------------------------------------------------------------\nLicense for BoringSSL:\nBoringSSL is a fork of OpenSSL. As such, large parts of it fall under OpenSSL\nlicensing. Files that are completely new have a Google copyright and an ISC\nlicense. This license is reproduced at the bottom of this file.\n\nContributors to BoringSSL are required to follow the CLA rules for Chromium:\nhttps://cla.developers.google.com/clas\n\nFiles in third_party/ have their own licenses, as described therein. The MIT\nlicense, for third_party/fiat, which, unlike other third_party directories, is\ncompiled into non-test libraries, is included below.\n\nThe OpenSSL toolkit stays under a dual license, i.e. both the conditions of the\nOpenSSL License and the original SSLeay license apply to the toolkit. See below\nfor the actual license texts. Actually both licenses are BSD-style Open Source\nlicenses. In case of any license issues related to OpenSSL please contact\nopenssl-core@openssl.org.\n\nThe following are Google-internal bug numbers where explicit permission from\nsome authors is recorded for use of their work. (This is purely for our own\nrecord keeping.)\n  27287199\n  27287880\n  27287883\n\n  OpenSSL License\n  ---------------\n\n/* ====================================================================\n * Copyright (c) 1998-2011 The OpenSSL Project.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions\n * are met:\n *\n * 1. Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer. \n *\n * 2. Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in\n *    the documentation and/or other materials provided with the\n *    distribution.\n *\n * 3. All advertising materials mentioning features or use of this\n *    software must display the following acknowledgment:\n *    \"This product includes software developed by the OpenSSL Project\n *    for use in the OpenSSL Toolkit. (http://www.openssl.org/)\"\n *\n * 4. The names \"OpenSSL Toolkit\" and \"OpenSSL Project\" must not be used to\n *    endorse or promote products derived from this software without\n *    prior written permission. For written permission, please contact\n *    openssl-core@openssl.org.\n *\n * 5. Products derived from this software may not be called \"OpenSSL\"\n *    nor may \"OpenSSL\" appear in their names without prior written\n *    permission of the OpenSSL Project.\n *\n * 6. Redistributions of any form whatsoever must retain the following\n *    acknowledgment:\n *    \"This product includes software developed by the OpenSSL Project\n *    for use in the OpenSSL Toolkit (http://www.openssl.org/)\"\n *\n * THIS SOFTWARE IS PROVIDED BY THE OpenSSL PROJECT ``AS IS'' AND ANY\n * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\n * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE OpenSSL PROJECT OR\n * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\n * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT\n * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\n * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)\n * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,\n * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED\n * OF THE POSSIBILITY OF SUCH DAMAGE.\n * ====================================================================\n *\n * This product includes cryptographic software written by Eric Young\n * (eay@cryptsoft.com).  This product includes software written by Tim\n * Hudson (tjh@cryptsoft.com).\n *\n */\n\n Original SSLeay License\n -----------------------\n\n/* Copyright (C) 1995-1998 Eric Young (eay@cryptsoft.com)\n * All rights reserved.\n *\n * This package is an SSL implementation written\n * by Eric Young (eay@cryptsoft.com).\n * The implementation was written so as to conform with Netscapes SSL.\n * \n * This library is free for commercial and non-commercial use as long as\n * the following conditions are aheared to.  The following conditions\n * apply to all code found in this distribution, be it the RC4, RSA,\n * lhash, DES, etc., code; not just the SSL code.  The SSL documentation\n * included with this distribution is covered by the same copyright terms\n * except that the holder is Tim Hudson (tjh@cryptsoft.com).\n * \n * Copyright remains Eric Young's, and as such any Copyright notices in\n * the code are not to be removed.\n * If this package is used in a product, Eric Young should be given attribution\n * as the author of the parts of the library used.\n * This can be in the form of a textual message at program startup or\n * in documentation (online or textual) provided with the package.\n * \n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions\n * are met:\n * 1. Redistributions of source code must retain the copyright\n *    notice, this list of conditions and the following disclaimer.\n * 2. Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n * 3. All advertising materials mentioning features or use of this software\n *    must display the following acknowledgement:\n *    \"This product includes cryptographic software written by\n *     Eric Young (eay@cryptsoft.com)\"\n *    The word 'cryptographic' can be left out if the rouines from the library\n *    being used are not cryptographic related :-).\n * 4. If you include any Windows specific code (or a derivative thereof) from \n *    the apps directory (application code) you must include an acknowledgement:\n *    \"This product includes software written by Tim Hudson (tjh@cryptsoft.com)\"\n * \n * THIS SOFTWARE IS PROVIDED BY ERIC YOUNG ``AS IS'' AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS\n * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)\n * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT\n * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY\n * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGE.\n * \n * The licence and distribution terms for any publically available version or\n * derivative of this code cannot be changed.  i.e. this code cannot simply be\n * copied and put under another distribution licence\n * [including the GNU Public Licence.]\n */\n\n\nISC license used for completely new code in BoringSSL:\n\n/* Copyright (c) 2015, Google Inc.\n *\n * Permission to use, copy, modify, and/or distribute this software for any\n * purpose with or without fee is hereby granted, provided that the above\n * copyright notice and this permission notice appear in all copies.\n *\n * THE SOFTWARE IS PROVIDED \"AS IS\" AND THE AUTHOR DISCLAIMS ALL WARRANTIES\n * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF\n * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY\n * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES\n * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION\n * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN\n * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */\n\n\nThe code in third_party/fiat carries the MIT license:\n\nCopyright (c) 2015-2016 the fiat-crypto authors (see\nhttps://github.com/mit-plv/fiat-crypto/blob/main/AUTHORS).\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n\nLicenses for support code\n-------------------------\n\nParts of the TLS test suite are under the Go license. This code is not included\nin BoringSSL (i.e. libcrypto and libssl) when compiled, however, so\ndistributing code linked against BoringSSL does not trigger this license:\n\nCopyright (c) 2009 The Go Authors. All rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are\nmet:\n\n   * Redistributions of source code must retain the above copyright\nnotice, this list of conditions and the following disclaimer.\n   * Redistributions in binary form must reproduce the above\ncopyright notice, this list of conditions and the following disclaimer\nin the documentation and/or other materials provided with the\ndistribution.\n   * Neither the name of Google Inc. nor the names of its\ncontributors may be used to endorse or promote products derived from\nthis software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n\nBoringSSL uses the Chromium test infrastructure to run a continuous build,\ntrybots etc. The scripts which manage this, and the script for generating build\nmetadata, are under the Chromium license. Distributing code linked against\nBoringSSL does not trigger this license.\n\nCopyright 2015 The Chromium Authors. All rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are\nmet:\n\n   * Redistributions of source code must retain the above copyright\nnotice, this list of conditions and the following disclaimer.\n   * Redistributions in binary form must reproduce the above\ncopyright notice, this list of conditions and the following disclaimer\nin the documentation and/or other materials provided with the\ndistribution.\n   * Neither the name of Google Inc. nor the names of its\ncontributors may be used to endorse or promote products derived from\nthis software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n--------------------------------------------------------------------------------\nLicense for gRPC:\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n--------------------------------------------------------------------------------\nLicense for Abseil:\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        https://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       https://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n\n--------------------------------------------------------------------------------\nLicense for Protocol buffers:\nCopyright 2008, Google Inc.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are\nmet:\n\n    * Redistributions of source code must retain the above copyright\nnotice, this list of conditions and the following disclaimer.\n    * Redistributions in binary form must reproduce the above\ncopyright notice, this list of conditions and the following disclaimer\nin the documentation and/or other materials provided with the\ndistribution.\n    * Neither the name of Google Inc. nor the names of its\ncontributors may be used to endorse or promote products derived from\nthis software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nCode generated by the Protocol Buffer compiler is owned by the owner\nof the input file used when generating it.  This code is not\nstandalone and requires a support library to be linked with it.  This\nsupport library is itself covered by the above license.\n\n--------------------------------------------------------------------------------\nLicense for RE2:\n// Copyright (c) 2009 The RE2 Authors. All rights reserved.\n//\n// Redistribution and use in source and binary forms, with or without\n// modification, are permitted provided that the following conditions are\n// met:\n//\n//    * Redistributions of source code must retain the above copyright\n// notice, this list of conditions and the following disclaimer.\n//    * Redistributions in binary form must reproduce the above\n// copyright notice, this list of conditions and the following disclaimer\n// in the documentation and/or other materials provided with the\n// distribution.\n//    * Neither the name of Google Inc. nor the names of its\n// contributors may be used to endorse or promote products derived from\n// this software without specific prior written permission.\n//\n// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n// \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\n// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\n// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\n// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\n// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\n// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\n// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\n// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n--------------------------------------------------------------------------------\nLicense for DLPack:\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"{}\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2017 by Contributors\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n--------------------------------------------------------------------------------\nLicense for double-conversion:\nCopyright 2006-2011, the V8 project authors. All rights reserved.\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are\nmet:\n\n    * Redistributions of source code must retain the above copyright\n      notice, this list of conditions and the following disclaimer.\n    * Redistributions in binary form must reproduce the above\n      copyright notice, this list of conditions and the following\n      disclaimer in the documentation and/or other materials provided\n      with the distribution.\n    * Neither the name of Google Inc. nor the names of its\n      contributors may be used to endorse or promote products derived\n      from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n--------------------------------------------------------------------------------\nLicense for Eigen:\nEigen 3.3.90\nThe corresponding source for this library is available at\nhttps://eigen.googlesource.com/mirror/\n\nEigen is primarily MPL2 licensed. See COPYING.MPL2 and these links:\n  http://www.mozilla.org/MPL/2.0/\n  http://www.mozilla.org/MPL/2.0/FAQ.html\n\nSome files contain third-party code under BSD, whence\nthe other COPYING.* files here.\n\nIf you want to guarantee that the Eigen code that you are #including\nis licensed under the MPL2 and possibly more permissive licenses (like\nBSD), #define this preprocessor symbol: EIGEN_MPL2_ONLY \nFor example, with most compilers, you could add this to your project\n      CXXFLAGS: -DEIGEN_MPL2_ONLY \nThis will cause a compilation error to be generated if you #include\nany code that is covered by more restrictive licences than MPL2.\n\n----------------------------------------------------------------------\nFollowing applies to:\n./test/sparseqr.cpp\n./test/half_float.cpp\n./test/zerosized.cpp\n./test/nesting_ops.cpp\n./test/sizeoverflow.cpp\n./test/swap.cpp\n./test/product_mmtr.cpp\n./test/stdvector_overload.cpp\n./test/product_symm.cpp\n./test/sparse_block.cpp\n./test/eigen2support.cpp\n./test/upperbidiagonalization.cpp\n./test/numext.cpp\n./test/adjoint.cpp\n./test/AnnoyingScalar.h\n./test/mpl2only.cpp\n./test/stddeque.cpp\n./test/householder.cpp\n./test/product_small.cpp\n./test/product_syrk.cpp\n./test/inplace_decomposition.cpp\n./test/vectorwiseop.cpp\n./test/meta.cpp\n./test/stdvector.cpp\n./test/sparseLM.cpp\n./test/diagonalmatrices.cpp\n./test/stdlist_overload.cpp\n./test/block.cpp\n./test/cholmod_support.cpp\n./test/basicstuff.cpp\n./test/triangular.cpp\n./test/product.h\n./test/vectorization_logic.cpp\n./test/dontalign.cpp\n./test/first_aligned.cpp\n./test/mapped_matrix.cpp\n./test/umfpack_support.cpp\n./test/product_selfadjoint.cpp\n./test/smallvectors.cpp\n./test/corners.cpp\n./test/product_trsolve.cpp\n./test/determinant.cpp\n./test/stdlist.cpp\n./test/unalignedcount.cpp\n./test/qr.cpp\n./test/svd_common.h\n./test/ref.cpp\n./test/symbolic_index.cpp\n./test/geo_transformations.cpp\n./test/geo_eulerangles.cpp\n./test/eigensolver_selfadjoint.cpp\n./test/stddeque_overload.cpp\n./test/jacobisvd.cpp\n./test/nullary.cpp\n./test/inverse.cpp\n./test/integer_types.cpp\n./test/metis_support.cpp\n./test/exceptions.cpp\n./test/packetmath.cpp\n./test/schur_complex.cpp\n./test/type_alias.cpp\n./test/unalignedassert.cpp\n./test/geo_quaternion.cpp\n./test/lu.cpp\n./test/qr_fullpivoting.cpp\n./test/denseLM.cpp\n./test/linearstructure.cpp\n./test/rand.cpp\n./test/conservative_resize.cpp\n./test/eigensolver_generalized_real.cpp\n./test/pastix_support.cpp\n./test/sparse_solver.h\n./test/num_dimensions.cpp\n./test/simplicial_cholesky.cpp\n./test/hessenberg.cpp\n./test/array_reverse.cpp\n./test/special_numbers.cpp\n./test/array_for_matrix.cpp\n./test/product_large.cpp\n./test/resize.cpp\n./test/sparse_solvers.cpp\n./test/selfadjoint.cpp\n./test/schur_real.cpp\n./test/sparse_basic.cpp\n./test/conjugate_gradient.cpp\n./test/real_qz.cpp\n./test/bandmatrix.cpp\n./test/dense_storage.cpp\n./test/permutationmatrices.cpp\n./test/array_cwise.cpp\n./test/qr_colpivoting.cpp\n./test/array_replicate.cpp\n./test/rvalue_types.cpp\n./test/stable_norm.cpp\n./test/geo_homogeneous.cpp\n./test/main.h\n./test/eigensolver_complex.cpp\n./test/product_trmm.cpp\n./test/bicgstab.cpp\n./test/redux.cpp\n./test/klu_support.cpp\n./test/geo_alignedbox.cpp\n./test/is_same_dense.cpp\n./test/sparse_permutations.cpp\n./test/sparse_vector.cpp\n./test/diagonal.cpp\n./test/sparse.h\n./test/mapstride.cpp\n./test/visitor.cpp\n./test/geo_hyperplane.cpp\n./test/bdcsvd.cpp\n./test/product_trmv.cpp\n./test/nestbyvalue.cpp\n./test/array_of_string.cpp\n./test/superlu_support.cpp\n./test/sizeof.cpp\n./test/boostmultiprec.cpp\n./test/commainitializer.cpp\n./test/constructor.cpp\n./test/mixingtypes.cpp\n./test/miscmatrices.cpp\n./test/mapstaticmethods.cpp\n./test/product_notemporary.cpp\n./test/initializer_list_construction.cpp\n./test/incomplete_cholesky.cpp\n./test/geo_parametrizedline.cpp\n./test/indexed_view.cpp\n./test/qtvector.cpp\n./test/sparselu.cpp\n./test/sparse_product.cpp\n./test/dynalloc.cpp\n./test/fastmath.cpp\n./test/prec_inverse_4x4.cpp\n./test/umeyama.cpp\n./test/reshape.cpp\n./test/product_extra.cpp\n./test/jacobi.cpp\n./test/sparse_ref.cpp\n./test/nomalloc.cpp\n./test/spqr_support.cpp\n./test/lscg.cpp\n./test/cholesky.cpp\n./test/eigensolver_generic.cpp\n./test/geo_orthomethods.cpp\n./test/svd_fill.h\n./test/stl_iterators.cpp\n./Eigen/src/MetisSupport/MetisSupport.h\n./Eigen/src/CholmodSupport/CholmodSupport.h\n./Eigen/src/QR/CompleteOrthogonalDecomposition.h\n./Eigen/src/QR/FullPivHouseholderQR.h\n./Eigen/src/QR/HouseholderQR.h\n./Eigen/src/QR/ColPivHouseholderQR.h\n./Eigen/src/plugins/CommonCwiseUnaryOps.h\n./Eigen/src/plugins/BlockMethods.h\n./Eigen/src/plugins/CommonCwiseBinaryOps.h\n./Eigen/src/plugins/MatrixCwiseUnaryOps.h\n./Eigen/src/plugins/IndexedViewMethods.h\n./Eigen/src/plugins/MatrixCwiseBinaryOps.h\n./Eigen/src/SVD/UpperBidiagonalization.h\n./Eigen/src/SVD/SVDBase.h\n./Eigen/src/SVD/BDCSVD.h\n./Eigen/src/SVD/JacobiSVD.h\n./Eigen/src/SparseLU/SparseLU_relax_snode.h\n./Eigen/src/SparseLU/SparseLU_column_dfs.h\n./Eigen/src/SparseLU/SparseLU_SupernodalMatrix.h\n./Eigen/src/SparseLU/SparseLU_pivotL.h\n./Eigen/src/SparseLU/SparseLU.h\n./Eigen/src/SparseLU/SparseLU_pruneL.h\n./Eigen/src/SparseLU/SparseLU_copy_to_ucol.h\n./Eigen/src/SparseLU/SparseLU_heap_relax_snode.h\n./Eigen/src/SparseLU/SparseLU_kernel_bmod.h\n./Eigen/src/SparseLU/SparseLU_panel_dfs.h\n./Eigen/src/SparseLU/SparseLU_panel_bmod.h\n./Eigen/src/SparseLU/SparseLU_Structs.h\n./Eigen/src/SparseLU/SparseLUImpl.h\n./Eigen/src/SparseLU/SparseLU_Memory.h\n./Eigen/src/SparseLU/SparseLU_column_bmod.h\n./Eigen/src/SparseLU/SparseLU_gemm_kernel.h\n./Eigen/src/SparseLU/SparseLU_Utils.h\n./Eigen/src/OrderingMethods/Eigen_Colamd.h\n./Eigen/src/OrderingMethods/Ordering.h\n./Eigen/src/OrderingMethods/Amd.h\n./Eigen/src/UmfPackSupport/UmfPackSupport.h\n./Eigen/src/Geometry/Umeyama.h\n./Eigen/src/Geometry/Transform.h\n./Eigen/src/Geometry/OrthoMethods.h\n./Eigen/src/Geometry/Hyperplane.h\n./Eigen/src/Geometry/Homogeneous.h\n./Eigen/src/Geometry/RotationBase.h\n./Eigen/src/Geometry/EulerAngles.h\n./Eigen/src/Geometry/Translation.h\n./Eigen/src/Geometry/Rotation2D.h\n./Eigen/src/Geometry/Scaling.h\n./Eigen/src/Geometry/AlignedBox.h\n./Eigen/src/Geometry/ParametrizedLine.h\n./Eigen/src/Geometry/Quaternion.h\n./Eigen/src/Geometry/AngleAxis.h\n./Eigen/src/Geometry/arch/Geometry_SSE.h\n./Eigen/src/KLUSupport/KLUSupport.h\n./Eigen/src/misc/Kernel.h\n./Eigen/src/misc/RealSvd2x2.h\n./Eigen/src/misc/Image.h\n./Eigen/src/StlSupport/details.h\n./Eigen/src/StlSupport/StdList.h\n./Eigen/src/StlSupport/StdDeque.h\n./Eigen/src/StlSupport/StdVector.h\n./Eigen/src/SparseQR/SparseQR.h\n./Eigen/src/SuperLUSupport/SuperLUSupport.h\n./Eigen/src/Householder/Householder.h\n./Eigen/src/Householder/HouseholderSequence.h\n./Eigen/src/Householder/BlockHouseholder.h\n./Eigen/src/Eigenvalues/SelfAdjointEigenSolver.h\n./Eigen/src/Eigenvalues/EigenSolver.h\n./Eigen/src/Eigenvalues/GeneralizedEigenSolver.h\n./Eigen/src/Eigenvalues/Tridiagonalization.h\n./Eigen/src/Eigenvalues/HessenbergDecomposition.h\n./Eigen/src/Eigenvalues/RealQZ.h\n./Eigen/src/Eigenvalues/RealSchur.h\n./Eigen/src/Eigenvalues/ComplexSchur.h\n./Eigen/src/Eigenvalues/ComplexEigenSolver.h\n./Eigen/src/Eigenvalues/MatrixBaseEigenvalues.h\n./Eigen/src/Eigenvalues/GeneralizedSelfAdjointEigenSolver.h\n./Eigen/src/SparseCholesky/SimplicialCholesky.h\n./Eigen/src/SparseCholesky/SimplicialCholesky_impl.h\n./Eigen/src/Cholesky/LLT.h\n./Eigen/src/Cholesky/LDLT.h\n./Eigen/src/Jacobi/Jacobi.h\n./Eigen/src/PaStiXSupport/PaStiXSupport.h\n./Eigen/src/SPQRSupport/SuiteSparseQRSupport.h\n./Eigen/src/LU/Determinant.h\n./Eigen/src/LU/InverseImpl.h\n./Eigen/src/LU/PartialPivLU.h\n./Eigen/src/LU/arch/Inverse_SSE.h\n./Eigen/src/LU/FullPivLU.h\n./Eigen/src/Core/Map.h\n./Eigen/src/Core/VectorwiseOp.h\n./Eigen/src/Core/VectorBlock.h\n./Eigen/src/Core/Array.h\n./Eigen/src/Core/Assign.h\n./Eigen/src/Core/Dot.h\n./Eigen/src/Core/NestByValue.h\n./Eigen/src/Core/CoreEvaluators.h\n./Eigen/src/Core/ReturnByValue.h\n./Eigen/src/Core/SelfCwiseBinaryOp.h\n./Eigen/src/Core/GlobalFunctions.h\n./Eigen/src/Core/Transpositions.h\n./Eigen/src/Core/Fuzzy.h\n./Eigen/src/Core/NoAlias.h\n./Eigen/src/Core/CwiseNullaryOp.h\n./Eigen/src/Core/NumTraits.h\n./Eigen/src/Core/IndexedView.h\n./Eigen/src/Core/ArrayWrapper.h\n./Eigen/src/Core/util/SymbolicIndex.h\n./Eigen/src/Core/util/BlasUtil.h\n./Eigen/src/Core/util/Constants.h\n./Eigen/src/Core/util/IntegralConstant.h\n./Eigen/src/Core/util/ReshapedHelper.h\n./Eigen/src/Core/util/StaticAssert.h\n./Eigen/src/Core/util/IndexedViewHelper.h\n./Eigen/src/Core/util/ConfigureVectorization.h\n./Eigen/src/Core/util/ForwardDeclarations.h\n./Eigen/src/Core/util/Meta.h\n./Eigen/src/Core/util/XprHelper.h\n./Eigen/src/Core/util/Macros.h\n./Eigen/src/Core/util/Memory.h\n./Eigen/src/Core/Product.h\n./Eigen/src/Core/Replicate.h\n./Eigen/src/Core/ArrayBase.h\n./Eigen/src/Core/functors/NullaryFunctors.h\n./Eigen/src/Core/functors/StlFunctors.h\n./Eigen/src/Core/functors/AssignmentFunctors.h\n./Eigen/src/Core/functors/UnaryFunctors.h\n./Eigen/src/Core/functors/TernaryFunctors.h\n./Eigen/src/Core/functors/BinaryFunctors.h\n./Eigen/src/Core/Redux.h\n./Eigen/src/Core/EigenBase.h\n./Eigen/src/Core/SolverBase.h\n./Eigen/src/Core/ProductEvaluators.h\n./Eigen/src/Core/Block.h\n./Eigen/src/Core/SolveTriangular.h\n./Eigen/src/Core/ArithmeticSequence.h\n./Eigen/src/Core/MatrixBase.h\n./Eigen/src/Core/PlainObjectBase.h\n./Eigen/src/Core/Transpose.h\n./Eigen/src/Core/IO.h\n./Eigen/src/Core/MathFunctions.h\n./Eigen/src/Core/Stride.h\n./Eigen/src/Core/MathFunctionsImpl.h\n./Eigen/src/Core/StableNorm.h\n./Eigen/src/Core/DiagonalProduct.h\n./Eigen/src/Core/products/GeneralMatrixMatrix.h\n./Eigen/src/Core/products/GeneralMatrixVector.h\n./Eigen/src/Core/products/SelfadjointMatrixVector.h\n./Eigen/src/Core/products/GeneralBlockPanelKernel.h\n./Eigen/src/Core/products/TriangularSolverMatrix.h\n./Eigen/src/Core/products/SelfadjointMatrixMatrix.h\n./Eigen/src/Core/products/Parallelizer.h\n./Eigen/src/Core/products/SelfadjointRank2Update.h\n./Eigen/src/Core/products/TriangularMatrixMatrix.h\n./Eigen/src/Core/products/TriangularMatrixVector.h\n./Eigen/src/Core/products/SelfadjointProduct.h\n./Eigen/src/Core/products/GeneralMatrixMatrixTriangular.h\n./Eigen/src/Core/products/TriangularSolverVector.h\n./Eigen/src/Core/CwiseUnaryView.h\n./Eigen/src/Core/CommaInitializer.h\n./Eigen/src/Core/DenseStorage.h\n./Eigen/src/Core/DenseBase.h\n./Eigen/src/Core/PartialReduxEvaluator.h\n./Eigen/src/Core/CoreIterators.h\n./Eigen/src/Core/PermutationMatrix.h\n./Eigen/src/Core/CwiseTernaryOp.h\n./Eigen/src/Core/Reverse.h\n./Eigen/src/Core/Reshaped.h\n./Eigen/src/Core/Inverse.h\n./Eigen/src/Core/TriangularMatrix.h\n./Eigen/src/Core/BooleanRedux.h\n./Eigen/src/Core/ForceAlignedAccess.h\n./Eigen/src/Core/Ref.h\n./Eigen/src/Core/StlIterators.h\n./Eigen/src/Core/BandMatrix.h\n./Eigen/src/Core/ConditionEstimator.h\n./Eigen/src/Core/Diagonal.h\n./Eigen/src/Core/DiagonalMatrix.h\n./Eigen/src/Core/AssignEvaluator.h\n./Eigen/src/Core/CwiseBinaryOp.h\n./Eigen/src/Core/Visitor.h\n./Eigen/src/Core/GenericPacketMath.h\n./Eigen/src/Core/SelfAdjointView.h\n./Eigen/src/Core/Random.h\n./Eigen/src/Core/Solve.h\n./Eigen/src/Core/arch/AltiVec/MathFunctions.h\n./Eigen/src/Core/arch/AltiVec/PacketMath.h\n./Eigen/src/Core/arch/AltiVec/Complex.h\n./Eigen/src/Core/arch/MSA/MathFunctions.h\n./Eigen/src/Core/arch/MSA/Complex.h\n./Eigen/src/Core/arch/MSA/PacketMath.h\n./Eigen/src/Core/arch/GPU/Half.h\n./Eigen/src/Core/arch/GPU/PacketMathHalf.h\n./Eigen/src/Core/arch/GPU/MathFunctions.h\n./Eigen/src/Core/arch/GPU/PacketMath.h\n./Eigen/src/Core/arch/GPU/TypeCasting.h\n./Eigen/src/Core/arch/NEON/MathFunctions.h\n./Eigen/src/Core/arch/NEON/Complex.h\n./Eigen/src/Core/arch/NEON/PacketMath.h\n./Eigen/src/Core/arch/NEON/TypeCasting.h\n./Eigen/src/Core/arch/AVX/MathFunctions.h\n./Eigen/src/Core/arch/AVX/TypeCasting.h\n./Eigen/src/Core/arch/AVX/Complex.h\n./Eigen/src/Core/arch/AVX/PacketMath.h\n./Eigen/src/Core/arch/SYCL/InteropHeaders.h\n./Eigen/src/Core/arch/SYCL/PacketMath.h\n./Eigen/src/Core/arch/SYCL/TypeCasting.h\n./Eigen/src/Core/arch/SYCL/MathFunctions.h\n./Eigen/src/Core/arch/Default/GenericPacketMathFunctions.h\n./Eigen/src/Core/arch/Default/ConjHelper.h\n./Eigen/src/Core/arch/Default/Settings.h\n./Eigen/src/Core/arch/AVX512/MathFunctions.h\n./Eigen/src/Core/arch/AVX512/PacketMath.h\n./Eigen/src/Core/arch/AVX512/Complex.h\n./Eigen/src/Core/arch/SSE/PacketMath.h\n./Eigen/src/Core/arch/SSE/Complex.h\n./Eigen/src/Core/arch/SSE/TypeCasting.h\n./Eigen/src/Core/arch/SSE/MathFunctions.h\n./Eigen/src/Core/arch/ZVector/MathFunctions.h\n./Eigen/src/Core/arch/ZVector/PacketMath.h\n./Eigen/src/Core/arch/ZVector/Complex.h\n./Eigen/src/Core/arch/CUDA/Complex.h\n./Eigen/src/Core/Swap.h\n./Eigen/src/Core/MapBase.h\n./Eigen/src/Core/GeneralProduct.h\n./Eigen/src/Core/Matrix.h\n./Eigen/src/Core/Select.h\n./Eigen/src/Core/CwiseUnaryOp.h\n./Eigen/src/Core/DenseCoeffsBase.h\n./Eigen/src/SparseCore/SparseCwiseUnaryOp.h\n./Eigen/src/SparseCore/TriangularSolver.h\n./Eigen/src/SparseCore/SparseView.h\n./Eigen/src/SparseCore/SparseSolverBase.h\n./Eigen/src/SparseCore/SparseTranspose.h\n./Eigen/src/SparseCore/SparseDenseProduct.h\n./Eigen/src/SparseCore/SparseMap.h\n./Eigen/src/SparseCore/SparseProduct.h\n./Eigen/src/SparseCore/SparseUtil.h\n./Eigen/src/SparseCore/SparsePermutation.h\n./Eigen/src/SparseCore/SparseTriangularView.h\n./Eigen/src/SparseCore/SparseSelfAdjointView.h\n./Eigen/src/SparseCore/SparseMatrixBase.h\n./Eigen/src/SparseCore/AmbiVector.h\n./Eigen/src/SparseCore/SparseAssign.h\n./Eigen/src/SparseCore/SparseRedux.h\n./Eigen/src/SparseCore/SparseDot.h\n./Eigen/src/SparseCore/SparseCwiseBinaryOp.h\n./Eigen/src/SparseCore/SparseCompressedBase.h\n./Eigen/src/SparseCore/SparseSparseProductWithPruning.h\n./Eigen/src/SparseCore/SparseColEtree.h\n./Eigen/src/SparseCore/SparseRef.h\n./Eigen/src/SparseCore/CompressedStorage.h\n./Eigen/src/SparseCore/MappedSparseMatrix.h\n./Eigen/src/SparseCore/SparseDiagonalProduct.h\n./Eigen/src/SparseCore/SparseFuzzy.h\n./Eigen/src/SparseCore/ConservativeSparseSparseProduct.h\n./Eigen/src/SparseCore/SparseMatrix.h\n./Eigen/src/SparseCore/SparseVector.h\n./Eigen/src/SparseCore/SparseBlock.h\n./Eigen/src/IterativeLinearSolvers/SolveWithGuess.h\n./Eigen/src/IterativeLinearSolvers/IterativeSolverBase.h\n./Eigen/src/IterativeLinearSolvers/BiCGSTAB.h\n./Eigen/src/IterativeLinearSolvers/ConjugateGradient.h\n./Eigen/src/IterativeLinearSolvers/BasicPreconditioners.h\n./Eigen/src/IterativeLinearSolvers/IncompleteCholesky.h\n./Eigen/src/IterativeLinearSolvers/IncompleteLUT.h\n./Eigen/src/IterativeLinearSolvers/LeastSquareConjugateGradient.h\n./unsupported/Eigen/src/Eigenvalues/ArpackSelfAdjointEigenSolver.h\n./unsupported/Eigen/src/SpecialFunctions/arch/GPU/GpuSpecialFunctions.h\n./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsHalf.h\n./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h\n./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsFunctors.h\n./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsArrayAPI.h\n./unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsPacketMath.h\n./unsupported/Eigen/src/Polynomials/Companion.h\n./unsupported/Eigen/src/Polynomials/PolynomialUtils.h\n./unsupported/Eigen/src/Polynomials/PolynomialSolver.h\n./unsupported/Eigen/src/Splines/Spline.h\n./unsupported/Eigen/src/Splines/SplineFwd.h\n./unsupported/Eigen/src/Splines/SplineFitting.h\n./unsupported/Eigen/src/BVH/KdBVH.h\n./unsupported/Eigen/src/BVH/BVAlgorithms.h\n./unsupported/Eigen/src/AutoDiff/AutoDiffJacobian.h\n./unsupported/Eigen/src/AutoDiff/AutoDiffVector.h\n./unsupported/Eigen/src/AutoDiff/AutoDiffScalar.h\n./unsupported/Eigen/src/MatrixFunctions/MatrixSquareRoot.h\n./unsupported/Eigen/src/MatrixFunctions/MatrixPower.h\n./unsupported/Eigen/src/MatrixFunctions/MatrixExponential.h\n./unsupported/Eigen/src/MatrixFunctions/MatrixLogarithm.h\n./unsupported/Eigen/src/MatrixFunctions/StemFunction.h\n./unsupported/Eigen/src/MatrixFunctions/MatrixFunction.h\n./unsupported/Eigen/src/Skyline/SkylineStorage.h\n./unsupported/Eigen/src/Skyline/SkylineMatrixBase.h\n./unsupported/Eigen/src/Skyline/SkylineMatrix.h\n./unsupported/Eigen/src/Skyline/SkylineInplaceLU.h\n./unsupported/Eigen/src/Skyline/SkylineProduct.h\n./unsupported/Eigen/src/Skyline/SkylineUtil.h\n./unsupported/Eigen/src/FFT/ei_kissfft_impl.h\n./unsupported/Eigen/src/FFT/ei_fftw_impl.h\n./unsupported/Eigen/src/LevenbergMarquardt/LevenbergMarquardt.h\n./unsupported/Eigen/src/NonLinearOptimization/HybridNonLinearSolver.h\n./unsupported/Eigen/src/NonLinearOptimization/LevenbergMarquardt.h\n./unsupported/Eigen/src/KroneckerProduct/KroneckerTensorProduct.h\n./unsupported/Eigen/src/NumericalDiff/NumericalDiff.h\n./unsupported/Eigen/src/IterativeSolvers/IncompleteLU.h\n./unsupported/Eigen/src/IterativeSolvers/MINRES.h\n./unsupported/Eigen/src/IterativeSolvers/DGMRES.h\n./unsupported/Eigen/src/IterativeSolvers/Scaling.h\n./unsupported/Eigen/src/IterativeSolvers/GMRES.h\n./unsupported/Eigen/src/MoreVectorization/MathFunctions.h\n./unsupported/Eigen/src/EulerAngles/EulerAngles.h\n./unsupported/Eigen/src/EulerAngles/EulerSystem.h\n./unsupported/Eigen/src/SparseExtra/BlockOfDynamicSparseMatrix.h\n./unsupported/Eigen/src/SparseExtra/DynamicSparseMatrix.h\n./unsupported/Eigen/src/SparseExtra/BlockSparseMatrix.h\n./unsupported/Eigen/src/SparseExtra/RandomSetter.h\n./unsupported/Eigen/src/SparseExtra/MatrixMarketIterator.h\n./unsupported/Eigen/src/SparseExtra/MarketIO.h\n./unsupported/Eigen/CXX11/src/TensorSymmetry/StaticSymmetry.h\n./unsupported/Eigen/CXX11/src/TensorSymmetry/Symmetry.h\n./unsupported/Eigen/CXX11/src/TensorSymmetry/DynamicSymmetry.h\n./unsupported/Eigen/CXX11/src/TensorSymmetry/util/TemplateGroupTheory.h\n./unsupported/Eigen/CXX11/src/util/EmulateCXX11Meta.h\n./unsupported/Eigen/CXX11/src/util/CXX11Meta.h\n./unsupported/Eigen/CXX11/src/util/MaxSizeVector.h\n./unsupported/Eigen/CXX11/src/util/EmulateArray.h\n./unsupported/Eigen/CXX11/src/util/CXX11Workarounds.h\n./unsupported/Eigen/CXX11/src/ThreadPool/ThreadYield.h\n./unsupported/Eigen/CXX11/src/ThreadPool/NonBlockingThreadPool.h\n./unsupported/Eigen/CXX11/src/ThreadPool/RunQueue.h\n./unsupported/Eigen/CXX11/src/ThreadPool/ThreadCancel.h\n./unsupported/Eigen/CXX11/src/ThreadPool/ThreadPoolInterface.h\n./unsupported/Eigen/CXX11/src/ThreadPool/ThreadLocal.h\n./unsupported/Eigen/CXX11/src/ThreadPool/Barrier.h\n./unsupported/Eigen/CXX11/src/ThreadPool/EventCount.h\n./unsupported/Eigen/CXX11/src/ThreadPool/ThreadEnvironment.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorRef.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorFixedSize.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclRun.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclTuple.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorTraits.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorStorage.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorDeviceThreadPool.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorContractionThreadPool.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclPlaceHolderExpr.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclExprConstructor.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorIntDiv.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclConvertToDeviceExpression.h\n./unsupported/Eigen/CXX11/src/Tensor/Tensor.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorDeviceGpu.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorPatch.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorStriding.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorScan.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorCustomOp.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorDeviceSycl.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorGenerator.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorReductionSycl.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorArgMaxSycl.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorConvolution.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorBase.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorReduction.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorUInt128.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorArgMax.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorMeta.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorExpr.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorIO.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorContraction.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorDeviceDefault.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorReverse.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorShuffling.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorConvolutionSycl.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclFunctors.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorMap.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSycl.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractFunctors.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclExtractAccessor.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorConcatenation.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorGpuHipCudaDefines.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorInitializer.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorIndexList.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorGpuHipCudaUndefines.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorContractionMapper.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorCostModel.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorForcedEval.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorGlobalFunctions.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorContractionSycl.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorImagePatch.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorContractionBlocking.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorMacros.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorDevice.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorSyclLeafCount.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorRandom.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorFFT.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorFunctors.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorContractionGpu.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorDimensionList.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorConversion.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorEvalTo.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorAssign.h\n./unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h\n./unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h\n./unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h\n./unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h\n./unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h\n./unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h\n./unsupported/bench/bench_svd.cpp\n./unsupported/test/cxx11_tensor_image_patch_sycl.cpp\n./unsupported/test/cxx11_tensor_expr.cpp\n./unsupported/test/FFTW.cpp\n./unsupported/test/cxx11_tensor_reverse_sycl.cpp\n./unsupported/test/cxx11_tensor_comparisons.cpp\n./unsupported/test/cxx11_tensor_intdiv.cpp\n./unsupported/test/autodiff.cpp\n./unsupported/test/cxx11_tensor_executor.cpp\n./unsupported/test/cxx11_tensor_reduction.cpp\n./unsupported/test/cxx11_tensor_device_sycl.cpp\n./unsupported/test/minres.cpp\n./unsupported/test/cxx11_tensor_striding.cpp\n./unsupported/test/cxx11_tensor_chipping.cpp\n./unsupported/test/cxx11_tensor_convolution_sycl.cpp\n./unsupported/test/openglsupport.cpp\n./unsupported/test/cxx11_tensor_ifft.cpp\n./unsupported/test/polynomialutils.cpp\n./unsupported/test/cxx11_tensor_block_access.cpp\n./unsupported/test/cxx11_tensor_block_eval.cpp\n./unsupported/test/cxx11_tensor_block_io.cpp\n./unsupported/test/cxx11_tensor_morphing.cpp\n./unsupported/test/cxx11_tensor_casts.cpp\n./unsupported/test/cxx11_tensor_shuffling_sycl.cpp\n./unsupported/test/cxx11_tensor_morphing_sycl.cpp\n./unsupported/test/forward_adolc.cpp\n./unsupported/test/cxx11_tensor_layout_swap.cpp\n./unsupported/test/cxx11_tensor_move.cpp\n./unsupported/test/EulerAngles.cpp\n./unsupported/test/cxx11_tensor_trace.cpp\n./unsupported/test/alignedvector3.cpp\n./unsupported/test/cxx11_tensor_lvalue.cpp\n./unsupported/test/cxx11_tensor_argmax.cpp\n./unsupported/test/cxx11_tensor_broadcast_sycl.cpp\n./unsupported/test/autodiff_scalar.cpp\n./unsupported/test/sparse_extra.cpp\n./unsupported/test/cxx11_tensor_of_strings.cpp\n./unsupported/test/cxx11_tensor_empty.cpp\n./unsupported/test/cxx11_tensor_patch.cpp\n./unsupported/test/cxx11_tensor_sycl.cpp\n./unsupported/test/cxx11_tensor_forced_eval_sycl.cpp\n./unsupported/test/cxx11_tensor_inflation_sycl.cpp\n./unsupported/test/BVH.cpp\n./unsupported/test/cxx11_tensor_generator.cpp\n./unsupported/test/cxx11_meta.cpp\n./unsupported/test/matrix_functions.h\n./unsupported/test/kronecker_product.cpp\n./unsupported/test/matrix_function.cpp\n./unsupported/test/cxx11_tensor_thread_pool.cpp\n./unsupported/test/cxx11_non_blocking_thread_pool.cpp\n./unsupported/test/cxx11_tensor_fft.cpp\n./unsupported/test/cxx11_tensor_assign.cpp\n./unsupported/test/cxx11_tensor_simple.cpp\n./unsupported/test/cxx11_tensor_of_complex.cpp\n./unsupported/test/cxx11_tensor_inflation.cpp\n./unsupported/test/cxx11_tensor_map.cpp\n./unsupported/test/cxx11_tensor_shuffling.cpp\n./unsupported/test/cxx11_tensor_padding.cpp\n./unsupported/test/cxx11_tensor_argmax_sycl.cpp\n./unsupported/test/matrix_square_root.cpp\n./unsupported/test/dgmres.cpp\n./unsupported/test/cxx11_tensor_custom_op_sycl.cpp\n./unsupported/test/cxx11_tensor_reduction_sycl.cpp\n./unsupported/test/cxx11_runqueue.cpp\n./unsupported/test/cxx11_tensor_const.cpp\n./unsupported/test/matrix_power.cpp\n./unsupported/test/cxx11_tensor_contraction.cpp\n./unsupported/test/cxx11_tensor_random.cpp\n./unsupported/test/cxx11_tensor_volume_patch_sycl.cpp\n./unsupported/test/cxx11_tensor_contract_sycl.cpp\n./unsupported/test/cxx11_tensor_math.cpp\n./unsupported/test/splines.cpp\n./unsupported/test/cxx11_tensor_ref.cpp\n./unsupported/test/cxx11_tensor_concatenation_sycl.cpp\n./unsupported/test/gmres.cpp\n./unsupported/test/cxx11_tensor_fixed_size.cpp\n./unsupported/test/cxx11_tensor_custom_op.cpp\n./unsupported/test/cxx11_tensor_generator_sycl.cpp\n./unsupported/test/cxx11_tensor_uint128.cpp\n./unsupported/test/cxx11_tensor_builtins_sycl.cpp\n./unsupported/test/polynomialsolver.cpp\n./unsupported/test/cxx11_tensor_concatenation.cpp\n./unsupported/test/cxx11_tensor_broadcasting.cpp\n./unsupported/test/cxx11_tensor_convolution.cpp\n./unsupported/test/cxx11_tensor_forced_eval.cpp\n./unsupported/test/levenberg_marquardt.cpp\n./unsupported/test/cxx11_tensor_reverse.cpp\n./unsupported/test/cxx11_tensor_notification.cpp\n./unsupported/test/cxx11_tensor_patch_sycl.cpp\n./unsupported/test/cxx11_tensor_image_patch.cpp\n./unsupported/test/cxx11_tensor_scan.cpp\n./unsupported/test/cxx11_tensor_padding_sycl.cpp\n./unsupported/test/cxx11_tensor_index_list.cpp\n./unsupported/test/cxx11_tensor_io.cpp\n./unsupported/test/cxx11_tensor_mixed_indices.cpp\n./unsupported/test/cxx11_tensor_striding_sycl.cpp\n./unsupported/test/cxx11_tensor_of_const_values.cpp\n./unsupported/test/cxx11_tensor_symmetry.cpp\n./unsupported/test/cxx11_tensor_custom_index.cpp\n./unsupported/test/cxx11_tensor_chipping_sycl.cpp\n./unsupported/test/cxx11_tensor_roundings.cpp\n./unsupported/test/matrix_exponential.cpp\n./unsupported/test/cxx11_eventcount.cpp\n./unsupported/test/special_functions.cpp\n./unsupported/test/cxx11_tensor_dimension.cpp\n./unsupported/test/cxx11_tensor_layout_swap_sycl.cpp\n./lapack/eigenvalues.cpp\n./lapack/single.cpp\n./lapack/svd.cpp\n./lapack/complex_single.cpp\n./lapack/lu.cpp\n./lapack/double.cpp\n./lapack/complex_double.cpp\n./lapack/cholesky.cpp\n./lapack/lapack_common.h\n./blas/level2_impl.h\n./blas/PackedTriangularMatrixVector.h\n./blas/level3_impl.h\n./blas/complex_double.cpp\n./blas/common.h\n./blas/GeneralRank1Update.h\n./blas/double.cpp\n./blas/complex_single.cpp\n./blas/Rank2Update.h\n./blas/level1_impl.h\n./blas/level2_real_impl.h\n./blas/level1_real_impl.h\n./blas/single.cpp\n./blas/PackedSelfadjointProduct.h\n./blas/BandTriangularSolver.h\n./blas/level2_cplx_impl.h\n./blas/PackedTriangularSolverVector.h\n./blas/level1_cplx_impl.h\n./bench/analyze-blocking-sizes.cpp\n./bench/BenchTimer.h\n./bench/spbench/spbenchsolver.h\n./bench/spbench/spbenchstyle.h\n./bench/benchFFT.cpp\n./bench/eig33.cpp\n./bench/benchmark-blocking-sizes.cpp\n./demos/opengl/quaternion_demo.cpp\n./demos/opengl/camera.h\n./demos/opengl/gpuhelper.cpp\n./demos/opengl/gpuhelper.h\n./demos/opengl/icosphere.cpp\n./demos/opengl/quaternion_demo.h\n./demos/opengl/trackball.h\n./demos/opengl/icosphere.h\n./demos/opengl/camera.cpp\n./demos/opengl/trackball.cpp\n./demos/mix_eigen_and_c/binary_library.h\n./demos/mix_eigen_and_c/binary_library.cpp\n./demos/mandelbrot/mandelbrot.cpp\n./demos/mandelbrot/mandelbrot.h\n\nMozilla Public License Version 2.0\n==================================\n\n1. Definitions\n--------------\n\n1.1. \"Contributor\"\n    means each individual or legal entity that creates, contributes to\n    the creation of, or owns Covered Software.\n\n1.2. \"Contributor Version\"\n    means the combination of the Contributions of others (if any) used\n    by a Contributor and that particular Contributor's Contribution.\n\n1.3. \"Contribution\"\n    means Covered Software of a particular Contributor.\n\n1.4. \"Covered Software\"\n    means Source Code Form to which the initial Contributor has attached\n    the notice in Exhibit A, the Executable Form of such Source Code\n    Form, and Modifications of such Source Code Form, in each case\n    including portions thereof.\n\n1.5. \"Incompatible With Secondary Licenses\"\n    means\n\n    (a) that the initial Contributor has attached the notice described\n        in Exhibit B to the Covered Software; or\n\n    (b) that the Covered Software was made available under the terms of\n        version 1.1 or earlier of the License, but not also under the\n        terms of a Secondary License.\n\n1.6. \"Executable Form\"\n    means any form of the work other than Source Code Form.\n\n1.7. \"Larger Work\"\n    means a work that combines Covered Software with other material, in \n    a separate file or files, that is not Covered Software.\n\n1.8. \"License\"\n    means this document.\n\n1.9. \"Licensable\"\n    means having the right to grant, to the maximum extent possible,\n    whether at the time of the initial grant or subsequently, any and\n    all of the rights conveyed by this License.\n\n1.10. \"Modifications\"\n    means any of the following:\n\n    (a) any file in Source Code Form that results from an addition to,\n        deletion from, or modification of the contents of Covered\n        Software; or\n\n    (b) any new file in Source Code Form that contains any Covered\n        Software.\n\n1.11. \"Patent Claims\" of a Contributor\n    means any patent claim(s), including without limitation, method,\n    process, and apparatus claims, in any patent Licensable by such\n    Contributor that would be infringed, but for the grant of the\n    License, by the making, using, selling, offering for sale, having\n    made, import, or transfer of either its Contributions or its\n    Contributor Version.\n\n1.12. \"Secondary License\"\n    means either the GNU General Public License, Version 2.0, the GNU\n    Lesser General Public License, Version 2.1, the GNU Affero General\n    Public License, Version 3.0, or any later versions of those\n    licenses.\n\n1.13. \"Source Code Form\"\n    means the form of the work preferred for making modifications.\n\n1.14. \"You\" (or \"Your\")\n    means an individual or a legal entity exercising rights under this\n    License. For legal entities, \"You\" includes any entity that\n    controls, is controlled by, or is under common control with You. For\n    purposes of this definition, \"control\" means (a) the power, direct\n    or indirect, to cause the direction or management of such entity,\n    whether by contract or otherwise, or (b) ownership of more than\n    fifty percent (50%) of the outstanding shares or beneficial\n    ownership of such entity.\n\n2. License Grants and Conditions\n--------------------------------\n\n2.1. Grants\n\nEach Contributor hereby grants You a world-wide, royalty-free,\nnon-exclusive license:\n\n(a) under intellectual property rights (other than patent or trademark)\n    Licensable by such Contributor to use, reproduce, make available,\n    modify, display, perform, distribute, and otherwise exploit its\n    Contributions, either on an unmodified basis, with Modifications, or\n    as part of a Larger Work; and\n\n(b) under Patent Claims of such Contributor to make, use, sell, offer\n    for sale, have made, import, and otherwise transfer either its\n    Contributions or its Contributor Version.\n\n2.2. Effective Date\n\nThe licenses granted in Section 2.1 with respect to any Contribution\nbecome effective for each Contribution on the date the Contributor first\ndistributes such Contribution.\n\n2.3. Limitations on Grant Scope\n\nThe licenses granted in this Section 2 are the only rights granted under\nthis License. No additional rights or licenses will be implied from the\ndistribution or licensing of Covered Software under this License.\nNotwithstanding Section 2.1(b) above, no patent license is granted by a\nContributor:\n\n(a) for any code that a Contributor has removed from Covered Software;\n    or\n\n(b) for infringements caused by: (i) Your and any other third party's\n    modifications of Covered Software, or (ii) the combination of its\n    Contributions with other software (except as part of its Contributor\n    Version); or\n\n(c) under Patent Claims infringed by Covered Software in the absence of\n    its Contributions.\n\nThis License does not grant any rights in the trademarks, service marks,\nor logos of any Contributor (except as may be necessary to comply with\nthe notice requirements in Section 3.4).\n\n2.4. Subsequent Licenses\n\nNo Contributor makes additional grants as a result of Your choice to\ndistribute the Covered Software under a subsequent version of this\nLicense (see Section 10.2) or under the terms of a Secondary License (if\npermitted under the terms of Section 3.3).\n\n2.5. Representation\n\nEach Contributor represents that the Contributor believes its\nContributions are its original creation(s) or it has sufficient rights\nto grant the rights to its Contributions conveyed by this License.\n\n2.6. Fair Use\n\nThis License is not intended to limit any rights You have under\napplicable copyright doctrines of fair use, fair dealing, or other\nequivalents.\n\n2.7. Conditions\n\nSections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted\nin Section 2.1.\n\n3. Responsibilities\n-------------------\n\n3.1. Distribution of Source Form\n\nAll distribution of Covered Software in Source Code Form, including any\nModifications that You create or to which You contribute, must be under\nthe terms of this License. You must inform recipients that the Source\nCode Form of the Covered Software is governed by the terms of this\nLicense, and how they can obtain a copy of this License. You may not\nattempt to alter or restrict the recipients' rights in the Source Code\nForm.\n\n3.2. Distribution of Executable Form\n\nIf You distribute Covered Software in Executable Form then:\n\n(a) such Covered Software must also be made available in Source Code\n    Form, as described in Section 3.1, and You must inform recipients of\n    the Executable Form how they can obtain a copy of such Source Code\n    Form by reasonable means in a timely manner, at a charge no more\n    than the cost of distribution to the recipient; and\n\n(b) You may distribute such Executable Form under the terms of this\n    License, or sublicense it under different terms, provided that the\n    license for the Executable Form does not attempt to limit or alter\n    the recipients' rights in the Source Code Form under this License.\n\n3.3. Distribution of a Larger Work\n\nYou may create and distribute a Larger Work under terms of Your choice,\nprovided that You also comply with the requirements of this License for\nthe Covered Software. If the Larger Work is a combination of Covered\nSoftware with a work governed by one or more Secondary Licenses, and the\nCovered Software is not Incompatible With Secondary Licenses, this\nLicense permits You to additionally distribute such Covered Software\nunder the terms of such Secondary License(s), so that the recipient of\nthe Larger Work may, at their option, further distribute the Covered\nSoftware under the terms of either this License or such Secondary\nLicense(s).\n\n3.4. Notices\n\nYou may not remove or alter the substance of any license notices\n(including copyright notices, patent notices, disclaimers of warranty,\nor limitations of liability) contained within the Source Code Form of\nthe Covered Software, except that You may alter any license notices to\nthe extent required to remedy known factual inaccuracies.\n\n3.5. Application of Additional Terms\n\nYou may choose to offer, and to charge a fee for, warranty, support,\nindemnity or liability obligations to one or more recipients of Covered\nSoftware. However, You may do so only on Your own behalf, and not on\nbehalf of any Contributor. You must make it absolutely clear that any\nsuch warranty, support, indemnity, or liability obligation is offered by\nYou alone, and You hereby agree to indemnify every Contributor for any\nliability incurred by such Contributor as a result of warranty, support,\nindemnity or liability terms You offer. You may include additional\ndisclaimers of warranty and limitations of liability specific to any\njurisdiction.\n\n4. Inability to Comply Due to Statute or Regulation\n---------------------------------------------------\n\nIf it is impossible for You to comply with any of the terms of this\nLicense with respect to some or all of the Covered Software due to\nstatute, judicial order, or regulation then You must: (a) comply with\nthe terms of this License to the maximum extent possible; and (b)\ndescribe the limitations and the code they affect. Such description must\nbe placed in a text file included with all distributions of the Covered\nSoftware under this License. Except to the extent prohibited by statute\nor regulation, such description must be sufficiently detailed for a\nrecipient of ordinary skill to be able to understand it.\n\n5. Termination\n--------------\n\n5.1. The rights granted under this License will terminate automatically\nif You fail to comply with any of its terms. However, if You become\ncompliant, then the rights granted under this License from a particular\nContributor are reinstated (a) provisionally, unless and until such\nContributor explicitly and finally terminates Your grants, and (b) on an\nongoing basis, if such Contributor fails to notify You of the\nnon-compliance by some reasonable means prior to 60 days after You have\ncome back into compliance. Moreover, Your grants from a particular\nContributor are reinstated on an ongoing basis if such Contributor\nnotifies You of the non-compliance by some reasonable means, this is the\nfirst time You have received notice of non-compliance with this License\nfrom such Contributor, and You become compliant prior to 30 days after\nYour receipt of the notice.\n\n5.2. If You initiate litigation against any entity by asserting a patent\ninfringement claim (excluding declaratory judgment actions,\ncounter-claims, and cross-claims) alleging that a Contributor Version\ndirectly or indirectly infringes any patent, then the rights granted to\nYou by any and all Contributors for the Covered Software under Section\n2.1 of this License shall terminate.\n\n5.3. In the event of termination under Sections 5.1 or 5.2 above, all\nend user license agreements (excluding distributors and resellers) which\nhave been validly granted by You or Your distributors under this License\nprior to termination shall survive termination.\n\n************************************************************************\n*                                                                      *\n*  6. Disclaimer of Warranty                                           *\n*  -------------------------                                           *\n*                                                                      *\n*  Covered Software is provided under this License on an \"as is\"       *\n*  basis, without warranty of any kind, either expressed, implied, or  *\n*  statutory, including, without limitation, warranties that the       *\n*  Covered Software is free of defects, merchantable, fit for a        *\n*  particular purpose or non-infringing. The entire risk as to the     *\n*  quality and performance of the Covered Software is with You.        *\n*  Should any Covered Software prove defective in any respect, You     *\n*  (not any Contributor) assume the cost of any necessary servicing,   *\n*  repair, or correction. This disclaimer of warranty constitutes an   *\n*  essential part of this License. No use of any Covered Software is   *\n*  authorized under this License except under this disclaimer.         *\n*                                                                      *\n************************************************************************\n\n************************************************************************\n*                                                                      *\n*  7. Limitation of Liability                                          *\n*  --------------------------                                          *\n*                                                                      *\n*  Under no circumstances and under no legal theory, whether tort      *\n*  (including negligence), contract, or otherwise, shall any           *\n*  Contributor, or anyone who distributes Covered Software as          *\n*  permitted above, be liable to You for any direct, indirect,         *\n*  special, incidental, or consequential damages of any character      *\n*  including, without limitation, damages for lost profits, loss of    *\n*  goodwill, work stoppage, computer failure or malfunction, or any    *\n*  and all other commercial damages or losses, even if such party      *\n*  shall have been informed of the possibility of such damages. This   *\n*  limitation of liability shall not apply to liability for death or   *\n*  personal injury resulting from such party's negligence to the       *\n*  extent applicable law prohibits such limitation. Some               *\n*  jurisdictions do not allow the exclusion or limitation of           *\n*  incidental or consequential damages, so this exclusion and          *\n*  limitation may not apply to You.                                    *\n*                                                                      *\n************************************************************************\n\n8. Litigation\n-------------\n\nAny litigation relating to this License may be brought only in the\ncourts of a jurisdiction where the defendant maintains its principal\nplace of business and such litigation shall be governed by laws of that\njurisdiction, without reference to its conflict-of-law provisions.\nNothing in this Section shall prevent a party's ability to bring\ncross-claims or counter-claims.\n\n9. Miscellaneous\n----------------\n\nThis License represents the complete agreement concerning the subject\nmatter hereof. If any provision of this License is held to be\nunenforceable, such provision shall be reformed only to the extent\nnecessary to make it enforceable. Any law or regulation which provides\nthat the language of a contract shall be construed against the drafter\nshall not be used to construe this License against a Contributor.\n\n10. Versions of the License\n---------------------------\n\n10.1. New Versions\n\nMozilla Foundation is the license steward. Except as provided in Section\n10.3, no one other than the license steward has the right to modify or\npublish new versions of this License. Each version will be given a\ndistinguishing version number.\n\n10.2. Effect of New Versions\n\nYou may distribute the Covered Software under the terms of the version\nof the License under which You originally received the Covered Software,\nor under the terms of any subsequent version published by the license\nsteward.\n\n10.3. Modified Versions\n\nIf you create software not governed by this License, and you want to\ncreate a new license for such software, you may create and use a\nmodified version of this License if you rename the license and remove\nany references to the name of the license steward (except to note that\nsuch modified license differs from this License).\n\n10.4. Distributing Source Code Form that is Incompatible With Secondary\nLicenses\n\nIf You choose to distribute Source Code Form that is Incompatible With\nSecondary Licenses under the terms of this version of the License, the\nnotice described in Exhibit B of this License must be attached.\n\nExhibit A - Source Code Form License Notice\n-------------------------------------------\n\n  This Source Code Form is subject to the terms of the Mozilla Public\n  License, v. 2.0. If a copy of the MPL was not distributed with this\n  file, You can obtain one at http://mozilla.org/MPL/2.0/.\n\nIf it is not possible or desirable to put the notice in a particular\nfile, then You may include the notice in a location (such as a LICENSE\nfile in a relevant directory) where a recipient would be likely to look\nfor such a notice.\n\nYou may add additional accurate notices of copyright ownership.\n\nExhibit B - \"Incompatible With Secondary Licenses\" Notice\n---------------------------------------------------------\n\n  This Source Code Form is \"Incompatible With Secondary Licenses\", as\n  defined by the Mozilla Public License, v. 2.0.\n\n----------------------------------------------------------------------\nFollowing applies to:\n./doc/UsingIntelMKL.dox\n./doc/UsingIntelMKL.dox\n./Eigen/src/Eigenvalues/ComplexSchur_MKL.h\n./Eigen/src/Eigenvalues/ComplexSchur_MKL.h\n./Eigen/src/Eigenvalues/SelfAdjointEigenSolver_MKL.h\n./Eigen/src/Eigenvalues/SelfAdjointEigenSolver_MKL.h\n./Eigen/src/Eigenvalues/RealSchur_MKL.h\n./Eigen/src/Eigenvalues/RealSchur_MKL.h\n./Eigen/src/LU/arch/Inverse_SSE.h\n./Eigen/src/LU/arch/Inverse_SSE.h\n./Eigen/src/LU/PartialPivLU_MKL.h\n./Eigen/src/LU/PartialPivLU_MKL.h\n./Eigen/src/QR/HouseholderQR_MKL.h\n./Eigen/src/QR/HouseholderQR_MKL.h\n./Eigen/src/QR/ColPivHouseholderQR_MKL.h\n./Eigen/src/QR/ColPivHouseholderQR_MKL.h\n./Eigen/src/SVD/JacobiSVD_MKL.h\n./Eigen/src/SVD/JacobiSVD_MKL.h\n./Eigen/src/PardisoSupport/PardisoSupport.h\n./Eigen/src/PardisoSupport/PardisoSupport.h\n./Eigen/src/Core/Assign_MKL.h\n./Eigen/src/Core/Assign_MKL.h\n./Eigen/src/Core/products/SelfadjointMatrixVector_MKL.h\n./Eigen/src/Core/products/SelfadjointMatrixVector_MKL.h\n./Eigen/src/Core/products/GeneralMatrixVector_MKL.h\n./Eigen/src/Core/products/GeneralMatrixVector_MKL.h\n./Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h\n./Eigen/src/Core/products/SelfadjointMatrixMatrix_MKL.h\n./Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h\n./Eigen/src/Core/products/TriangularMatrixMatrix_MKL.h\n./Eigen/src/Core/products/GeneralMatrixMatrix_MKL.h\n./Eigen/src/Core/products/GeneralMatrixMatrix_MKL.h\n./Eigen/src/Core/products/TriangularMatrixVector_MKL.h\n./Eigen/src/Core/products/TriangularMatrixVector_MKL.h\n./Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h\n./Eigen/src/Core/products/GeneralMatrixMatrixTriangular_MKL.h\n./Eigen/src/Core/products/TriangularSolverMatrix_MKL.h\n./Eigen/src/Core/products/TriangularSolverMatrix_MKL.h\n./Eigen/src/Core/util/MKL_support.h\n./Eigen/src/Core/util/MKL_support.h\n./Eigen/src/Cholesky/LLT_MKL.h\n./Eigen/src/Cholesky/LLT_MKL.h\n\n/*\n Copyright (c) 2011, Intel Corporation. All rights reserved.\n\n Redistribution and use in source and binary forms, with or without\n modification, are permitted provided that the following conditions\n are met:\n\n * Redistributions of source code must retain the above copyright\n   notice, this list of conditions and the following disclaimer.  *\n   Redistributions in binary form must reproduce the above copyright\n   notice, this list of conditions and the following disclaimer in the\n   documentation and/or other materials provided with the\n   distribution.  * Neither the name of Intel Corporation nor the\n   names of its contributors may be used to endorse or promote\n   products derived from this software without specific prior written\n   permission.\n\n THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n \"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\n LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\n A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\n OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\n SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\n LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\n DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\n THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\n OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n */\n\n----------------------------------------------------------------------\nFollowing applies to:\n./unsupported/Eigen/src/LevenbergMarquardt/LevenbergMarquardt.h\n./unsupported/Eigen/src/LevenbergMarquardt/LMcovar.h\n./unsupported/Eigen/src/LevenbergMarquardt/LMonestep.h\n./unsupported/Eigen/src/LevenbergMarquardt/LMpar.h\n./unsupported/Eigen/src/LevenbergMarquardt/LMqrsolv.h\n\nMinpack Copyright Notice (1999) University of Chicago.  All rights\nreserved\n\nRedistribution and use in source and binary forms, with or\nwithout modification, are permitted provided that the\nfollowing conditions are met:\n\n1. Redistributions of source code must retain the above\ncopyright notice, this list of conditions and the following\ndisclaimer.\n\n2. Redistributions in binary form must reproduce the above\ncopyright notice, this list of conditions and the following\ndisclaimer in the documentation and/or other materials\nprovided with the distribution.\n\n3. The end-user documentation included with the\nredistribution, if any, must include the following\nacknowledgment:\n\n   \"This product includes software developed by the\n   University of Chicago, as Operator of Argonne National\n   Laboratory.\n\nAlternately, this acknowledgment may appear in the software\nitself, if and wherever such third-party acknowledgments\nnormally appear.\n\n4. WARRANTY DISCLAIMER. THE SOFTWARE IS SUPPLIED \"AS IS\"\nWITHOUT WARRANTY OF ANY KIND. THE COPYRIGHT HOLDER, THE\nUNITED STATES, THE UNITED STATES DEPARTMENT OF ENERGY, AND\nTHEIR EMPLOYEES: (1) DISCLAIM ANY WARRANTIES, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO ANY IMPLIED WARRANTIES\nOF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE\nOR NON-INFRINGEMENT, (2) DO NOT ASSUME ANY LEGAL LIABILITY\nOR RESPONSIBILITY FOR THE ACCURACY, COMPLETENESS, OR\nUSEFULNESS OF THE SOFTWARE, (3) DO NOT REPRESENT THAT USE OF\nTHE SOFTWARE WOULD NOT INFRINGE PRIVATELY OWNED RIGHTS, (4)\nDO NOT WARRANT THAT THE SOFTWARE WILL FUNCTION\nUNINTERRUPTED, THAT IT IS ERROR-FREE OR THAT ANY ERRORS WILL\nBE CORRECTED.\n\n5. LIMITATION OF LIABILITY. IN NO EVENT WILL THE COPYRIGHT\nHOLDER, THE UNITED STATES, THE UNITED STATES DEPARTMENT OF\nENERGY, OR THEIR EMPLOYEES: BE LIABLE FOR ANY INDIRECT,\nINCIDENTAL, CONSEQUENTIAL, SPECIAL OR PUNITIVE DAMAGES OF\nANY KIND OR NATURE, INCLUDING BUT NOT LIMITED TO LOSS OF\nPROFITS OR LOSS OF DATA, FOR ANY REASON WHATSOEVER, WHETHER\nSUCH LIABILITY IS ASSERTED ON THE BASIS OF CONTRACT, TORT\n(INCLUDING NEGLIGENCE OR STRICT LIABILITY), OR OTHERWISE,\nEVEN IF ANY OF SAID PARTIES HAS BEEN WARNED OF THE\nPOSSIBILITY OF SUCH LOSS OR DAMAGES.\n\n\nCopyright (c) 1992-2013 The University of Tennessee and The University\n                        of Tennessee Research Foundation.  All rights\n                        reserved.\nCopyright (c) 2000-2013 The University of California Berkeley. All\n                        rights reserved.\nCopyright (c) 2006-2013 The University of Colorado Denver.  All rights\n                        reserved.\n\nFollowing applies to:\n./lapack/*.c\n\n$COPYRIGHT$\n\nAdditional copyrights may follow\n\n$HEADER$\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are\nmet:\n\n- Redistributions of source code must retain the above copyright\n  notice, this list of conditions and the following disclaimer.\n\n- Redistributions in binary form must reproduce the above copyright\n  notice, this list of conditions and the following disclaimer listed\n  in this license in the documentation and/or other materials\n  provided with the distribution.\n\n- Neither the name of the copyright holders nor the names of its\n  contributors may be used to endorse or promote products derived from\n  this software without specific prior written permission.\n\nThe copyright holders provide no reassurances that the source code\nprovided does not infringe any patent, copyright, or any other\nintellectual property rights of third parties.  The copyright holders\ndisclaim any liability to any recipient for claims brought against\nrecipient by any third party for infringement of that parties\nintellectual property rights.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n----------------------------------------------------------------------\nFollowing applies to:\n\n./cmake/FindComputeCpp.cmake\n\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n--------------------------------------------------------------------------------\nLicense for Farmhash:\n// Copyright (c) 2014 Google, Inc.\n//\n// Permission is hereby granted, free of charge, to any person obtaining a copy\n// of this software and associated documentation files (the \"Software\"), to deal\n// in the Software without restriction, including without limitation the rights\n// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n// copies of the Software, and to permit persons to whom the Software is\n// furnished to do so, subject to the following conditions:\n//\n// The above copyright notice and this permission notice shall be included in\n// all copies or substantial portions of the Software.\n//\n// THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n// THE SOFTWARE.\n\n--------------------------------------------------------------------------------\nLicense for Flatbuffers:\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2014 Google Inc.\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n--------------------------------------------------------------------------------\nLicense for highwayhash:\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n\n--------------------------------------------------------------------------------\nLicense for libjpeg-turbo:\nFor a summary of these license terms, see LICENSE.md.\n\nlibjpeg-turbo license\n---------------------\n    This license covers the TurboJPEG API library and associated programs.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n- Redistributions of source code must retain the above copyright notice,\n  this list of conditions and the following disclaimer.\n- Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n- Neither the name of the libjpeg-turbo Project nor the names of its\n  contributors may be used to endorse or promote products derived from this\n  software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\",\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\nARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE\nLIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\nCONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\nSUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\nINTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\nCONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\nARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\nPOSSIBILITY OF SUCH DAMAGE.\n\n\nlibjpeg license, Independent JPEG Group\n---------------------------------------\n    This license applies to the libjpeg API library and associated programs\n    (any code inherited from libjpeg, and any modifications to that code.)\n\nThe authors make NO WARRANTY or representation, either express or implied,\nwith respect to this software, its quality, accuracy, merchantability, or\nfitness for a particular purpose.  This software is provided \"AS IS\", and you,\nits user, assume the entire risk as to its quality and accuracy.\n\nThis software is copyright (C) 1991-2016, Thomas G. Lane, Guido Vollbeding.\nAll Rights Reserved except as specified below.\n\nPermission is hereby granted to use, copy, modify, and distribute this\nsoftware (or portions thereof) for any purpose, without fee, subject to these\nconditions:\n(1) If any part of the source code for this software is distributed, then this\nREADME file must be included, with this copyright and no-warranty notice\nunaltered; and any additions, deletions, or changes to the original files\nmust be clearly indicated in accompanying documentation.\n(2) If only executable code is distributed, then the accompanying\ndocumentation must state that \"this software is based in part on the work of\nthe Independent JPEG Group\".\n(3) Permission for use of this software is granted only if the user accepts\nfull responsibility for any undesirable consequences; the authors accept\nNO LIABILITY for damages of any kind.\n\nThese conditions apply to any software derived from or based on the IJG code,\nnot just to the unmodified library.  If you use our work, you ought to\nacknowledge us.\n\nPermission is NOT granted for the use of any IJG author's name or company name\nin advertising or publicity relating to this software or products derived from\nit.  This software may be referred to only as \"the Independent JPEG Group's\nsoftware\".\n\nWe specifically permit and encourage the use of this software as the basis of\ncommercial products, provided that all warranty or liability claims are\nassumed by the product vendor.\n\n\nThe Unix configuration script \"configure\" was produced with GNU Autoconf.\nIt is copyright by the Free Software Foundation but is freely distributable.\nThe same holds for its supporting scripts (config.guess, config.sub,\nltmain.sh).  Another support script, install-sh, is copyright by X Consortium\nbut is also freely distributable.\n\nThe IJG distribution formerly included code to read and write GIF files.\nTo avoid entanglement with the Unisys LZW patent (now expired), GIF reading\nsupport has been removed altogether, and the GIF writer has been simplified\nto produce \"uncompressed GIFs\".  This technique does not use the LZW\nalgorithm; the resulting GIF files are larger than usual, but are readable\nby all standard GIF decoders.\n\nWe are required to state that\n    \"The Graphics Interchange Format(c) is the Copyright property of\n    CompuServe Incorporated.  GIF(sm) is a Service Mark property of\n    CompuServe Incorporated.\"\n\n\nzlib License\n------------\n    This license is a subset of the other two, and it covers the libjpeg-turbo\n    SIMD extensions.\n\nThis software is provided 'as-is', without any express or implied\nwarranty.  In no event will the authors be held liable for any damages\narising from the use of this software.\n\nPermission is granted to anyone to use this software for any purpose,\nincluding commercial applications, and to alter it and redistribute it\nfreely, subject to the following restrictions:\n\n1. The origin of this software must not be misrepresented; you must not\n   claim that you wrote the original software. If you use this software\n   in a product, an acknowledgment in the product documentation would be\n   appreciated but is not required.\n2. Altered source versions must be plainly marked as such, and must not be\n   misrepresented as being the original software.\n3. This notice may not be removed or altered from any source distribution.\n\n--------------------------------------------------------------------------------\nLicense for fft2d:\nCopyright(C) 1997,2001 Takuya OOURA (email: ooura@kurims.kyoto-u.ac.jp).\nYou may use, copy, modify this code for any purpose and \nwithout fee. You may distribute this ORIGINAL package.\n\n--------------------------------------------------------------------------------\nLicense for giflib:\nThe GIFLIB distribution is Copyright (c) 1997  Eric S. Raymond\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE.\n\n--------------------------------------------------------------------------------\nLicense for llvm-project:\nCopied from llvm-project/llvm/LICENSE.TXT:\n==============================================================================\nThe LLVM Project is under the Apache License v2.0 with LLVM Exceptions:\n==============================================================================\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n    1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n    2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n    3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n    4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n    5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n    6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n    7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n    8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n    9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n    END OF TERMS AND CONDITIONS\n\n    APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n    Copyright [yyyy] [name of copyright owner]\n\n    Licensed under the Apache License, Version 2.0 (the \"License\");\n    you may not use this file except in compliance with the License.\n    You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n    Unless required by applicable law or agreed to in writing, software\n    distributed under the License is distributed on an \"AS IS\" BASIS,\n    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n    See the License for the specific language governing permissions and\n    limitations under the License.\n\n\n---- LLVM Exceptions to the Apache 2.0 License ----\n\nAs an exception, if, as a result of your compiling your source code, portions\nof this Software are embedded into an Object form of such source code, you\nmay redistribute such embedded portions in such Object form without complying\nwith the conditions of Sections 4(a), 4(b) and 4(d) of the License.\n\nIn addition, if you combine or link compiled forms of this Software with\nsoftware that is licensed under the GPLv2 (\"Combined Software\") and if a\ncourt of competent jurisdiction determines that the patent provision (Section\n3), the indemnity provision (Section 9) or other Section of the License\nconflicts with the conditions of the GPLv2, you may retroactively and\nprospectively choose to deem waived or otherwise exclude such Section(s) of\nthe License, but only in their entirety and only with respect to the Combined\nSoftware.\n\n==============================================================================\nSoftware from third parties included in the LLVM Project:\n==============================================================================\nThe LLVM Project contains third party software which is under different license\nterms. All such code will be identified clearly using at least one of two\nmechanisms:\n1) It will be in a separate directory tree with its own `LICENSE.txt` or\n   `LICENSE` file at the top containing the specific license and restrictions\n   which apply to that software, or\n2) It will contain specific license and restriction terms at the top of every\n   file.\n\n==============================================================================\nLegacy LLVM License (https://llvm.org/docs/DeveloperPolicy.html#legacy):\n==============================================================================\nUniversity of Illinois/NCSA\nOpen Source License\n\nCopyright (c) 2003-2019 University of Illinois at Urbana-Champaign.\nAll rights reserved.\n\nDeveloped by:\n\n    LLVM Team\n\n    University of Illinois at Urbana-Champaign\n\n    http://llvm.org\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of\nthis software and associated documentation files (the \"Software\"), to deal with\nthe Software without restriction, including without limitation the rights to\nuse, copy, modify, merge, publish, distribute, sublicense, and/or sell copies\nof the Software, and to permit persons to whom the Software is furnished to do\nso, subject to the following conditions:\n\n    * Redistributions of source code must retain the above copyright notice,\n      this list of conditions and the following disclaimers.\n\n    * Redistributions in binary form must reproduce the above copyright notice,\n      this list of conditions and the following disclaimers in the\n      documentation and/or other materials provided with the distribution.\n\n    * Neither the names of the LLVM Team, University of Illinois at\n      Urbana-Champaign, nor the names of its contributors may be used to\n      endorse or promote products derived from this Software without specific\n      prior written permission.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS\nFOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE\nCONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE\nSOFTWARE.\n\n==============================================================================\n==============================================================================                                                                                \nCopied from llvm-project/llvm/utils/unittest/googletest/LICENSE.TXT and\nllvm-project/llvm/utils/unittest/googlemock/LICENSE.txt:\n\nCopyright 2008, Google Inc.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are\nmet:\n\n    * Redistributions of source code must retain the above copyright\nnotice, this list of conditions and the following disclaimer.\n    * Redistributions in binary form must reproduce the above\ncopyright notice, this list of conditions and the following disclaimer\nin the documentation and/or other materials provided with the\ndistribution.\n    * Neither the name of Google Inc. nor the names of its\ncontributors may be used to endorse or promote products derived from\nthis software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n==============================================================================\n==============================================================================                                                                                \nCopied from llvm-project/llvm/lib/Support/COPYRIGHT.regex:\n$OpenBSD: COPYRIGHT,v 1.3 2003/06/02 20:18:36 millert Exp $\n\nCopyright 1992, 1993, 1994 Henry Spencer.  All rights reserved.\nThis software is not subject to any license of the American Telephone\nand Telegraph Company or of the Regents of the University of California.\n\nPermission is granted to anyone to use this software for any purpose on\nany computer system, and to alter it and redistribute it, subject\nto the following restrictions:\n\n1. The author is not responsible for the consequences of use of this\n   software, no matter how awful, even if they arise from flaws in it.\n\n2. The origin of this software must not be misrepresented, either by\n   explicit claim or by omission.  Since few users ever read sources,\n   credits must appear in the documentation.\n\n3. Altered versions must be plainly marked as such, and must not be\n   misrepresented as being the original software.  Since few users\n   ever read sources, credits must appear in the documentation.\n\n4. This notice may not be removed or altered.\n\n=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=\n/*-\n * Copyright (c) 1994\n *\tThe Regents of the University of California.  All rights reserved.\n *\n * Redistribution and use in source and binary forms, with or without\n * modification, are permitted provided that the following conditions\n * are met:\n * 1. Redistributions of source code must retain the above copyright\n *    notice, this list of conditions and the following disclaimer.\n * 2. Redistributions in binary form must reproduce the above copyright\n *    notice, this list of conditions and the following disclaimer in the\n *    documentation and/or other materials provided with the distribution.\n * 3. Neither the name of the University nor the names of its contributors\n *    may be used to endorse or promote products derived from this software\n *    without specific prior written permission.\n *\n * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND\n * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n * ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE\n * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\n * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS\n * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)\n * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT\n * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY\n * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF\n * SUCH DAMAGE.\n *\n *\t@(#)COPYRIGHT\t8.1 (Berkeley) 3/16/94\n */\n\n==============================================================================\n==============================================================================                                                                                \nCopied from llvm-project/llvm/include/llvm/Support/LICENSE.TXT:\n\nLLVM System Interface Library\n-------------------------------------------------------------------------------\nThe LLVM System Interface Library is licensed under the Illinois Open Source\nLicense and has the following additional copyright:\n\nCopyright (C) 2004 eXtensible Systems, Inc.\n\n==============================================================================\n==============================================================================                                                                                \nCopied from llvm-project/llvm/test/YAMLParser/LICENSE.txt:\n\nCopyright (c) 2006 Kirill Simonov\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of\nthis software and associated documentation files (the \"Software\"), to deal in\nthe Software without restriction, including without limitation the rights to\nuse, copy, modify, merge, publish, distribute, sublicense, and/or sell copies\nof the Software, and to permit persons to whom the Software is furnished to do\nso, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n--------------------------------------------------------------------------------\nLicense for mkl_dnn:\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"{}\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright {yyyy} {name of copyright owner}\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n------------------------------------------------------------------------\n    \nThe below applies to src/cpu/xbyak/*.\n   \n\nCopyright (c) 2007 MITSUNARI Shigeo\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\nRedistributions of source code must retain the above copyright notice, this\nlist of conditions and the following disclaimer.\nRedistributions in binary form must reproduce the above copyright notice,\nthis list of conditions and the following disclaimer in the documentation\nand/or other materials provided with the distribution.\nNeither the name of the copyright owner nor the names of its contributors may\nbe used to endorse or promote products derived from this software without\nspecific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\nARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE\nLIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\nCONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\nSUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\nINTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\nCONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\nARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF\nTHE POSSIBILITY OF SUCH DAMAGE.\n\nソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た\nす場合に限り、再頒布および使用が許可されます。\n\nソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項\nを含めること。\nバイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作\n権表示、本条件一覧、および下記免責条項を含めること。\n書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進\nに、著作権者の名前またはコントリビューターの名前を使用してはならない。\n本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ\nれており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性\nに関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。\n著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを\n問わず、かつ責任の根拠が契約であるか厳格責任であるか（過失その他の）不法行為で\nあるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、\n本ソフトウェアの使用によって発生した（代替品または代用サービスの調達、使用の\n喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない）直接\n損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、\n一切責任を負わないものとします。\n\n--------------------------------------------------------------------------------\nLicense for nsync:\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n--------------------------------------------------------------------------------\nLicense for TensorFlow:\nCopyright 2019 The TensorFlow Authors.  All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n\n--------------------------------------------------------------------------------\nLicense for the FFT components of ducc0:\nCopyright (C) 2010-2022 Max-Planck-Society\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without modification,\nare permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n* Redistributions in binary form must reproduce the above copyright notice, this\n  list of conditions and the following disclaimer in the documentation and/or\n  other materials provided with the distribution.\n* Neither the name of the copyright holder nor the names of its contributors may\n  be used to endorse or promote products derived from this software without\n  specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\nANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\nWARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR\nANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES\n(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;\nLOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON\nANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS\nSOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n--------------------------------------------------------------------------------\nLicense for pybind11:\nCopyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>, All rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n1. Redistributions of source code must retain the above copyright notice, this\n   list of conditions and the following disclaimer.\n\n2. Redistributions in binary form must reproduce the above copyright notice,\n   this list of conditions and the following disclaimer in the documentation\n   and/or other materials provided with the distribution.\n\n3. Neither the name of the copyright holder nor the names of its contributors\n   may be used to endorse or promote products derived from this software\n   without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" AND\nANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED\nWARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\nPlease also refer to the file .github/CONTRIBUTING.md, which clarifies licensing of\nexternal contributions to this project including patches, pull requests, etc.\n\n--------------------------------------------------------------------------------\nLicense for snappy:\nCopyright 2011, Google Inc.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are\nmet:\n\n    * Redistributions of source code must retain the above copyright\nnotice, this list of conditions and the following disclaimer.\n    * Redistributions in binary form must reproduce the above\ncopyright notice, this list of conditions and the following disclaimer\nin the documentation and/or other materials provided with the\ndistribution.\n    * Neither the name of Google Inc. nor the names of its\ncontributors may be used to endorse or promote products derived from\nthis software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS\n\"AS IS\" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT\nLIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR\nA PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT\nOWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT\nLIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,\nDATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY\nTHEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT\n(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n\n===\n\nSome of the benchmark data in util/zippy/testdata is licensed differently:\n\n - fireworks.jpeg is Copyright 2013 Steinar H. Gunderson, and\n   is licensed under the Creative Commons Attribution 3.0 license\n   (CC-BY-3.0). See https://creativecommons.org/licenses/by/3.0/\n   for more information.\n\n - kppkn.gtb is taken from the Gaviota chess tablebase set, and\n   is licensed under the MIT License. See\n   https://sites.google.com/site/gaviotachessengine/Home/endgame-tablebases-1\n   for more information.\n\n - paper-100k.pdf is an excerpt (bytes 92160 to 194560) from the paper\n   “Combinatorial Modeling of Chromatin Features Quantitatively Predicts DNA\n   Replication Timing in _Drosophila_” by Federico Comoglio and Renato Paro,\n   which is licensed under the CC-BY license. See\n   http://www.ploscompbiol.org/static/license for more ifnormation.\n\n - alice29.txt, asyoulik.txt, plrabn12.txt and lcet10.txt are from Project\n   Gutenberg. The first three have expired copyrights and are in the public\n   domain; the latter does not have expired copyright, but is still in the\n   public domain according to the license information\n   (http://www.gutenberg.org/ebooks/53).\n\n--------------------------------------------------------------------------------\nLicense for upb:\n\nCopyright (c) 2009-2011, Google Inc.\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n    * Redistributions of source code must retain the above copyright\n      notice, this list of conditions and the following disclaimer.\n    * Redistributions in binary form must reproduce the above copyright\n      notice, this list of conditions and the following disclaimer in the\n      documentation and/or other materials provided with the distribution.\n    * Neither the name of Google Inc. nor the names of any other\n      contributors may be used to endorse or promote products\n      derived from this software without specific prior written permission.\n\nTHIS SOFTWARE IS PROVIDED BY GOOGLE INC. ``AS IS'' AND ANY EXPRESS OR IMPLIED\nWARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF\nMERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO\nEVENT SHALL GOOGLE INC. BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,\nSPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,\nPROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR\nBUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER\nIN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\nARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\nPOSSIBILITY OF SUCH DAMAGE.\n\n--------------------------------------------------------------------------------\nLicense for zlib:\n(extracted from README, except for match.S)\n\nCopyright notice:\n\n (C) 1995-2013 Jean-loup Gailly and Mark Adler\n\n  This software is provided 'as-is', without any express or implied\n  warranty.  In no event will the authors be held liable for any damages\n  arising from the use of this software.\n\n  Permission is granted to anyone to use this software for any purpose,\n  including commercial applications, and to alter it and redistribute it\n  freely, subject to the following restrictions:\n\n  1. The origin of this software must not be misrepresented; you must not\n     claim that you wrote the original software. If you use this software\n     in a product, an acknowledgment in the product documentation would be\n     appreciated but is not required.\n  2. Altered source versions must be plainly marked as such, and must not be\n     misrepresented as being the original software.\n  3. This notice may not be removed or altered from any source distribution.\n\n  Jean-loup Gailly        Mark Adler\n  jloup@gzip.org          madler@alumni.caltech.edu\n\nIf you use the zlib library in a product, we would appreciate *not* receiving\nlengthy legal documents to sign.  The sources are provided for free but without\nwarranty of any kind.  The library has been entirely written by Jean-loup\nGailly and Mark Adler; it does not include third-party code.\n\nIf you redistribute modified sources, we would appreciate that you include in\nthe file ChangeLog history information documenting your changes.  Please read\nthe FAQ for more information on the distribution of modified source versions.\n\n(extracted from match.S, for match.S only)\n\nCopyright (C) 1998, 2007 Brian Raiter <breadbox@muppetlabs.com>\n\nThis software is provided 'as-is', without any express or implied\nwarranty.  In no event will the author be held liable for any damages\narising from the use of this software.\n\nPermission is granted to anyone to use this software for any purpose,\nincluding commercial applications, and to alter it and redistribute it\nfreely, subject to the following restrictions:\n\n1. The origin of this software must not be misrepresented; you must not\n  claim that you wrote the original software. If you use this software\n  in a product, an acknowledgment in the product documentation would be\n  appreciated but is not required.\n2. Altered source versions must be plainly marked as such, and must not be\n  misrepresented as being the original software.\n3. This notice may not be removed or altered from any source distribution.\n"
  },
  {
    "path": "build_jaxlib/build/build.py",
    "content": "#!/usr/bin/python\n#\n# Copyright 2018 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n#\n# Helper script for building JAX's libjax easily.\n\n\nimport argparse\nimport collections\nimport hashlib\nimport os\nimport platform\nimport re\nimport shutil\nimport stat\nimport subprocess\nimport sys\nimport textwrap\nimport urllib\n\n# pylint: disable=g-import-not-at-top\nif hasattr(urllib, \"urlretrieve\"):\n  urlretrieve = urllib.urlretrieve\nelse:\n  import urllib.request\n  urlretrieve = urllib.request.urlretrieve\n\nif hasattr(shutil, \"which\"):\n  which = shutil.which\nelse:\n  from distutils.spawn import find_executable as which\n# pylint: enable=g-import-not-at-top\n\n\ndef is_windows():\n  return sys.platform.startswith(\"win32\")\n\n\ndef shell(cmd):\n  try:\n    output = subprocess.check_output(cmd)\n  except subprocess.CalledProcessError as e:\n    print(e.output)\n    raise\n  return output.decode(\"UTF-8\").strip()\n\n\n# Python\n\ndef get_python_bin_path(python_bin_path_flag):\n  \"\"\"Returns the path to the Python interpreter to use.\"\"\"\n  path = python_bin_path_flag or sys.executable\n  return path.replace(os.sep, \"/\")\n\n\ndef get_python_version(python_bin_path):\n  version_output = shell(\n    [python_bin_path, \"-c\",\n     (\"import sys; print(\\\"{}.{}\\\".format(sys.version_info[0], \"\n      \"sys.version_info[1]))\")])\n  major, minor = map(int, version_output.split(\".\"))\n  return major, minor\n\ndef check_python_version(python_version):\n  if python_version < (3, 7):\n    print(\"ERROR: JAX requires Python 3.7 or newer, found \", python_version)\n    sys.exit(-1)\n\n\ndef check_numpy_version(python_bin_path):\n  version = shell(\n      [python_bin_path, \"-c\", \"import numpy as np; print(np.__version__)\"])\n  numpy_version = tuple(map(int, version.split(\".\")[:2]))\n  if numpy_version < (1, 20):\n    print(\"ERROR: JAX requires NumPy 1.20 or newer, found \" + version + \".\")\n    sys.exit(-1)\n  return version\n\n# Bazel\n\nBAZEL_BASE_URI = \"https://github.com/bazelbuild/bazel/releases/download/5.1.1/\"\nBazelPackage = collections.namedtuple(\"BazelPackage\",\n                                      [\"base_uri\", \"file\", \"sha256\"])\nbazel_packages = {\n    (\"Linux\", \"x86_64\"):\n        BazelPackage(\n            base_uri=None,\n            file=\"bazel-5.1.1-linux-x86_64\",\n            sha256=\n            \"5e126060d9169b462a18e97435356c3b3712d20fdbef9ac7609016838a90e7d3\"),\n    (\"Linux\", \"aarch64\"):\n        BazelPackage(\n            base_uri=None,\n            file=\"bazel-5.1.1-linux-arm64\",\n            sha256=\n            \"a590a28608772e779efc0c29bb678cd2a150deb27a9f8c557cc1d2b131a779ef\"),\n    (\"Darwin\", \"x86_64\"):\n        BazelPackage(\n            base_uri=None,\n            file=\"bazel-5.1.1-darwin-x86_64\",\n            sha256=\n            \"91d8958fffd3077c32466a03300b7eba3b680588688f11d378ccbf2ae9000753\"),\n    (\"Darwin\", \"arm64\"):\n        BazelPackage(\n            base_uri=None,\n            file=\"bazel-5.1.1-darwin-arm64\",\n            sha256=\n            \"4fad9d066436ccca022578192be9fcc330d833799833c549683949939b3ce717\"),\n    (\"Windows\", \"AMD64\"):\n        BazelPackage(\n            base_uri=None,\n            file=\"bazel-5.1.1-windows-x86_64.exe\",\n            sha256=\n            \"03061f1e9aac1966155ca402dcd1075c6493dfe85df72aa2cf3e12fcaa258d90\"),\n}\n\n\ndef download_and_verify_bazel():\n  \"\"\"Downloads a bazel binary from Github, verifying its SHA256 hash.\"\"\"\n  package = bazel_packages.get((platform.system(), platform.machine()))\n  if package is None:\n    return None\n\n  if not os.access(package.file, os.X_OK):\n    uri = (package.base_uri or BAZEL_BASE_URI) + package.file\n    sys.stdout.write(f\"Downloading bazel from: {uri}\\n\")\n\n    def progress(block_count, block_size, total_size):\n      if total_size <= 0:\n        total_size = 170**6\n      progress = (block_count * block_size) / total_size\n      num_chars = 40\n      progress_chars = int(num_chars * progress)\n      sys.stdout.write(\"{} [{}{}] {}%\\r\".format(\n          package.file, \"#\" * progress_chars,\n          \".\" * (num_chars - progress_chars), int(progress * 100.0)))\n\n    tmp_path, _ = urlretrieve(uri, None,\n                              progress if sys.stdout.isatty() else None)\n    sys.stdout.write(\"\\n\")\n\n    # Verify that the downloaded Bazel binary has the expected SHA256.\n    with open(tmp_path, \"rb\") as downloaded_file:\n      contents = downloaded_file.read()\n\n    digest = hashlib.sha256(contents).hexdigest()\n    if digest != package.sha256:\n      print(\n          \"Checksum mismatch for downloaded bazel binary (expected {}; got {}).\"\n          .format(package.sha256, digest))\n      sys.exit(-1)\n\n    # Write the file as the bazel file name.\n    with open(package.file, \"wb\") as out_file:\n      out_file.write(contents)\n\n    # Mark the file as executable.\n    st = os.stat(package.file)\n    os.chmod(package.file,\n             st.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)\n\n  return os.path.join(\".\", package.file)\n\n\ndef get_bazel_paths(bazel_path_flag):\n  \"\"\"Yields a sequence of guesses about bazel path. Some of sequence elements\n  can be None. The resulting iterator is lazy and potentially has a side\n  effects.\"\"\"\n  yield bazel_path_flag\n  yield which(\"bazel\")\n  yield download_and_verify_bazel()\n\n\ndef get_bazel_path(bazel_path_flag):\n  \"\"\"Returns the path to a Bazel binary, downloading Bazel if not found. Also,\n  checks Bazel's version is at least newer than 5.1.1\n\n  A manual version check is needed only for really old bazel versions.\n  Newer bazel releases perform their own version check against .bazelversion\n  (see for details\n  https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).\n  \"\"\"\n  for path in filter(None, get_bazel_paths(bazel_path_flag)):\n    version = get_bazel_version(path)\n    if version is not None and version >= (5, 1, 1):\n      return path, \".\".join(map(str, version))\n\n  print(\"Cannot find or download a suitable version of bazel.\"\n        \"Please install bazel >= 5.1.1.\")\n  sys.exit(-1)\n\n\ndef get_bazel_version(bazel_path):\n  try:\n    version_output = shell([bazel_path, \"--version\"])\n  except subprocess.CalledProcessError:\n    return None\n  match = re.search(r\"bazel *([0-9\\\\.]+)\", version_output)\n  if match is None:\n    return None\n  return tuple(int(x) for x in match.group(1).split(\".\"))\n\n\ndef write_bazelrc(*, python_bin_path, remote_build,\n                  cuda_toolkit_path, cudnn_install_path,\n                  cuda_version, cudnn_version, rocm_toolkit_path,\n                  cpu, cuda_compute_capabilities,\n                  rocm_amdgpu_targets, bazel_options, target_cpu_features,\n                  wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,\n                  enable_tpu, enable_remote_tpu, enable_rocm,\n                  enable_plugin_device):\n  tf_cuda_paths = []\n\n  with open(\"../.jax_configure.bazelrc\", \"w\") as f:\n    if not remote_build and python_bin_path:\n      f.write(textwrap.dedent(\"\"\"\\\n        build --strategy=Genrule=standalone\n        build --repo_env PYTHON_BIN_PATH=\"{python_bin_path}\"\n        build --action_env=PYENV_ROOT\n        build --python_path=\"{python_bin_path}\"\n        \"\"\").format(python_bin_path=python_bin_path))\n\n    if cuda_toolkit_path:\n      tf_cuda_paths.append(cuda_toolkit_path)\n      f.write(\"build --action_env CUDA_TOOLKIT_PATH=\\\"{cuda_toolkit_path}\\\"\\n\"\n              .format(cuda_toolkit_path=cuda_toolkit_path))\n    if cudnn_install_path:\n      # see https://github.com/tensorflow/tensorflow/issues/51040\n      if cudnn_install_path not in tf_cuda_paths:\n        tf_cuda_paths.append(cudnn_install_path)\n      f.write(\"build --action_env CUDNN_INSTALL_PATH=\\\"{cudnn_install_path}\\\"\\n\"\n              .format(cudnn_install_path=cudnn_install_path))\n    if len(tf_cuda_paths):\n      f.write(\"build --action_env TF_CUDA_PATHS=\\\"{tf_cuda_paths}\\\"\\n\"\n              .format(tf_cuda_paths=\",\".join(tf_cuda_paths)))\n    if cuda_version:\n      f.write(\"build --action_env TF_CUDA_VERSION=\\\"{cuda_version}\\\"\\n\"\n              .format(cuda_version=cuda_version))\n    if cudnn_version:\n      f.write(\"build --action_env TF_CUDNN_VERSION=\\\"{cudnn_version}\\\"\\n\"\n              .format(cudnn_version=cudnn_version))\n    if cuda_compute_capabilities:\n      f.write(\n        f'build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES=\"{cuda_compute_capabilities}\"\\n')\n    if rocm_toolkit_path:\n      f.write(\"build --action_env ROCM_PATH=\\\"{rocm_toolkit_path}\\\"\\n\"\n              .format(rocm_toolkit_path=rocm_toolkit_path))\n    if rocm_amdgpu_targets:\n      f.write(\n        f'build:rocm --action_env TF_ROCM_AMDGPU_TARGETS=\"{rocm_amdgpu_targets}\"\\n')\n    if cpu is not None:\n      f.write(\"build --distinct_host_configuration=true\\n\")\n      f.write(f\"build --cpu={cpu}\\n\")\n    else:\n      f.write(\"build --distinct_host_configuration=false\\n\")\n\n    for o in bazel_options:\n      f.write(f\"build {o}\\n\")\n    if target_cpu_features == \"release\":\n      if wheel_cpu == \"x86_64\":\n        f.write(\"build --config=avx_windows\\n\" if is_windows()\n                else \"build --config=avx_posix\\n\")\n    elif target_cpu_features == \"native\":\n      if is_windows():\n        print(\"--target_cpu_features=native is not supported on Windows; ignoring.\")\n      else:\n        f.write(\"build --config=native_arch_posix\\n\")\n\n    if enable_mkl_dnn:\n      f.write(\"build --config=mkl_open_source_only\\n\")\n    if enable_cuda:\n      f.write(\"build --config=cuda\\n\")\n      if not enable_nccl:\n        f.write(\"build --config=nonccl\\n\")\n      else:\n        from cupy.cuda import nccl\n        nccl_version = str(nccl.get_version())\n        nccl_version = f\"{nccl_version[0]}.{int(nccl_version[1:-2])}.{int(nccl_version[-2:])}\"\n        f.write(f'build --action_env TF_NCCL_VERSION=\"{nccl_version}\"\\n')\n\n    if enable_tpu:\n      f.write(\"build --config=tpu\\n\")\n    if enable_remote_tpu:\n      f.write(\"build --//build:enable_remote_tpu=true\\n\")\n    if enable_rocm:\n      f.write(\"build --config=rocm\\n\")\n      if not enable_nccl:\n        f.write(\"build --config=nonccl\\n\")\n    if enable_plugin_device:\n      f.write(\"build --config=plugin_device\\n\")\n\nBANNER = r\"\"\"\n     _   _  __  __\n    | | / \\ \\ \\/ /\n _  | |/ _ \\ \\  /\n| |_| / ___ \\/  \\\n \\___/_/   \\/_/\\_\\\n\n\"\"\"\n\nEPILOG = \"\"\"\n\nFrom the 'build' directory in the JAX repository, run\n    python build.py\nor\n    python3 build.py\nto download and build JAX's XLA (jaxlib) dependency.\n\"\"\"\n\n\ndef _parse_string_as_bool(s):\n  \"\"\"Parses a string as a boolean argument.\"\"\"\n  lower = s.lower()\n  if lower == \"true\":\n    return True\n  elif lower == \"false\":\n    return False\n  else:\n    raise ValueError(f\"Expected either 'true' or 'false'; got {s}\")\n\n\ndef add_boolean_argument(parser, name, default=False, help_str=None):\n  \"\"\"Creates a boolean flag.\"\"\"\n  group = parser.add_mutually_exclusive_group()\n  group.add_argument(\n      \"--\" + name,\n      nargs=\"?\",\n      default=default,\n      const=True,\n      type=_parse_string_as_bool,\n      help=help_str)\n  group.add_argument(\"--no\" + name, dest=name, action=\"store_false\")\n\n\ndef main():\n  cwd = os.getcwd()\n  parser = argparse.ArgumentParser(\n      description=\"Builds jaxlib from source.\", epilog=EPILOG)\n  parser.add_argument(\n      \"--bazel_path\",\n      help=\"Path to the Bazel binary to use. The default is to find bazel via \"\n      \"the PATH; if none is found, downloads a fresh copy of bazel from \"\n      \"GitHub.\")\n  parser.add_argument(\n      \"--python_bin_path\",\n      help=\"Path to Python binary to use. The default is the Python \"\n      \"interpreter used to run the build script.\")\n  parser.add_argument(\n      \"--target_cpu_features\",\n      choices=[\"release\", \"native\", \"default\"],\n      default=\"release\",\n      help=\"What CPU features should we target? 'release' enables CPU \"\n           \"features that should be enabled for a release build, which on \"\n           \"x86-64 architectures enables AVX. 'native' enables \"\n           \"-march=native, which generates code targeted to use all \"\n           \"features of the current machine. 'default' means don't opt-in \"\n           \"to any architectural features and use whatever the C compiler \"\n           \"generates by default.\")\n  add_boolean_argument(\n      parser,\n      \"enable_mkl_dnn\",\n      default=True,\n      help_str=\"Should we build with MKL-DNN enabled?\")\n  add_boolean_argument(\n      parser,\n      \"enable_cuda\",\n      help_str=\"Should we build with CUDA enabled? Requires CUDA and CuDNN.\")\n  add_boolean_argument(\n      parser,\n      \"enable_tpu\",\n      help_str=\"Should we build with Cloud TPU VM support enabled?\")\n  add_boolean_argument(\n      parser,\n      \"enable_remote_tpu\",\n      help_str=\"Should we build with remote Cloud TPU support enabled?\")\n  add_boolean_argument(\n      parser,\n      \"enable_rocm\",\n      help_str=\"Should we build with ROCm enabled?\")\n  add_boolean_argument(\n      parser,\n      \"enable_nccl\",\n      default=True,\n      help_str=\"Should we build with NCCL enabled? Has no effect for non-CUDA \"\n               \"builds.\")\n  add_boolean_argument(\n      parser,\n      \"enable_plugin_device\",\n      default=False,\n      help_str=\"Should we build with a plugin device enable?\")\n  add_boolean_argument(\n      parser,\n      \"remote_build\",\n      default=False,\n      help_str=\"Should we build with RBE (Remote Build Environment)?\")\n  parser.add_argument(\n      \"--cuda_path\",\n      default=None,\n      help=\"Path to the CUDA toolkit.\")\n  parser.add_argument(\n      \"--cudnn_path\",\n      default=None,\n      help=\"Path to CUDNN libraries.\")\n  parser.add_argument(\n      \"--cuda_version\",\n      default=None,\n      help=\"CUDA toolkit version, e.g., 11.1\")\n  parser.add_argument(\n      \"--cudnn_version\",\n      default=None,\n      help=\"CUDNN version, e.g., 8\")\n  # Caution: if changing the default list of CUDA capabilities, you should also\n  # update the list in .bazelrc, which is used for wheel builds.\n  parser.add_argument(\n      \"--cuda_compute_capabilities\",\n      default=None,\n      help=\"A comma-separated list of CUDA compute capabilities to support.\")\n  parser.add_argument(\n      \"--rocm_amdgpu_targets\",\n      default=\"gfx900,gfx906,gfx908,gfx90a,gfx1030\",\n      help=\"A comma-separated list of ROCm amdgpu targets to support.\")\n  parser.add_argument(\n      \"--rocm_path\",\n      default=None,\n      help=\"Path to the ROCm toolkit.\")\n  parser.add_argument(\n      \"--bazel_startup_options\",\n      action=\"append\", default=[],\n      help=\"Additional startup options to pass to bazel.\")\n  parser.add_argument(\n      \"--bazel_options\",\n      action=\"append\", default=[],\n      help=\"Additional options to pass to bazel.\")\n  parser.add_argument(\n      \"--output_path\",\n      default=os.path.join(cwd, \"dist\"),\n      help=\"Directory to which the jaxlib wheel should be written\")\n  parser.add_argument(\n      \"--target_cpu\",\n      default=None,\n      help=\"CPU platform to target. Default is the same as the host machine. \"\n           \"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.\")\n  add_boolean_argument(\n      parser,\n      \"configure_only\",\n      default=False,\n      help_str=\"If true, writes a .bazelrc file but does not build jaxlib.\")\n  parser.add_argument(\n      \"--dev_install\",\n      action=\"store_true\",\n      help=\"Do not build wheel. Use dev install\")\n\n  args = parser.parse_args()\n\n  if is_windows() and args.enable_cuda:\n    if args.cuda_version is None:\n      parser.error(\"--cuda_version is needed for Windows CUDA build.\")\n    if args.cudnn_version is None:\n      parser.error(\"--cudnn_version is needed for Windows CUDA build.\")\n\n  if args.enable_cuda and args.enable_rocm:\n    parser.error(\"--enable_cuda and --enable_rocm cannot be enabled at the same time.\")\n\n  print(BANNER)\n\n  output_path = os.path.abspath(args.output_path)\n  os.chdir(os.path.dirname(__file__ or args.prog) or '.')\n\n  host_cpu = platform.machine()\n  wheel_cpus = {\n      \"darwin_arm64\": \"arm64\",\n      \"darwin_x86_64\": \"x86_64\",\n      \"ppc\": \"ppc64le\",\n      \"aarch64\": \"aarch64\",\n  }\n  # TODO(phawkins): support other bazel cpu overrides.\n  wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None\n               else host_cpu)\n\n  # Find a working Bazel.\n  bazel_path, bazel_version = get_bazel_path(args.bazel_path)\n  print(f\"Bazel binary path: {bazel_path}\")\n  print(f\"Bazel version: {bazel_version}\")\n\n  python_bin_path = get_python_bin_path(args.python_bin_path)\n  print(f\"Python binary path: {python_bin_path}\")\n  python_version = get_python_version(python_bin_path)\n  print(\"Python version: {}\".format(\".\".join(map(str, python_version))))\n  check_python_version(python_version)\n\n  numpy_version = check_numpy_version(python_bin_path)\n  print(f\"NumPy version: {numpy_version}\")\n\n  print(\"MKL-DNN enabled: {}\".format(\"yes\" if args.enable_mkl_dnn else \"no\"))\n  print(f\"Target CPU: {wheel_cpu}\")\n  print(f\"Target CPU features: {args.target_cpu_features}\")\n\n  cuda_toolkit_path = args.cuda_path\n  cudnn_install_path = args.cudnn_path\n  rocm_toolkit_path = args.rocm_path\n  print(\"CUDA enabled: {}\".format(\"yes\" if args.enable_cuda else \"no\"))\n  if args.enable_cuda:\n    if cuda_toolkit_path:\n      print(f\"CUDA toolkit path: {cuda_toolkit_path}\")\n    if cudnn_install_path:\n      print(f\"CUDNN library path: {cudnn_install_path}\")\n    if args.cuda_compute_capabilities is not None:\n      print(f\"CUDA compute capabilities: {args.cuda_compute_capabilities}\")\n    if args.cuda_version:\n      print(f\"CUDA version: {args.cuda_version}\")\n    if args.cudnn_version:\n      print(f\"CUDNN version: {args.cudnn_version}\")\n    print(\"NCCL enabled: {}\".format(\"yes\" if args.enable_nccl else \"no\"))\n\n  print(\"TPU enabled: {}\".format(\"yes\" if args.enable_tpu else \"no\"))\n  print(\"Remote TPU enabled: {}\".format(\"yes\" if args.enable_remote_tpu else \"no\"))\n\n  print(\"ROCm enabled: {}\".format(\"yes\" if args.enable_rocm else \"no\"))\n  if args.enable_rocm:\n    if rocm_toolkit_path:\n      print(f\"ROCm toolkit path: {rocm_toolkit_path}\")\n    print(f\"ROCm amdgpu targets: {args.rocm_amdgpu_targets}\")\n\n  print(\"Plugin device enabled: {}\".format(\"yes\" if args.enable_plugin_device else \"no\"))\n\n  write_bazelrc(\n      python_bin_path=python_bin_path,\n      remote_build=args.remote_build,\n      cuda_toolkit_path=cuda_toolkit_path,\n      cudnn_install_path=cudnn_install_path,\n      cuda_version=args.cuda_version,\n      cudnn_version=args.cudnn_version,\n      rocm_toolkit_path=rocm_toolkit_path,\n      cpu=args.target_cpu,\n      cuda_compute_capabilities=args.cuda_compute_capabilities,\n      rocm_amdgpu_targets=args.rocm_amdgpu_targets,\n      bazel_options=args.bazel_options,\n      target_cpu_features=args.target_cpu_features,\n      wheel_cpu=wheel_cpu,\n      enable_mkl_dnn=args.enable_mkl_dnn,\n      enable_cuda=args.enable_cuda,\n      enable_nccl=args.enable_nccl,\n      enable_tpu=args.enable_tpu,\n      enable_remote_tpu=args.enable_remote_tpu,\n      enable_rocm=args.enable_rocm,\n      enable_plugin_device=args.enable_plugin_device,\n  )\n\n  if args.configure_only:\n    return\n\n  print(\"\\nBuilding XLA and installing it in the jaxlib source tree...\")\n\n\n  command = ([bazel_path] + args.bazel_startup_options +\n    [\"run\", \"--verbose_failures=true\"] +\n    [\":build_wheel\", \"--\",\n    f\"--output_path={output_path}\",\n    f\"--cpu={wheel_cpu}\"])\n  if args.dev_install:\n    command += [\"--dev_install\"]\n  print(\" \".join(command))\n  shell(command)\n  shell([bazel_path] + args.bazel_startup_options + [\"shutdown\"])\n\n\nif __name__ == \"__main__\":\n  main()\n"
  },
  {
    "path": "build_jaxlib/build/build_wheel.py",
    "content": "# Copyright 2020 The JAX Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Script that builds a jaxlib wheel, intended to be run via bazel run as part\n# of the jaxlib build process.\n\n# Most users should not run this script directly; use build.py instead.\n\nimport argparse\nimport datetime\nimport functools\nimport glob\nimport os\nimport pathlib\nimport platform\nimport re\nimport shutil\nimport subprocess\nimport sys\nimport tempfile\n\nfrom bazel_tools.tools.python.runfiles import runfiles\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n  \"--sources_path\",\n  default=None,\n  help=\"Path in which the wheel's sources should be prepared. Optional. If \"\n       \"omitted, a temporary directory will be used.\")\nparser.add_argument(\n  \"--output_path\",\n  default=None,\n  required=True,\n  help=\"Path to which the output wheel should be written. Required.\")\nparser.add_argument(\n  \"--cpu\",\n  default=None,\n  required=True,\n  help=\"Target CPU architecture. Required.\")\nparser.add_argument(\n  \"--dev_install\",\n  action=\"store_true\",\n  help=\"Do not build wheel. Use dev install\")\nargs = parser.parse_args()\n\nr = runfiles.Create()\n\n\ndef _is_mac():\n  return platform.system() == \"Darwin\"\n\n\ndef _is_windows():\n  return sys.platform.startswith(\"win32\")\n\n\npyext = \"pyd\" if _is_windows() else \"so\"\n\n\ndef exists(src_file):\n  return r.Rlocation(src_file) is not None\n\n\ndef copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True):\n  if from_runfiles:\n    src_file = r.Rlocation(src_file)\n\n  src_filename = os.path.basename(src_file)\n  dst_file = os.path.join(dst_dir, dst_filename or src_filename)\n  if _is_windows():\n    shutil.copyfile(src_file, dst_file)\n  else:\n    shutil.copy(src_file, dst_file)\n\ndef dev_install(sources_path, output_path):\n  sys.stderr.write(\"Dev Install:\\n\")\n  sys.stderr.write(f'Run \"pip install -e .\" once in {output_path}\\n')\n  os.system(f\"rm -rf {output_path}/*\")\n  os.system(f\"cp -r {sources_path}/* {output_path}\")\n  return\n\n_XLA_EXTENSION_STUBS = [\n    \"__init__.pyi\",\n    \"jax_jit.pyi\",\n    \"ops.pyi\",\n    \"outfeed_receiver.pyi\",\n    \"pmap_lib.pyi\",\n    \"profiler.pyi\",\n    \"pytree.pyi\",\n    \"transfer_guard_lib.pyi\",\n]\n_OPTIONAL_XLA_EXTENSION_STUBS = [\n]\n\n\ndef patch_copy_xla_extension_stubs(dst_dir):\n  # This file is required by PEP-561. It marks jaxlib as package containing\n  # type stubs.\n  with open(os.path.join(dst_dir, \"py.typed\"), \"w\"):\n    pass\n  xla_extension_dir = os.path.join(dst_dir, \"xla_extension\")\n  os.makedirs(xla_extension_dir)\n  for stub_name in _XLA_EXTENSION_STUBS:\n    stub_path = r.Rlocation(\n        \"org_tensorflow/tensorflow/compiler/xla/python/xla_extension/\" + stub_name)\n    stub_path = str(stub_path)  # Make pytype accept os.path.exists(stub_path).\n    if stub_name in _OPTIONAL_XLA_EXTENSION_STUBS and not os.path.exists(stub_path):\n      continue\n    with open(stub_path) as f:\n      src = f.read()\n    src = src.replace(\n        \"from tensorflow.compiler.xla.python import xla_extension\",\n        \"from .. import xla_extension\"\n    )\n    with open(os.path.join(xla_extension_dir, stub_name), \"w\") as f:\n      f.write(src)\n\n\ndef patch_copy_tpu_client_py(dst_dir):\n  with open(r.Rlocation(\"org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.py\")) as f:\n    src = f.read()\n    src = src.replace(\"from tensorflow.compiler.xla.python import xla_extension as _xla\",\n                      \"from . import xla_extension as _xla\")\n    src = src.replace(\"from tensorflow.compiler.xla.python import xla_client\",\n                      \"from . import xla_client\")\n    src = src.replace(\n        \"from tensorflow.compiler.xla.python.tpu_driver.client import tpu_client_extension as _tpu_client\",\n        \"from . import tpu_client_extension as _tpu_client\")\n    with open(os.path.join(dst_dir, \"tpu_client.py\"), \"w\") as f:\n      f.write(src)\n\ndef verify_mac_libraries_dont_reference_chkstack():\n  \"\"\"Verifies that xla_extension.so doesn't depend on ____chkstk_darwin.\n\n  We don't entirely know why this happens, but in some build environments\n  we seem to target the wrong Mac OS version.\n  https://github.com/google/jax/issues/3867\n\n  This check makes sure we don't release wheels that have this dependency.\n  \"\"\"\n  if not _is_mac():\n    return\n  nm = subprocess.run(\n    [\"nm\", \"-g\",\n     r.Rlocation(\"org_tensorflow/tensorflow/compiler/xla/python/xla_extension.so\")\n    ],\n    capture_output=True, text=True,\n    check=False)\n  if nm.returncode != 0:\n    raise RuntimeError(f\"nm process failed: {nm.stdout} {nm.stderr}\")\n  if \"____chkstk_darwin\" in nm.stdout:\n    raise RuntimeError(\n      \"Mac wheel incorrectly depends on symbol ____chkstk_darwin, which \"\n      \"means that it isn't compatible with older MacOS versions.\")\n\n\ndef prepare_wheel(sources_path):\n  \"\"\"Assembles a source tree for the wheel in `sources_path`.\"\"\"\n  jaxlib_dir = os.path.join(sources_path, \"jaxlib\")\n  os.makedirs(jaxlib_dir)\n  copy_to_jaxlib = functools.partial(copy_file, dst_dir=jaxlib_dir)\n\n  verify_mac_libraries_dont_reference_chkstack()\n  copy_file(\"__main__/build/LICENSE.txt\", dst_dir=sources_path)\n  copy_file(\"__main__/jaxlib/README.md\", dst_dir=sources_path)\n  copy_file(\"__main__/jaxlib/setup.py\", dst_dir=sources_path)\n  copy_file(\"__main__/jaxlib/setup.cfg\", dst_dir=sources_path)\n  copy_to_jaxlib(\"__main__/jaxlib/init.py\", dst_filename=\"__init__.py\")\n  copy_to_jaxlib(f\"__main__/jaxlib/cpu_feature_guard.{pyext}\")\n  copy_to_jaxlib(\"__main__/jaxlib/lapack.py\")\n  copy_to_jaxlib(f\"__main__/jaxlib/_lapack.{pyext}\")\n  copy_to_jaxlib(\"__main__/jaxlib/mhlo_helpers.py\")\n  copy_to_jaxlib(f\"__main__/jaxlib/_ducc_fft.{pyext}\")\n  copy_to_jaxlib(\"__main__/jaxlib/ducc_fft.py\")\n  copy_to_jaxlib(\"__main__/jaxlib/gpu_prng.py\")\n  copy_to_jaxlib(\"__main__/jaxlib/gpu_linalg.py\")\n  copy_to_jaxlib(\"__main__/jaxlib/gpu_solver.py\")\n  copy_to_jaxlib(\"__main__/jaxlib/gpu_sparse.py\")\n  copy_to_jaxlib(\"__main__/jaxlib/version.py\")\n  copy_to_jaxlib(\"__main__/jaxlib/xla_client.py\")\n  copy_to_jaxlib(f\"__main__/jaxlib/xla_extension.{pyext}\")\n\n  cuda_dir = os.path.join(jaxlib_dir, \"cuda\")\n  if exists(f\"__main__/jaxlib/cuda/_cusolver.{pyext}\"):\n    libdevice_dir = os.path.join(cuda_dir, \"nvvm\", \"libdevice\")\n    os.makedirs(libdevice_dir)\n    copy_file(\"local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc\", dst_dir=libdevice_dir)\n    copy_file(f\"__main__/jaxlib/cuda/_cusolver.{pyext}\", dst_dir=cuda_dir)\n    copy_file(f\"__main__/jaxlib/cuda/_cublas.{pyext}\", dst_dir=cuda_dir)\n    copy_file(f\"__main__/jaxlib/cuda/_cuda_linalg.{pyext}\", dst_dir=cuda_dir)\n    copy_file(f\"__main__/jaxlib/cuda/_cuda_prng.{pyext}\", dst_dir=cuda_dir)\n  rocm_dir = os.path.join(jaxlib_dir, \"rocm\")\n  if exists(f\"__main__/jaxlib/rocm/_hipsolver.{pyext}\"):\n    os.makedirs(rocm_dir)\n    copy_file(f\"__main__/jaxlib/rocm/_hipsolver.{pyext}\", dst_dir=rocm_dir)\n    copy_file(f\"__main__/jaxlib/rocm/_hipblas.{pyext}\", dst_dir=rocm_dir)\n    copy_file(f\"__main__/jaxlib/rocm/_hip_linalg.{pyext}\", dst_dir=rocm_dir)\n    copy_file(f\"__main__/jaxlib/rocm/_hip_prng.{pyext}\", dst_dir=rocm_dir)\n  if exists(f\"__main__/jaxlib/cuda/_cusparse.{pyext}\"):\n    copy_file(f\"__main__/jaxlib/cuda/_cusparse.{pyext}\", dst_dir=cuda_dir)\n  if exists(f\"__main__/jaxlib/rocm/_hipsparse.{pyext}\"):\n    copy_file(f\"__main__/jaxlib/rocm/_hipsparse.{pyext}\", dst_dir=rocm_dir)\n\n\n  mlir_dir = os.path.join(jaxlib_dir, \"mlir\")\n  mlir_dialects_dir = os.path.join(jaxlib_dir, \"mlir\", \"dialects\")\n  mlir_libs_dir = os.path.join(jaxlib_dir, \"mlir\", \"_mlir_libs\")\n  os.makedirs(mlir_dir)\n  os.makedirs(mlir_dialects_dir)\n  os.makedirs(mlir_libs_dir)\n  copy_file(\"__main__/jaxlib/mlir/ir.py\", dst_dir=mlir_dir)\n  copy_file(\"__main__/jaxlib/mlir/passmanager.py\", dst_dir=mlir_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_builtin_ops_ext.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_builtin_ops_gen.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_chlo_ops_gen.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_mhlo_ops_gen.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_ods_common.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_func_ops_ext.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_func_ops_gen.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_ml_program_ops_ext.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_ml_program_ops_gen.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/_sparse_tensor_ops_gen.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/sparse_tensor.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/builtin.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/chlo.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/mhlo.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/func.py\", dst_dir=mlir_dialects_dir)\n  copy_file(\"__main__/jaxlib/mlir/dialects/ml_program.py\", dst_dir=mlir_dialects_dir)\n\n  copy_file(\"__main__/jaxlib/mlir/_mlir_libs/__init__.py\", dst_dir=mlir_libs_dir)\n  copy_file(f\"__main__/jaxlib/mlir/_mlir_libs/_mlir.{pyext}\", dst_dir=mlir_libs_dir)\n  copy_file(f\"__main__/jaxlib/mlir/_mlir_libs/_chlo.{pyext}\", dst_dir=mlir_libs_dir)\n  copy_file(f\"__main__/jaxlib/mlir/_mlir_libs/_mlirHlo.{pyext}\", dst_dir=mlir_libs_dir)\n  copy_file(f\"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}\", dst_dir=mlir_libs_dir)\n  copy_file(f\"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}\", dst_dir=mlir_libs_dir)\n  copy_file(f\"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}\", dst_dir=mlir_libs_dir)\n  copy_file(f\"__main__/jaxlib/mlir/_mlir_libs/_site_initialize_0.{pyext}\", dst_dir=mlir_libs_dir)\n  if _is_windows():\n    copy_file(\"__main__/jaxlib/mlir/_mlir_libs/jaxlib_mlir_capi.dll\", dst_dir=mlir_libs_dir)\n  elif _is_mac():\n    copy_file(\"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.dylib\", dst_dir=mlir_libs_dir)\n  else:\n    copy_file(\"__main__/jaxlib/mlir/_mlir_libs/libjaxlib_mlir_capi.so\", dst_dir=mlir_libs_dir)\n  patch_copy_xla_extension_stubs(jaxlib_dir)\n\n  if exists(\"org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so\"):\n    copy_to_jaxlib(\"org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so\")\n    patch_copy_tpu_client_py(jaxlib_dir)\n\n\ndef edit_jaxlib_version(sources_path):\n  version_regex = re.compile(r'__version__ = \\\"(.*)\\\"')\n\n  version_file = pathlib.Path(sources_path) / \"jaxlib\" / \"version.py\"\n  content = version_file.read_text()\n\n  version_num = version_regex.search(content).group(1)\n\n  datestring = datetime.datetime.now().strftime('%Y%m%d')\n  nightly_version = f'{version_num}.dev{datestring}'\n\n  content = content.replace(f'__version__ = \"{version_num}\"',\n                            f'__version__ = \"{nightly_version}\"')\n  version_file.write_text(content)\n\n\ndef build_wheel(sources_path, output_path, cpu):\n  \"\"\"Builds a wheel in `output_path` using the source tree in `sources_path`.\"\"\"\n  platform_name, cpu_name = {\n    (\"Linux\", \"x86_64\"): (\"manylinux2014\", \"x86_64\"),\n    (\"Linux\", \"aarch64\"): (\"manylinux2014\", \"aarch64\"),\n    (\"Linux\", \"ppc64le\"): (\"manylinux2014\", \"ppc64le\"),\n    (\"Darwin\", \"x86_64\"): (\"macosx_10_14\", \"x86_64\"),\n    (\"Darwin\", \"arm64\"): (\"macosx_11_0\", \"arm64\"),\n    (\"Windows\", \"AMD64\"): (\"win\", \"amd64\"),\n  }[(platform.system(), cpu)]\n  python_tag_arg = (f\"--python-tag=cp{sys.version_info.major}\"\n                    f\"{sys.version_info.minor}\")\n  platform_tag_arg = f\"--plat-name={platform_name}_{cpu_name}\"\n  cwd = os.getcwd()\n  if os.environ.get('JAXLIB_NIGHTLY'):\n    edit_jaxlib_version(sources_path)\n  os.chdir(sources_path)\n  subprocess.run([sys.executable, \"setup.py\", \"bdist_wheel\",\n                 python_tag_arg, platform_tag_arg], check=True)\n  os.chdir(cwd)\n  for wheel in glob.glob(os.path.join(sources_path, \"dist\", \"*.whl\")):\n    output_file = os.path.join(output_path, os.path.basename(wheel))\n    sys.stderr.write(f\"Output wheel: {output_file}\\n\\n\")\n    sys.stderr.write(\"To install the newly-built jaxlib wheel, run:\\n\")\n    sys.stderr.write(f\"  pip install {output_file}\\n\\n\")\n    shutil.copy(wheel, output_path)\n\n\ntmpdir = None\nsources_path = args.sources_path\nif sources_path is None:\n  tmpdir = tempfile.TemporaryDirectory(prefix=\"jaxlib\")\n  sources_path = tmpdir.name\n\ntry:\n  os.makedirs(args.output_path, exist_ok=True)\n  prepare_wheel(sources_path)\n  if args.dev_install:\n    dev_install(sources_path, args.output_path)\n  else:\n    build_wheel(sources_path, args.output_path, args.cpu)\nfinally:\n  if tmpdir:\n    tmpdir.cleanup()\n"
  },
  {
    "path": "build_jaxlib/release/README.md",
    "content": "# How to Release JaxLib and generate a PyPI Index\n\n1.  Upload jaxlib wheels as assets under a release tag.\n```shell\nGITHUB_TOKEN=[ADMIN_TOKEN] python wheel_upload.py --tag [TAG] --path [PATH_TO_WHEELS]\n```\n\n2. Generate a html index page and commit it to the master branch of Alpa doc repository.\n```shell\nGITHUB_TOKEN=[ADMIN_TOKEN] python generate_pypi_index.py --tag [TAG]\n```\nAll wheel assets under `[TAG]` will be included in a html index page appeared in the doc repo.\n\nPlease make sure the TAG is aligned in Step 1 and Step 2.\n"
  },
  {
    "path": "build_jaxlib/release/generate_pypi_index.py",
    "content": "\"\"\"Generate and upload a PyPI index page given a tag.\"\"\"\nimport os\nimport logging\nimport argparse\nimport subprocess\nfrom datetime import datetime\n\nimport github3\nimport github3.session as session\nimport requests\n\n\ndef py_str(cstr):\n    return cstr.decode(\"utf-8\")\n\n\ndef url_is_valid(url):\n    \"\"\"Check if a given URL is valid, i.e. it returns 200 OK when requested.\"\"\"\n    r = requests.get(url)\n\n    if r.status_code != 200:\n        print(\"Warning: HTTP code %s for url %s\" % (r.status_code, url))\n\n    return r.status_code == 200\n\n\ndef list_wheels(repo, tag):\n    gh = github3.GitHub(token=os.environ[\"GITHUB_TOKEN\"],\n                        session=session.GitHubSession(default_connect_timeout=100, default_read_timeout=100))\n    repo = gh.repository(*repo.split(\"/\"))\n    wheels = []\n    all_tags = [release.tag_name for release in repo.releases()]\n    if tag not in all_tags:\n        raise RuntimeError(\"The tag provided does not exist.\")\n    release = repo.release_from_tag(tag)\n    for asset in release.assets():\n        print(f\"Validating {asset.name} with url: {asset.browser_download_url}\")\n        if asset.name.endswith(\".whl\") and url_is_valid(asset.browser_download_url):\n            wheels.append(asset)\n    return wheels\n\n\ndef update_wheel_page(keep_list, site_repo, tag, dry_run=False):\n    \"\"\"Update the wheel page\"\"\"\n    new_html = \"\"\n    for asset in keep_list:\n        new_html += '<a href=\"%s\">%s</a><br>\\n' % (\n            asset.browser_download_url,\n            asset.name,\n        )\n\n    def run_cmd(cmd):\n        proc = subprocess.Popen(\n            cmd, cwd=site_repo, stdout=subprocess.PIPE, stderr=subprocess.STDOUT\n        )\n        (out, _) = proc.communicate()\n        if proc.returncode != 0:\n            msg = \"git error: %s\" % cmd\n            msg += py_str(out)\n            raise RuntimeError(msg)\n\n    run_cmd([\"git\", \"fetch\"])\n    run_cmd([\"git\", \"checkout\", \"-B\", \"master\", \"origin/master\"])\n    wheel_html_path = os.path.join(site_repo, \"wheels.html\")\n    if not os.path.exists(wheel_html_path) or open(wheel_html_path, \"r\").read() != new_html:\n        print(f\"Wheel page changed, update {wheel_html_path}..\")\n        if not dry_run:\n            open(wheel_html_path, \"w\").write(new_html)\n            run_cmd([\"git\", \"add\", \"wheels.html\"])\n            run_cmd([\"git\", \"commit\", \"-am\",\n                     f\"wheel update at {datetime.now()} from tag {tag}\"])\n            run_cmd([\"git\", \"push\", \"origin\", \"master\"])\n\n\ndef delete_assets(remove_list, dry_run):\n    for asset in remove_list:\n        if not dry_run:\n            asset.delete()\n    if remove_list:\n        print(\"Finish deleting %d removed assets\" % len(remove_list))\n\n\ndef main():\n    logging.basicConfig(level=logging.WARNING)\n    parser = argparse.ArgumentParser(\n        description=\"Generate a wheel page given a release tag, assuming the wheels have been uploaded.\"\n    )\n    parser.add_argument(\"--dry-run\", action=\"store_true\")\n    parser.add_argument(\"--site-path\", type=str, default=\"alpa-projects.github.io\")\n    parser.add_argument(\"--repo\", type=str, default=\"alpa-projects/alpa\")\n    parser.add_argument(\"--tag\", type=str)\n\n    if \"GITHUB_TOKEN\" not in os.environ:\n        raise RuntimeError(\"need GITHUB_TOKEN\")\n    args = parser.parse_args()\n    wheels = list_wheels(args.repo, args.tag)\n    update_wheel_page(wheels, args.site_path, args.tag, args.dry_run)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "build_jaxlib/release/wheel_upload.py",
    "content": "\"\"\"Update the wheels page, prune old nightly builds if necessary (source from tlcpack).\"\"\"\nimport github3\nimport github3.session as session\nimport os\nimport logging\nimport argparse\n\n\ndef upload(args, path):\n    # gh = github3.login(token=os.environ[\"GITHUB_TOKEN\"])\n    gh = github3.GitHub(token=os.environ[\"GITHUB_TOKEN\"],\n                        session=session.GitHubSession(default_connect_timeout=100, default_read_timeout=100))\n    repo = gh.repository(*args.repo.split(\"/\"))\n    release = repo.release_from_tag(args.tag)\n    name = os.path.basename(path)\n    content_bytes = open(path, \"rb\").read()\n\n    for asset in release.assets():\n        if asset.name == name:\n            if not args.dry_run:\n                asset.delete()\n                print(f\"Remove duplicated file {name}\")\n    print(f\"Start to upload {path} to {args.repo}, this can take a while...\")\n    if not args.dry_run:\n        release.upload_asset(\"application/octet-stream\", name, content_bytes)\n    print(f\"Finish uploading {path}\")\n\n\ndef main():\n    logging.basicConfig(level=logging.WARNING)\n    parser = argparse.ArgumentParser(description=\"Upload wheel as an asset of a tag.\")\n    parser.add_argument(\"--tag\", type=str)\n    parser.add_argument(\"--repo\", type=str, default=\"alpa-projects/alpa\")\n    parser.add_argument(\"--dry-run\", action=\"store_true\")\n    parser.add_argument(\"--path\", type=str)\n\n    if \"GITHUB_TOKEN\" not in os.environ:\n        raise RuntimeError(\"need GITHUB_TOKEN\")\n    args = parser.parse_args()\n    if os.path.isdir(args.path):\n        for name in os.listdir(args.path):\n            if name.endswith(\".whl\"):\n                upload(args, os.path.join(args.path, name))\n    else:\n        upload(args, args.path)\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "build_jaxlib/update_build_scripts.patch",
    "content": "diff --git a/build_jaxlib/build/build.py b/build_jaxlib/build/build.py\nindex d8e90202..5cbcc33d 100755\n--- a/build_jaxlib/build/build.py\n+++ b/build_jaxlib/build/build.py\n@@ -283,6 +283,11 @@ def write_bazelrc(*, python_bin_path, remote_build,\n       f.write(\"build --config=cuda\\n\")\n       if not enable_nccl:\n         f.write(\"build --config=nonccl\\n\")\n+      else:\n+        from cupy.cuda import nccl\n+        nccl_version = str(nccl.get_version())\n+        nccl_version = f\"{nccl_version[0]}.{int(nccl_version[1:-2])}.{int(nccl_version[-2:])}\"\n+        f.write(f'build --action_env TF_NCCL_VERSION=\"{nccl_version}\"\\n')\n     if enable_tpu:\n       f.write(\"build --config=tpu\\n\")\n     if enable_remote_tpu:\n@@ -292,6 +297,7 @@ def write_bazelrc(*, python_bin_path, remote_build,\n       if not enable_nccl:\n         f.write(\"build --config=nonccl\\n\")\n \n+\n BANNER = r\"\"\"\n      _   _  __  __\n     | | / \\ \\ \\/ /\n@@ -443,6 +449,10 @@ def main():\n       \"configure_only\",\n       default=False,\n       help_str=\"If true, writes a .bazelrc file but does not build jaxlib.\")\n+  parser.add_argument(\n+      \"--dev_install\",\n+      action=\"store_true\",\n+      help=\"Do not build wheel. Use dev install\")\n   args = parser.parse_args()\n \n   if is_windows() and args.enable_cuda:\n@@ -546,6 +556,8 @@ def main():\n     [\":build_wheel\", \"--\",\n     f\"--output_path={output_path}\",\n     f\"--cpu={wheel_cpu}\"])\n+  if args.dev_install:\n+    command += [\"--dev_install\"]\n   print(\" \".join(command))\n   shell(command)\n   shell([bazel_path, \"shutdown\"])\ndiff --git a/build_jaxlib/build/build_wheel.py b/build_jaxlib/build/build_wheel.py\nindex 31df6256..d118da2c 100644\n--- a/build_jaxlib/build/build_wheel.py\n+++ b/build_jaxlib/build/build_wheel.py\n@@ -48,6 +48,10 @@ parser.add_argument(\n   default=None,\n   required=True,\n   help=\"Target CPU architecture. Required.\")\n+parser.add_argument(\n+  \"--dev_install\",\n+  action=\"store_true\",\n+  help=\"Do not build wheel. Use dev install\")\n args = parser.parse_args()\n \n r = runfiles.Create()\n@@ -79,6 +83,12 @@ def copy_file(src_file, dst_dir, dst_filename=None, from_runfiles=True):\n   else:\n     shutil.copy(src_file, dst_file)\n \n+def dev_install(sources_path, output_path):\n+  sys.stderr.write(\"Dev Install:\\n\")\n+  sys.stderr.write(f'Run \"pip install -e .\" once in {output_path}\\n')\n+  os.system(f\"rm -rf {output_path}/*\")\n+  os.system(f\"cp -r {sources_path}/* {output_path}\")\n+  return\n \n _XLA_EXTENSION_STUBS = [\n     \"__init__.pyi\",\n@@ -300,7 +310,10 @@ if sources_path is None:\n try:\n   os.makedirs(args.output_path, exist_ok=True)\n   prepare_wheel(sources_path)\n-  build_wheel(sources_path, args.output_path, args.cpu)\n+  if args.dev_install:\n+    dev_install(sources_path, args.output_path)\n+  else:\n+    build_wheel(sources_path, args.output_path, args.cpu)\n finally:\n   if tmpdir:\n     tmpdir.cleanup()\n"
  },
  {
    "path": "docker/README.md",
    "content": "# Alpa Docker\nThis directory contains Alpa's docker infrastructure. Alpa uses docker to provide environment to build and release Python wheels and to perform unit tests.\nMost docker files in this directory depend on [nvidia-docker](https://github.com/NVIDIA/nvidia-docker/).\n\nBelow we provide instructions on\n- How to build Alpa-modified jaxlib in a docker container\n- How to run Alpa in a docker container\n\nMore docker examples can be found in the directory of [Alpa CI/CD](../.github/workflows).\n\n## Build Jaxlib-alpa wheels using Docker\nWe provide a Docker image to build the Alpa-modified jaxlib wheels inside a container.\n\n\n### Steps\nFirst, figure out the CUDA and Python versions you want to use to build jaxlib. Current we support the following versions:\n- CUDA: 11.1, 11.2, 11.3\n- Python: 3.7, 3.8, 3.9\n\nSuppose we want to build the jaxlib-alpa with CUDA 11.1 and Python 3.8.\n#### Build the docker image\n```python\n# create a folder to save the output wheels\ncd alpa/docker && mkdir -p dist\n\n# build the image using the chosen CUDA version\ndocker build -t build-jaxlib-image -f build_jaxlib.Dockerfile . --build-arg JAX_CUDA_VERSION=11.1\n```\n\n#### Build the wheels inside a container\n```bash\n# create a subfolder for the specific wheel version.\nmkdir -p dist/cuda111\n\n# build the wheel in a container using the selected Python and CUDA versions\ndocker run --tmpfs /build:exec --rm -v $(pwd)/dist:/dist build-jaxlib-image 3.8 cuda 11.1 main\n\n# Move the output wheel\nmv -f dist/*.whl dist/cuda111/\n```\nCheck out the wheel under the folder ``alpa/build/dist/cuda111/``.\n\n## Run Alpa in a docker container\nYou can run Alpa inside a docker container. Below are steps on how to run Alpa in a docker container in the interactive mode.\n\nFirst, build a docker image based on the provided dockerfile:\n```bash\ndocker build -t run-alpa-image -f run_alpa.Dockerfile .\n```\n\nFor cloud provider with InfiniBand (such as CoreWeave) we need to include additional dependencies:\n ```bash\ndocker build -t run-alpa-image -f run_alpa_infiniband.Dockerfile .\n```\n\nSecond, build a container from the image and enter the container's interactive shell:\n```bash\ndocker run --gpus all --rm --shm-size=10.24gb -it run-alpa-image\n```\n\nThird, check alpa installation is correct:\n```bash\nconda activate alpa\n# Start ray:\nray start --head\n# Test Alpa can run correctly:\npython -m alpa.test_install\n```\n\nAlternatively, you can skip the interactive shell, and pass commands or job scripts via the `docker run` command to the container.\n"
  },
  {
    "path": "docker/build_alpa.Dockerfile",
    "content": "FROM quay.io/pypa/manylinux2014_x86_64\n\nWORKDIR /\nSHELL [\"/bin/bash\", \"-c\"]\nRUN yum-config-manager --add-repo http://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo\nRUN yum --enablerepo=epel -y install cuda-11-1\n\nCOPY scripts/build_alpa.sh /build_alpa.sh\nRUN chmod +x /build_alpa.sh\n\nWORKDIR /build\nENV TEST_TMPDIR /build\nENTRYPOINT [\"/build_alpa.sh\"]\n"
  },
  {
    "path": "docker/build_doc.Dockerfile",
    "content": "FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython\n\nWORKDIR /\nSHELL [\"/bin/bash\", \"-c\"]\nRUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list\nRUN apt-get update\nRUN apt-get install -y coinor-cbc glpk-utils python3-virtualenv\n\nRUN virtualenv --python=python3.8 python3.8-env\nRUN source python3.8-env/bin/activate && pip install --upgrade pip \\\n    && pip install numpy==1.20 setuptools wheel six auditwheel \\\n    sphinx sphinx-rtd-theme sphinx-gallery matplotlib\nCOPY scripts/build_doc.sh /build_doc.sh\nRUN chmod +x build_doc.sh\nENTRYPOINT [\"/build_doc.sh\"]\n"
  },
  {
    "path": "docker/build_jaxlib.Dockerfile",
    "content": "FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython\n\nWORKDIR /\nSHELL [\"/bin/bash\", \"-c\"]\nRUN sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub\nRUN sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub\nRUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list\nRUN apt-get update\nRUN apt-get install -y python3-virtualenv\n\nRUN virtualenv --python=python3.7 python3.7-env\nRUN virtualenv --python=python3.8 python3.8-env\nRUN virtualenv --python=python3.9 python3.9-env\n\n# We pin numpy to the minimum permitted version to avoid compatibility issues.\nRUN source python3.7-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel\nRUN source python3.8-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel\nRUN source python3.9-env/bin/activate && pip install --upgrade pip && pip install numpy==1.20 setuptools wheel six auditwheel\n\n# Change the CUDA version if it doesn't match the installed version in the base image\n# which is 10.0\nARG JAX_CUDA_VERSION=11.1\nCOPY scripts/install_cuda.sh /install_cuda.sh\nRUN chmod +x /install_cuda.sh\nRUN /bin/bash -c 'if [[ ! \"$CUDA_VERSION\" =~ ^$JAX_CUDA_VERSION.*$ ]]; then \\\n  /install_cuda.sh $JAX_CUDA_VERSION; \\\n  fi'\n\n\nWORKDIR /\nCOPY scripts/build_jaxlib_docker_entrypoint.sh /build_jaxlib_docker_entrypoint.sh\nRUN chmod +x /build_jaxlib_docker_entrypoint.sh\n\nWORKDIR /build\nENV TEST_TMPDIR /build\nENTRYPOINT [\"/build_jaxlib_docker_entrypoint.sh\"]\n"
  },
  {
    "path": "docker/coreweave/README.md",
    "content": "# Run Alpa in k8s cloud with InfiniBand (CoreWeave)\nTo run Alpa in specialized GPU cloud like [CoreWeave](https://coreweave.com/), we will need a few pieces in addition to [default run Alpa in Docker](../README.md):\n\n1. InfiniBand dependencies in Alpa docker image\n2. K8s deployment YAML file to declare Ray cluster resources\n3. Run NCCL with InfiniBand related environment variables such as `NCCL_IB_HCA`\n\nWe will go through each step to show you how to deploy Ray cluster in k8s cloud and run Alpa with InfiniBand.\n\nNote most of the content is re-usable for generic k8s and InfiniBand deployment where CoreWeave is the concrete cloud provider we used as verification.\n\n## Build Alpa docker image\n\nFirst, build a docker image based on the provided dockerfile:\n```bash\ndocker build -t run-alpa-image -f run_alpa_infiniband.Dockerfile .\n```\n\nThis docker file added InfiniBand dependencies in addition to the [default run_alpa.Dockerfile](../run_alpa.Dockerfile).\n\n## Tag and push your docker image\nThen tag and push your Alpa docker image to a public repository in docker.com.\n```bash\ndocker tag {image_hash} {your_docker}/{image}:{version}\n```\n```bash\ndocker push {your_docker}/{image}:{version}\n```\n\n## Write cluster.yaml file\nThen write your deployment script to use the Alpa docker image you just built in a k8s cloud.\nThe k8s deployment process can be summarized as the following steps in a nutshell:\n\n1. Define service/headnode/worker roles in the k8s deployment for the Ray cluster.\n2. Make physical resource requirements to the k8s cloud regarding GPU/CPU/RAM/InfiniBand/number of replicas.\n3. Pull the Alpa docker image you built with Ray.\n4. For each container, activate Alpa conda environment and run `ray start` to establish Ray runtime across the cluster.\n\n[Example end to end working YAML file](cluster.yaml)\n\nChange the `TODO` in sample YAML file to match your desired namespace, docker image and resource requirements.\n\n## Deploy to k8s\n\nThen we can use simple idempotent commands to start and terminate your Ray cluster to run Alpa.\n```bash\nkubectl apply -f cluster.yaml\n```\n\n```bash\nkubectl delete -f cluster.yaml\n```\n\n## Example end-to-end workflow\n\nOnce your cluster is started, you should be able to monitor all pods like this:\n```\n❯ k get pods\nNAME                                    READY   STATUS    RESTARTS   AGE\ndeployment-ray-head-d9dc9cf7f-pkqvz     1/1     Running   0          2m25s\ndeployment-ray-worker-d66d65c7b-25659   1/1     Running   0          2m24s\ndeployment-ray-worker-d66d65c7b-6sbpz   1/1     Running   0          2m24s\ndeployment-ray-worker-d66d65c7b-8smzr   1/1     Running   0          2m24s\n```\n\nYou can ssh into the headnode for interactive development and job submission.\n```bash\nkubectl exec --stdin --tty deployment-ray-head-d9dc9cf7f-pkqvz -- /bin/bash -i -l\n```\n\nThen activate alpa conda environment:\n```bash\nconda activate alpa\n\n```\n\nAnd verify your Ray cluster is running as expected.\n```\n(alpa) ray@deployment-ray-head-d9dc9cf7f-pkqvz:~$ ray status\n======== Autoscaler status: 2022-12-29 10:05:41.200229 ========\nNode status\n---------------------------------------------------------------\nHealthy:\n 1 node_a4328576d9fee799a5e6853acba0a6c1e1d8cb7fbabed6a6bab3649a\n 1 node_475ed937e3506d7f47ac1abc508e0eb7cde2a270d86a23fad3b9d0b2\n 1 node_347bc30b1fe0cc5f5730a6f803018fe2f3b6597226be69580995b436\n 1 node_8725d199fd3ef007abb673be6307a233a6f90f1001d8cd29aa873789\nPending:\n (no pending nodes)\nRecent failures:\n (no failures)\n\nResources\n---------------------------------------------------------------\nUsage:\n 0.0/128.0 CPU\n 0.0/32.0 GPU\n 0.0/4.0 accelerator_type:A100\n 0.00/197.961 GiB memory\n 0.00/86.199 GiB object_store_memory\n ```\n\n ## Environment variables for NCCL\n\n In order to enable InfiniBand for NCCL communication, you will need a few additional env vars, such as `NCCL_IB_HCA=ibp`. You can see the full list of configurations in [NCCL user guide](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html)\n\n ## Run Alpa's NCCL test\n\nAlpa uses cupy / ray collective / xla to orchestrate NCCL communcation.\nYou should be able to run the NCCL test [profile_communication](https://github.com/alpa-projects/alpa/blob/5660516ad3a29e5760673e599fc84aa604589a82/benchmark/cupy/profile_communication.py) in\n\n```bash\npython profile_communication.py --ib\n```\n\nOptionally add `--debug` to show NCCL logs to ensure InfiniBand is indeed used instead of Ethernet, as their AllReduce performance difference is expected to be very significant.\n\nSample output from a 4 node 8x80GB A100s NVLink cluster:\n\n```\nAllReduce: [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]]\tBytes: 2.00000 GB\tTime: 0.04278 s\tBandwidth: 90.59 GB/s\nAllReduce: [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]\tBytes: 2.00000 GB\tTime: 0.03842 s\tBandwidth: 97.59 GB/s\nAllReduce: [[0, 3]]\tBytes: 2.00000 GB\tTime: 0.01006 s\tBandwidth: 198.82 GB/s\nAllReduce: [[0, 4], [1, 5], [2, 6], [3, 7]]\tBytes: 2.00000 GB\tTime: 0.00994 s\tBandwidth: 201.30 GB/s\nAllReduce: [[0, 2, 4, 6], [1, 3, 5, 7]]\tBytes: 2.00000 GB\tTime: 0.01404 s\tBandwidth: 213.71 GB/s\nAllReduce: [[0, 1, 2, 3], [4, 5, 6, 7]]\tBytes: 2.00000 GB\tTime: 0.01406 s\tBandwidth: 213.31 GB/s\nAllReduce: [[0, 1, 2, 3, 4, 5, 6, 7]]\tBytes: 2.00000 GB\tTime: 0.01623 s\tBandwidth: 215.60 GB/s\n\nSendRecv: [[0, 1]]\tBytes: 2.00000 GB\tTime: 0.00814 s\tBandwidth: 245.59 GB/s\nSendRecv: [[0, 31]]\tBytes: 2.00000 GB\tTime: 0.15949 s\tBandwidth: 12.54 GB/s\nSendRecv: [[0, 1], [2, 3]]\tBytes: 2.00000 GB\tTime: 0.00815 s\tBandwidth: 490.84 GB/s\nSendRecv: [[0, 28], [1, 29]]\tBytes: 2.00000 GB\tTime: 0.17521 s\tBandwidth: 22.83 GB/s\nSendRecv: [[0, 30], [1, 31]]\tBytes: 2.00000 GB\tTime: 0.17519 s\tBandwidth: 22.83 GB/s\nSendRecv: [[0, 28], [1, 29], [2, 30], [3, 31]]\tBytes: 2.00000 GB\tTime: 0.17526 s\tBandwidth: 45.65 GB/s\nSendRecv: [[0, 24], [1, 25], [2, 26], [3, 27]]\tBytes: 2.00000 GB\tTime: 0.17486 s\tBandwidth: 45.75 GB/s\nSendRecv: [[0, 24], [1, 25], [2, 26], [3, 27], [4, 28], [5, 29], [6, 30], [7, 31]]\tBytes: 2.00000 GB\tTime: 0.17491 s\tBandwidth: 91.48 GB/s\n```"
  },
  {
    "path": "docker/coreweave/cluster.yaml",
    "content": "apiVersion: v1\nkind: Service\nmetadata:\n  namespace: tenant-jiaohpc-jd  # TODO: Change to your namespace\n  name: service-ray-cluster\n  labels:\n    app: ray-cluster\nspec:\n  ports:\n  - name: dashboard\n    protocol: TCP\n    port: 8265\n    targetPort: 8265\n  - name: gcs-server\n    protocol: TCP\n    port: 6380\n    targetPort: 6380\n  selector:\n    app: ray-cluster\n    component: ray-head\n---\napiVersion: apps/v1\nkind: Deployment\nmetadata:\n  namespace: tenant-jiaohpc-jd  # TODO: Change to your namespace\n  name: deployment-ray-head\n  labels:\n    app: ray-cluster\n    ray-node: head\nspec:\n  # Do not change this - Ray currently only supports one head node per cluster.\n  replicas: 1\n  selector:\n    matchLabels:\n      component: ray-head\n      type: ray\n      app: ray-cluster\n  template:\n    metadata:\n      labels:\n        component: ray-head\n        type: ray\n        app: ray-cluster\n    spec:\n      # If the head node goes down, the entire cluster (including all worker\n      # nodes) will go down as well. If you want Kubernetes to bring up a new\n      # head node in this case, set this to \"Always,\" else set it to \"Never.\"\n      restartPolicy: Always\n\n      # This volume allocates shared memory for Ray to use for its plasma\n      # object store. If you do not provide this, Ray will fall back to\n      # /tmp which cause slowdowns if is not a shared memory volume.\n      volumes:\n      - name: dshm\n        emptyDir:\n          medium: Memory\n      containers:\n        - name: ray-head\n          image: jiaodong/alpa:v1  # TODO: Change to your Alpa docker image\n          imagePullPolicy: IfNotPresent\n          # This volume allocates shared memory for Ray to use for its plasma]\n          # --login in required to have access to conda to activate alpa env\n          command: [\"/bin/bash\", \"-l\", \"-c\", \"--\"]\n          args:\n            - \"conda activate alpa && ray start --head --port=6380 --num-cpus=$MY_CPU_REQUEST --dashboard-host=0.0.0.0 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --redis-password='' --block\"\n          # This volume allocates shared memory for Ray to use for its plasma\n          # object store. If you do not provide this, Ray will fall back to\n          # /tmp which cause slowdowns if is not a shared memory volume.\n          volumeMounts:\n            - mountPath: /dev/shm\n              name: dshm\n          env:\n            # This is used in the ray start command so that Ray can spawn the\n            # correct number of processes. Omitting this may lead to degraded\n            # performance.\n            - name: MY_CPU_REQUEST\n              valueFrom:\n                resourceFieldRef:\n                  resource: requests.cpu\n          resources:\n            limits:\n              cpu: 32\n              memory: 64Gi\n              nvidia.com/gpu: 8\n              rdma/ib: 1\n      # Refer to CoreWeave's documentation for more details about GPU node types and placement\n      # https://docs.coreweave.com/coreweave-kubernetes/node-types\n      affinity:\n        nodeAffinity:\n          requiredDuringSchedulingIgnoredDuringExecution:\n            nodeSelectorTerms:\n            - matchExpressions:\n              - key: gpu.nvidia.com/class\n                operator: In\n                values:\n                  - A100_NVLINK_80GB\n---\napiVersion: apps/v1\nkind: Deployment\nmetadata:\n  namespace: tenant-jiaohpc-jd  # TODO: Change to your namespace\n  name: deployment-ray-worker\n  labels:\n    app: ray-cluster\nspec:\n  # Change this to scale the number of worker nodes started in the Ray cluster.\n  replicas: 3\n  selector:\n    matchLabels:\n      component: ray-worker\n      type: ray\n      app: ray-cluster\n  template:\n    metadata:\n      labels:\n        component: ray-worker\n        type: ray\n        app: ray-cluster\n    spec:\n      restartPolicy: Always\n      volumes:\n      - name: dshm\n        emptyDir:\n          medium: Memory\n      containers:\n      - name: ray-worker\n        image: jiaodong/alpa:v1  # TODO: Change to your Alpa docker image\n        imagePullPolicy: IfNotPresent\n        # --login in required to have access to conda to activate alpa env\n        command: [\"/bin/bash\", \"-l\", \"-c\", \"--\"]\n        args:\n          - \"conda activate alpa && ray start --num-cpus=$MY_CPU_REQUEST --address=service-ray-cluster:6380 --object-manager-port=8076 --node-manager-port=8077 --dashboard-agent-grpc-port=8078 --dashboard-agent-listen-port=8079 --min-worker-port=10002 --max-worker-port=19999 --block\"\n        # This volume allocates shared memory for Ray to use for its plasma\n        # object store. If you do not provide this, Ray will fall back to\n        # /tmp which cause slowdowns if is not a shared memory volume.\n        volumeMounts:\n          - mountPath: /dev/shm\n            name: dshm\n        env:\n          # This is used in the ray start command so that Ray can spawn the\n          # correct number of processes. Omitting this may lead to degraded\n          # performance.\n          - name: MY_CPU_REQUEST\n            valueFrom:\n              resourceFieldRef:\n                resource: requests.cpu\n        resources:\n          limits:\n            cpu: 32\n            memory: 64Gi\n            nvidia.com/gpu: 8\n            rdma/ib: 1\n      # Refer to CoreWeave's documentation for more details about GPU node types and placement\n      # https://docs.coreweave.com/coreweave-kubernetes/node-types\n      affinity:\n        nodeAffinity:\n          requiredDuringSchedulingIgnoredDuringExecution:\n            nodeSelectorTerms:\n            - matchExpressions:\n              - key: gpu.nvidia.com/class\n                operator: In\n                values:\n                  - A100_NVLINK_80GB"
  },
  {
    "path": "docker/coreweave/run_alpa_infiniband.Dockerfile",
    "content": "# base docker image\nFROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04\n\n# init workdir\nRUN mkdir -p /build\nWORKDIR /build\n\n# InfiniBand (IB) dependencies adopoted from CoreWeave's github\n# https://github.com/coreweave/nccl-tests\nARG DEBIAN_FRONTEND=noninteractive\nRUN apt-get -qq update && \\\n    apt-get -qq install -y --allow-change-held-packages --no-install-recommends \\\n    build-essential libtool autoconf automake autotools-dev unzip \\\n    ca-certificates \\\n    wget curl openssh-server vim environment-modules \\\n    iputils-ping net-tools \\\n    libnuma1 libsubunit0 libpci-dev \\\n    libpmix-dev \\\n    datacenter-gpu-manager\n\n# Mellanox OFED (latest)\nRUN wget -qO - https://www.mellanox.com/downloads/ofed/RPM-GPG-KEY-Mellanox | apt-key add -\nRUN cd /etc/apt/sources.list.d/ && wget https://linux.mellanox.com/public/repo/mlnx_ofed/latest/ubuntu18.04/mellanox_mlnx_ofed.list\n\nRUN apt-get -qq update \\\n    && apt-get -qq install -y --no-install-recommends \\\n    ibverbs-utils libibverbs-dev libibumad3 libibumad-dev librdmacm-dev rdmacm-utils infiniband-diags ibverbs-utils \\\n    && rm -rf /var/lib/apt/lists/*\n\n# HPC-X (2.12)\nENV HPCX_VERSION=2.12\nRUN cd /tmp && \\\n    wget -q -O - http://blobstore.s3.ord1.coreweave.com/drivers/hpcx-v${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl${HPCX_VERSION}-x86_64.tbz | tar xjf - && \\\n    mv hpcx-v${HPCX_VERSION}-gcc-MLNX_OFED_LINUX-5-ubuntu20.04-cuda11-gdrcopy2-nccl${HPCX_VERSION}-x86_64 /opt/hpcx\n\n# GDRCopy userspace components (2.3)\nRUN cd /tmp && \\\n    wget -q https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2011.4/x86/Ubuntu20.04/gdrcopy-tests_2.3-1_amd64.cuda11_4.Ubuntu20_04.deb && \\\n    wget -q https://developer.download.nvidia.com/compute/redist/gdrcopy/CUDA%2011.4/x86/Ubuntu20.04/libgdrapi_2.3-1_amd64.Ubuntu20_04.deb && \\\n    dpkg -i *.deb && \\\n    rm *.deb\n\n# Begin auto-generated paths\nENV HPCX_DIR=/opt/hpcx\nENV HPCX_UCX_DIR=/opt/hpcx/ucx\nENV HPCX_UCC_DIR=/opt/hpcx/ucc\nENV HPCX_SHARP_DIR=/opt/hpcx/sharp\nENV HPCX_NCCL_RDMA_SHARP_PLUGIN_DIR=/opt/hpcx/nccl_rdma_sharp_plugin\nENV HPCX_HCOLL_DIR=/opt/hpcx/hcoll\nENV HPCX_MPI_DIR=/opt/hpcx/ompi\nENV HPCX_OSHMEM_DIR=/opt/hpcx/ompi\nENV HPCX_MPI_TESTS_DIR=/opt/hpcx/ompi/tests\nENV HPCX_OSU_DIR=/opt/hpcx/ompi/tests/osu-micro-benchmarks-5.8\nENV HPCX_OSU_CUDA_DIR=/opt/hpcx/ompi/tests/osu-micro-benchmarks-5.8-cuda\nENV HPCX_IPM_DIR=/opt/hpcx/ompi/tests/ipm-2.0.6\nENV HPCX_CLUSTERKIT_DIR=/opt/hpcx/clusterkit\nENV OMPI_HOME=/opt/hpcx/ompi\nENV MPI_HOME=/opt/hpcx/ompi\nENV OSHMEM_HOME=/opt/hpcx/ompi\nENV OPAL_PREFIX=/opt/hpcx/ompi\nENV PATH=/opt/hpcx/clusterkit/bin:/opt/hpcx/hcoll/bin:/opt/hpcx/ucc/bin:/opt/hpcx/ucx/bin:/opt/hpcx/ompi/bin:$PATH\nENV LD_LIBRARY_PATH=/opt/hpcx/nccl_rdma_sharp_plugin/lib:/opt/hpcx/ucc/lib/ucc:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib/ucx:/opt/hpcx/ucx/lib:/opt/hpcx/sharp/lib:/opt/hpcx/hcoll/lib:/opt/hpcx/ompi/lib:$LD_LIBRARY_PATH\nENV LIBRARY_PATH=/opt/hpcx/nccl_rdma_sharp_plugin/lib:/opt/hpcx/ompi/lib:/opt/hpcx/sharp/lib:/opt/hpcx/ucc/lib:/opt/hpcx/ucx/lib:/opt/hpcx/hcoll/lib:/opt/hpcx/ompi/lib:/usr/local/cuda/lib64/stubs\nENV OLD_CPATH=\nENV CPATH=/opt/hpcx/ompi/include:/opt/hpcx/ucc/include:/opt/hpcx/ucx/include:/opt/hpcx/sharp/include:/opt/hpcx/hcoll/include:\nENV PKG_CONFIG_PATH=/opt/hpcx/hcoll/lib/pkgconfig:/opt/hpcx/sharp/lib/pkgconfig:/opt/hpcx/ucx/lib/pkgconfig:/opt/hpcx/ompi/lib/pkgconfig:\n# End of auto-generated paths\n\n# install common tool & conda\nRUN apt update && \\\n    apt install wget -y && \\\n    apt install git -y && \\\n    apt install vim -y && \\\n    wget --quiet https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh -O ~/anaconda.sh && \\\n    /bin/bash ~/anaconda.sh -b -p /opt/conda && \\\n    rm ~/anaconda.sh && \\\n    mkdir -p /opt/conda/envs/alpa && \\\n    ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \\\n    echo \". /opt/conda/etc/profile.d/conda.sh\" >> ~/.bashrc && \\\n    echo \"conda activate base\" >> ~/.bashrc\n\n# install conda alpa env\nRUN . /opt/conda/etc/profile.d/conda.sh && \\\n    conda create --name alpa python=3.8 -y && \\\n    conda activate alpa && \\\n    apt install coinor-cbc -y && \\\n    pip3 install --upgrade pip && \\\n    pip3 install cupy-cuda113 && \\\n    pip3 install alpa && \\\n    pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html\n\n# Execute in Alpa conda env\nENV PATH /opt/conda/envs/alpa/bin:$PATH"
  },
  {
    "path": "docker/run_alpa.Dockerfile",
    "content": "# base docker image\nFROM nvidia/cuda:11.3.0-cudnn8-devel-ubuntu20.04\n\n# init workdir\nRUN mkdir -p /build\nWORKDIR /build\n\n# install common tool & conda\nRUN apt update && \\\n    apt install wget -y && \\\n    apt install git -y && \\\n    apt install vim -y && \\\n    wget --quiet https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh -O ~/anaconda.sh && \\\n    /bin/bash ~/anaconda.sh -b -p /opt/conda && \\\n    rm ~/anaconda.sh && \\\n    mkdir -p /opt/conda/envs/alpa && \\\n    ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \\\n    echo \". /opt/conda/etc/profile.d/conda.sh\" >> ~/.bashrc && \\\n    echo \"conda activate base\" >> ~/.bashrc\n\n# install conda alpa env\nRUN . /opt/conda/etc/profile.d/conda.sh && \\\n    conda create --name alpa python=3.8 -y && \\\n    conda activate alpa && \\\n    apt install coinor-cbc -y && \\\n    pip3 install --upgrade pip && \\\n    pip3 install cupy-cuda113 && \\\n    pip3 install alpa && \\\n    pip3 install jaxlib==0.3.22+cuda113.cudnn820 -f https://alpa-projects.github.io/wheels.html\n"
  },
  {
    "path": "docker/scripts/build_alpa.sh",
    "content": "#!/bin/bash\nset -xev\nif [ ! -d \"/dist\" ]\nthen\n  echo \"/dist must be mounted to produce output\"\n  exit 1\nfi\n\nusage() {\n  echo \"usage: ${0##*/} [3.7|3.8|3.9] [alpa-branch]\"\n  exit 1\n}\n\nif [[ $# -lt 2 ]]\nthen\n  usage\nfi\n\nexport PY_VERSION=$1\n\nif [ $PY_VERSION = \"3.7\" ]; then\n  #alias python=\"/opt/python/cp37-cp37m/bin/python\"\n  ln -fs /opt/python/cp37-cp37m/bin/python /usr/bin/python3\n  python3 -m ensurepip --upgrade\n  python3 -m pip install cmake auditwheel pybind11\n  ln -fs /opt/python/cp37-cp37m/bin/pybind11-config /usr/bin/pybind11-config\nelif [ $PY_VERSION = \"3.8\" ]; then\n  #alias python=\"/opt/python/cp38-cp38/bin/python\"\n  ln -fs /opt/python/cp38-cp38/bin/python /usr/bin/python3\n  python3 -m ensurepip --upgrade\n  python3 -m pip install cmake auditwheel pybind11\n  ln -fs /opt/python/cp38-cp38/bin//pybind11-config /usr/bin/pybind11-config\nelif [ $PY_VERSION = \"3.9\" ]; then\n  #alias python=\"/opt/python/cp39-cp39/bin/python\"\n  ln -fs /opt/python/cp39-cp39/bin/python /usr/bin/python3\n  python3 -m ensurepip --upgrade\n  python3 -m pip install cmake auditwheel pybind11\n  ln -fs /opt/python/cp39-cp39/bin/pybind11-config /usr/bin/pybind11-config\nelse\n  echo \"Unsupported Python version: $PY_VERSION\"\n  exit 1\nfi\n\nALPA_BRANCH=\"$2\"\n\n# switch to the merge commit\ngit clone https://github.com/alpa-projects/alpa.git\ncd alpa\ngit fetch origin +${ALPA_BRANCH}\ngit checkout -qf FETCH_HEAD\n\n# install jaxlib and jax\npython3 update_version.py --git-describe\npython3 setup.py bdist_wheel sdist\n\n#if ! python3 -m auditwheel show dist/alpa-*.whl  | egrep 'platform tag: \"(manylinux2014_x86_64|manylinux_2_17_x86_64)\"' > /dev/null; then\n#  # Print output for debugging\n#  python3 -m auditwheel show dist/alpa-*.whl\n#  echo \"jaxlib wheel is not manylinux2014 compliant\"\n#  exit 1\n#fi\n\n#rename 'linux' manylinux2014 dist/*.whl\ncp -r dist/*whl /dist/\n"
  },
  {
    "path": "docker/scripts/build_doc.sh",
    "content": "#!/bin/bash\n\nset -xev\n\nif [ ! -d \"/alpa-dist\" ]\nthen\n  echo \"/alpa-dist must be mounted to produce output\"\n  exit 1\nfi\n\nsource /python3.8-env/bin/activate\npip install /alpa-dist/jaxlib-alpa-ci/jaxlib-0.3.5+cuda111.cudnn805-cp38-none-manylinux2010_x86_64.whl\npip install jax==0.3.5\n\ngit clone https://github.com/alpa-projects/alpa.git\ncd alpa\npip install cupy-cuda111\npython -m cupyx.tools.install_library --library nccl --cuda 11.1\npip install -e .[doc]\ncd /alpa/docs\nmake html\ncp -r _build/html/* /alpa-dist/docs/\n"
  },
  {
    "path": "docker/scripts/build_jaxlib_docker_entrypoint.sh",
    "content": "#!/bin/bash\n# Adapted from https://github.com/alpa-projects/jax-alpa/blob/main/build/build_wheel_docker_entrypoint.sh\nset -xev\nif [ ! -d \"/dist\" ]\nthen\n  echo \"/dist must be mounted to produce output\"\n  exit 1\nfi\n\nexport CC=/dt8/usr/bin/gcc\nexport GCC_HOST_COMPILER_PATH=/dt8/usr/bin/gcc\nexport CUDA_PATH=/usr/local/cuda\nexport LD_LIBRARY_PATH=$CUDA_PATH/lib64:$LD_LIBRARY_PATH\n\nusage() {\n  echo \"usage: ${0##*/} [3.7|3.8|3.9] [cuda|nocuda] [11.1|11.2|11.3] [alpa branch name] [tensorflow-alpa branch name]\"\n  exit 1\n}\n\nif [[ $# -lt 3 ]]\nthen\n  usage\nfi\n\nPY_VERSION=\"$1\"\necho \"Python version $PY_VERSION\"\n\n# switch tensorflow-alpa branch if necessary\ngit clone --recursive https://github.com/alpa-projects/alpa.git\n\n# switch alpa branch\nif [[ $# -eq 4 ]]\nthen\n  ALPA_BRANCH=\"$4\"\n  echo \"Switch to alpa branch ALPA_BRANCH\"\n  cd /build/alpa\n  git fetch origin +${ALPA_BRANCH}\n  git checkout -qf FETCH_HEAD\n  git submodule update --recursive\nfi\n\n# switch tensorflow-alpa branch, this will overwrite the above\nif [[ $# -eq 5 ]]\nthen\n  TF_BRANCH=\"$5\"\n  echo \"Switch to tensorflow-alpa branch $TF_BRANCH\"\n  cd /build/alpa/third_party/tensorflow-alpa\n  git fetch origin +${TF_BRANCH}\n  git checkout -qf FETCH_HEAD\nfi\n\nmkdir /build/tmp\nmkdir /build/root\nexport TMPDIR=/build/tmp\n\n# Builds and activates a specific Python version.\nsource /python${PY_VERSION}-env/bin/activate\n\n# Workaround for https://github.com/bazelbuild/bazel/issues/9254\nexport BAZEL_LINKLIBS=\"-lstdc++\"\nexport JAX_CUDA_VERSION=$3\nexport CUPY_VERSION=${JAX_CUDA_VERSION//.}\n\nif [ $JAX_CUDA_VERSION = \"11.0\" ]; then\n  export JAX_CUDNN_VERSION=\"805\"\nelif [ $JAX_CUDA_VERSION = \"11.1\" ]; then\n  export JAX_CUDNN_VERSION=\"805\"\nelif [ $JAX_CUDA_VERSION = \"11.2\" ]; then\n  export JAX_CUDNN_VERSION=\"810\"\nelif [ $JAX_CUDA_VERSION = \"11.3\" ]; then\n  export JAX_CUDNN_VERSION=\"820\"\nelif [ $JAX_CUDA_VERSION = \"11.4\" ]; then\n  export JAX_CUDNN_VERSION=\"822\"\nelse\n  echo \"Unknown CUDNN version for CUDA version: $JAX_CUDA_VERSION\"\n  exit 1\nfi\n\n\n# install cupy\npip install cupy-cuda${JAX_CUDA_VERSION//.}\npython -m cupyx.tools.install_library --library nccl --cuda $JAX_CUDA_VERSION\n\n# start building\ncd /build/alpa/build_jaxlib\ncase $2 in\n  cuda)\n    python build/build.py --enable_cuda --bazel_startup_options=\"--output_user_root=/build/root\" --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa \n    ;;\n  nocuda)\n    python build/build.py --enable_tpu --bazel_startup_options=\"--output_user_root=/build/root\" --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa\n    ;;\n  *)\n    usage\nesac\n\nif ! python -m auditwheel show dist/jaxlib-*.whl | egrep 'platform tag: \"(manylinux2014_x86_64|manylinux_2_17_x86_64)\"' > /dev/null; then\n  # Print output for debugging\n  python -m auditwheel show dist/jaxlib-*.whl\n  echo \"jaxlib wheel is not manylinux2014 compliant\"\n  exit 1\nfi\ncp -r dist/* /dist\n"
  },
  {
    "path": "docker/scripts/install_cuda.sh",
    "content": "#!/bin/bash\nset -xe\n\nCUDA_VERSION=$1\n\nLIBCUDNN=libcudnn7\nif [ $CUDA_VERSION = \"10.0\" ]; then\n  CUBLAS=libcublas10\n  CUBLAS_DEV=libcublas-dev\nelif [ $CUDA_VERSION = \"10.1\" ]; then\n  # Have to pin to libcublas10=10.2.1.243-1 due to bug in TF, see\n  # https://github.com/tensorflow/tensorflow/issues/9489#issuecomment-562394257\n  CUBLAS=libcublas10=10.2.1.243-1\n  CUBLAS_DEV=libcublas-dev=10.2.1.243-1\nelif [ $CUDA_VERSION = \"10.2\" ]; then\n  CUBLAS=libcublas10\n  CUBLAS_DEV=libcublas-dev\n  CUDNN_VERSION=7.6.5.32\nelif [ $CUDA_VERSION = \"11.0\" ]; then\n  CUBLAS=libcublas-11-0\n  CUBLAS_DEV=libcublas-dev-11-0\n  CUDNN_VERSION=8.0.5.39\n  LIBCUDNN=libcudnn8\nelif [ $CUDA_VERSION = \"11.1\" ]; then\n  CUBLAS=libcublas-11-1\n  CUBLAS_DEV=libcublas-dev-11-1\n  CUDNN_VERSION=8.0.5.39\n  LIBCUDNN=libcudnn8\nelif [ $CUDA_VERSION = \"11.2\" ]; then\n  CUBLAS=libcublas-11-2\n  CUBLAS_DEV=libcublas-dev-11-2\n  CUDNN_VERSION=8.1.0.77\n  LIBCUDNN=libcudnn8\nelif [ $CUDA_VERSION = \"11.3\" ]; then\n  CUBLAS=libcublas-11-3\n  CUBLAS_DEV=libcublas-dev-11-3\n  CUDNN_VERSION=8.2.0.53\n  LIBCUDNN=libcudnn8\nelif [ $CUDA_VERSION = \"11.4\" ]; then\n  CUBLAS=libcublas-11-4\n  CUBLAS_DEV=libcublas-dev-11-4\n  CUDNN_VERSION=8.2.2.26\n  LIBCUDNN=libcudnn8\nelse\n  echo \"Unsupported CUDA version: $CUDA_VERSION\"\n  exit 1\nfi\n\necho \"Installing cuda version: $CUDA_VERSION\"\necho \"cudnn version: $CUDNN_VERSION\"\n\napt-key adv --keyserver keyserver.ubuntu.com --recv-keys A4B469963BF863CC\napt-get update\napt-get remove -y --allow-change-held-packages -f cuda-license-10-0 libnccl-dev libcudnn7 libcudnn8 libnccl2\napt-get install -y --no-install-recommends --allow-downgrades \\\n  $CUBLAS \\\n  $CUBLAS_DEV \\\n  cuda-nvml-dev-$CUDA_VERSION \\\n  cuda-command-line-tools-$CUDA_VERSION \\\n  cuda-libraries-dev-$CUDA_VERSION \\\n  cuda-minimal-build-$CUDA_VERSION \\\n  $LIBCUDNN=$CUDNN_VERSION-1+cuda$CUDA_VERSION \\\n  $LIBCUDNN-dev=$CUDNN_VERSION-1+cuda$CUDA_VERSION\nrm -f /usr/local/cuda\nln -s /usr/local/cuda-$CUDA_VERSION /usr/local/cuda\n"
  },
  {
    "path": "docker/scripts/install_torch.sh",
    "content": "#!/bin/bash\nset -xe\n\ninstall_torch_deps() {\n    # NOTE: functorch is pinned to the last commit that works with PyTorch 1.12\n    pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==1.12 torchdistx && \\\n    ([ -d \"functorch\" ] || git clone https://github.com/pytorch/functorch) && \\\n    pushd functorch && git checkout 76976db8412b60d322c680a5822116ba6f2f762a && python setup.py install && popd\n}\n\ninstall_torch_deps\n"
  },
  {
    "path": "docker/scripts/test_alpa_docker_entrypoint.sh",
    "content": "#!/bin/bash\nset -xev\nif [ ! -d \"/alpa-dist\" ]\nthen\n  echo \"/alpa-dist must be mounted to produce output\"\n  exit 1\nfi\n\nusage() {\n  echo \"usage: ${0##*/} [3.7|3.8|3.9] [alpa-branch]\"\n  exit 1\n}\n\nif [[ $# -lt 2 ]]\nthen\n  usage\nfi\n\nexport PY_VERSION=$1\nALPA_BRANCH=\"$2\"\n\n# Enter python env\nsource /python${PY_VERSION}-env/bin/activate\n# switch to the merge commit\ngit clone https://github.com/alpa-projects/alpa.git\ncd /build/alpa\ngit fetch origin +${ALPA_BRANCH}\ngit checkout -qf FETCH_HEAD\n\n# install jaxlib and jax\npip install /alpa-dist/jaxlib-alpa-ci/jaxlib-0.3.22+cuda111.cudnn805-cp38-cp38-manylinux2014_x86_64.whl\npip install jax==0.3.22\n\n# install cupy\npip install cupy-cuda111\npython -m cupyx.tools.install_library --library nccl --cuda 11.1\npip install -e .[dev]\nray start --head\ncd tests\npython run_all.py\n"
  },
  {
    "path": "docker/unittest.Dockerfile",
    "content": "FROM gcr.io/tensorflow-testing/nosla-cuda11.1-cudnn8-ubuntu18.04-manylinux2010-multipython\n\nWORKDIR /\nSHELL [\"/bin/bash\", \"-c\"]\nRUN rm -f /etc/apt/sources.list.d/jonathonf-ubuntu-python-3_6-xenial.list\n# Fetch latest pub key so apt-get works.\nRUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub\nRUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub\nRUN apt-get update\nRUN apt-get install -y python3-virtualenv\nRUN virtualenv --python=python3.7 python3.7-env\nRUN virtualenv --python=python3.8 python3.8-env\nRUN virtualenv --python=python3.9 python3.9-env\n\n# We pin numpy to the minimum permitted version to avoid compatibility issues.\nRUN source python3.7-env/bin/activate && pip install --upgrade pip \\\n  && pip install numpy==1.20 setuptools wheel six auditwheel \\\n  tqdm scipy numba pulp tensorstore prospector yapf coverage cmake \\\n  pybind11 ray[default] matplotlib transformers uvicorn fastapi\nRUN source python3.8-env/bin/activate && pip install --upgrade pip \\\n  && pip install numpy==1.20 setuptools wheel six auditwheel \\\n  tqdm scipy numba pulp tensorstore prospector yapf coverage cmake  \\\n  pybind11 ray[default] matplotlib transformers uvicorn fastapi\nRUN source python3.9-env/bin/activate && pip install --upgrade pip \\\n  && pip install numpy==1.20 setuptools wheel six auditwheel \\\n  tqdm scipy numba pulp tensorstore prospector yapf coverage cmake  \\\n  pybind11 ray[default] matplotlib transformers uvicorn fastapi\n\n# Install PyTorch dependencies\nWORKDIR /\nCOPY scripts/install_torch.sh /install_torch.sh\nRUN chmod +x /install_torch.sh\nRUN source python3.7-env/bin/activate && /install_torch.sh\nRUN source python3.8-env/bin/activate && /install_torch.sh\nRUN source python3.9-env/bin/activate && /install_torch.sh\n\n# We determine the CUDA version at `docker build ...` phase\nARG JAX_CUDA_VERSION=11.1\nCOPY scripts/install_cuda.sh /install_cuda.sh\nRUN chmod +x /install_cuda.sh\nRUN /bin/bash -c 'if [[ ! \"$CUDA_VERSION\" =~ ^$JAX_CUDA_VERSION.*$ ]]; then \\\n  /install_cuda.sh $JAX_CUDA_VERSION; \\\n  fi'\n\n# Install cupy\nRUN source python3.7-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.}\nRUN source python3.8-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.}\nRUN source python3.9-env/bin/activate && pip install cupy-cuda${JAX_CUDA_VERSION//.}\n\nWORKDIR /\nCOPY scripts/test_alpa_docker_entrypoint.sh /test_alpa_docker_entrypoint.sh\nRUN chmod +x /test_alpa_docker_entrypoint.sh\n\nWORKDIR /build\nENV TEST_TMPDIR /build\nENTRYPOINT [\"/test_alpa_docker_entrypoint.sh\"]\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line, and also\n# from the environment for the first two.\nSPHINXOPTS    ?=\nSPHINXBUILD   ?= sphinx-build\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\nclean:\n\trm -rf $(BUILDDIR)/*\n\trm -rf tutorials/\n"
  },
  {
    "path": "docs/README.md",
    "content": "# Alpa Documentation\n\n## Build the documentation website\n\n### Dependency\n```\npip3 install sphinx sphinx-rtd-theme sphinx-gallery matplotlib\n```\n\n### Build\n```\nmake html\n```\n\nThe build process will execute all tutorial scripts to generate the gallery.\nThis may cause failures if the build machine does not have necessary environment.\nThis may also result in a very long build time.\nYou can set `ALPA_TUTORIAL_EXEC_PATTERN` to only execute the files that match the regular expression pattern.\nFor example, to build one specific file, do\n```\nexport ALPA_TUTORIAL_EXEC_PATTERN=filename.py\nmake html\n```\nTo skip execution of all tutorials, do\n```\nexport ALPA_TUTORIAL_EXEC_PATTERN=none\nmake html\n```\n\n### Clean\nTo remove all generated files:\n```\nmake clean\n```\n\n### Serve\nRun an HTTP server and visit http://localhost:8000 in your browser.\n```\npython3 -m http.server --d _build/html\n```\n\n### Publish\nClone [alpa-projects.github.io](https://github.com/alpa-projects/alpa-projects.github.io) and make sure you have write access.\n\n```bash\nexport ALPA_SITE_PATH=~/efs/alpa-projects.github.io   # update this with your path\n./publish.py\n```\n\n## Add new documentations\nAlpa uses [Sphinx](https://www.sphinx-doc.org/en/master/index.html) to generate static documentation website and use [Sphinx-gallery](https://sphinx-gallery.github.io/stable/index.html) to generate gallery examples.\n\nYour new example should be created under `docs/gallery`. \n\n### Define the Order of Tutorials\nYou can define the order of tutorials with `subsection_order` and\n`within_subsection_order` in [`conf.py`](conf.py).\nBy default, the tutorials within one subsection are sorted by filename.\n"
  },
  {
    "path": "docs/architecture/alpa_compiler_walk_through.rst",
    "content": ".. _Alpa Compiler Walk-Through:\n\n==========================\nAlpa Compiler Walk-Through\n==========================\n\nThis document provides a walk-through of the compiler part of Alpa.\n\n.. note::\n  This document is based on the workflow as in `this commit <https://github.com/alpa-projects/alpa/tree/388594f>`__. While some specific details might not be the same as in the latest version, the general idea should be the same.\n\n\nStarting from an arbitrary JAX function (i.e., computational graph) of a neural network training step, Alpa’s overall workflow includes the following steps:\n\n1. **Layer construction:** Cluster different operators in the\n   computational graph into a sequential list of pipeline layers.\n2. **Stage construction:** Cluster the pipeline layers into pipeline\n   stages and assign each stage a subset of devices for pipeline\n   execution (i.e., inter-operator parallelism).\n3. **Auto sharding:** Figure out how to shard each operator within each pipeline stage on its corresponding devices with SPMD parallelism (i.e., intra-operator parallelism).\n\nLet’s start with the following code snippet:\n\n.. code:: python\n\n   class ManualPipelineMLPModel(nn.Module):\n       hidden_dim: int\n\n       @nn.compact\n       def __call__(self, x):\n           x = nn.Dense(features=self.hidden_dim * 4)(x)\n           x = nn.relu(x)\n           x = nn.Dense(features=self.hidden_dim)(x)\n           x = nn.relu(x)\n           # Use this boundary marker to separate the model into two stages.\n           alpa.mark_pipeline_boundary()\n           x = nn.Dense(features=self.hidden_dim * 4)(x)\n           x = nn.relu(x)\n           x = nn.Dense(features=self.hidden_dim)(x)\n           x = nn.relu(x)\n           return x\n\n   @alpa.parallelize(method=alpa.PipeshardParallel(num_micro_batches=16,\n                                                   layer_option=\"manual\"))\n   def manual_pipeline_train_step(state, batch):\n\n       def loss_func(params):\n           out = state.apply_fn(params, batch[\"x\"])\n           loss = jnp.mean((out - batch[\"y\"])**2)\n           return loss\n       # Use `alpa.grad` here to slice the forward/backward stages and the\n       # gradient update stage\n       grads = alpa.grad(loss_func)(state.params)\n       new_state = state.apply_gradients(grads=grads)\n       return new_state\n\nCompared to original JAX/Flax, this code snippet additionally calls ``alpa.mark_pipeline``, ``alpa.parallelize``, and ``alpa.grad``. Below, we will show how Alpa uses these functions and decorators to compile the original single device computational graph into a distributed version.\n\nLayer Construction\n==================\n\nThe first transformation we perform is in ``alpa.grad``\n(`link <https://github.com/alpa-projects/alpa/blob/388594f00d1ee0fe4dc0d51c2d8567da13226fdf/alpa/api.py#L213>`__)\nfor layer construction. It is a thin wrapper of the original ``jax.grad`` in JAX,\nwhich additionally performs the following tasks:\n\n1. Process pipeline markers to form forward pipeline layers.\n2. Call the original ``jax.grad``. We directly use JAX's autograd to map\n   the forward layers to the backward layers.\n3. Mark all the gradients with a special marker so that we can perform\n   gradient accumulation for them.\n4. Mark all the operators after the gradient computation as the\n   gradient update phase.\n\nWe form the pipeline layers by inserting pipeline markers into the JAX\nautomatically or manually with user annotations.\n``layer_option=\"manual\"`` in the code example above indicates that we\nare inserting the markers manually.\n\nThe definition of pipeline markers can be found in\n`primitive_def.py <https://github.com/alpa-projects/alpa/blob/388594f00d1ee0fe4dc0d51c2d8567da13226fdf/alpa/pipeline_parallel/primitive_def.py>`__.\nWe define a new JAX primitive ``pipeline_p`` and an XLA custom call\n``pipeline_marker``. All these markers behave exactly the same as an\nidentity function that returns all the input\narguments.\n\nWe distinguish between ``start`` and ``end`` markers. The ``start``\nmarker captures all the inputs to a pipeline layer, and the ``end`` marker captures the outputs. To preserve the forward/backward\nstage mapping, we set the gradient of a ``start`` marker to be an ``end``\nmarker, and the gradient of an ``end`` to be a ``start``.\n\nA complete pipeline layer has the following structure:\n\n::\n\n   marked_inputs = pipeline_marker[type=\"start\"] layer_inputs\n   ...\n   layer_outputs = some_jax_operator marked_inputs\n   ...\n   marked_outputs = pipeline_marker[type=\"end\"] layer_outputs\n\nNote that all the inputs of the JAX operators within the pipeline layer\nshould take the marked inputs or the intermediate results within the\nlayer. All the outputs of the layer will be marked by the ``end``\nmarker.\n\nIn the manual case, we provide a simpler API that doesn’t require two\nmarkers for a stage and the users do not need to specify the input and\noutput variables. Instead, the users only need to call\n``alpa.mark_pipeline_boundary`` at the boundary of two pipeline layers.\nThe ``layer_level_jaxpr_transformation`` function\n(`link <https://github.com/alpa-projects/alpa/blob/388594f00d1ee0fe4dc0d51c2d8567da13226fdf/alpa/pipeline_parallel/layer_construction.py#L424-L432>`__)\nwill transform it to the above form.\n\n**Note:** Alpa can also perform rematerialization (i.e., gradient checkpointing) at these pipeline stage\nboundaries. See these functions:\n`link <https://github.com/alpa-projects/alpa/blob/388594f00d1ee0fe4dc0d51c2d8567da13226fdf/alpa/pipeline_parallel/layer_construction.py#L475-L547>`__.\n\nStage Construction\n==================\n\nThe transformed function with layer markers is then transformed by\n``@alpa.parallelize``. The most important option of\n``@alpa.parallelize`` is ``method``, which specifies which type of\nparallelism to use. Here we set it to ``alpa.PipeshardParallel``,\nindicating that we are using both pipeline parallelism (inter-operator\nparallelism) and SPMD-shard parallelism (intra-operator parallelism).\n\n``@alpa.parallelize`` transforms the original function to a\n``ParallelizedFunc``. ``ParallelizedFunc`` is a Python class that\nbehaves like the original function but with some additional methods.\n``ParallelizedFunc`` flattens the input arguments, and will compile the\nJAX function according to the ``method``. In our case, it eventually\ncalls ``compile_pipeshard_executable()``\n`here <https://github.com/alpa-projects/alpa/blob/388594f00d1ee0fe4dc0d51c2d8567da13226fdf/alpa/pipeline_parallel/compile_executable.py#L42-L50>`__,\nwhich transforms the input as follows:\n\n1. ``compile_pipeshard_executable`` first traces the original function\n   to JAXPR. Note that we trace the function with both full batch size\n   and the smaller micro-batch size for gradient accumulation. Then we\n   call into ``compile_pipeshard_executable_internal``.\n\n2. ``split_compute_grad_and_apply_grad`` splits the ``apply_grad`` part\n   from the rest of the function. There is a special transformation for\n   the case where a single parameter ``x`` is used in multiple pipeline\n   layers ``l1(x)``, ``l2(x)``, ... For example in language models' tied-embedding layer, the embedding matrix is used by both the first\n   and the last stage. In this case, the backward pass of JAX will\n   generate some equations that are not captured by pipeline markers to\n   calculate the gradient to ``x``: ``grad_x = grad_l1_x + grad_l2_x``.\n   We move these kinds of equations to the ``apply_grad`` part and let\n   each layer perform gradient accumulation separately.\n\n3. ``compute_grad_to_accumulate_grad`` transforms the original\n   a ``compute_grad`` JAXPR that only computes gradient to\n   an ``accumulate_grad`` JAXPR that performs gradient accumulation. More\n   specifically, the structure of ``accumulate_grad`` is shown in the following pseudo-code:\n\n   .. code:: python\n\n      def accumulate_grad(compute_grad_inputs, accumulated_grad):\n          grad = compute_grad(compute_grad_inputs)\n        accumulated_grad += grad\n          return accumulated_grad\n\n   Note that the ``+=`` above is only correct when the gradients can be\n   summed up. When the output is per input data (e.g., inference\n   output), we use ``concat`` instead of ``+=``. The analysis of which\n   operator to use is done in ``_get_full_batch_apply_grad`` by\n   comparing full-batch and micro-batch codes.\n\n4. ``slice_closed_jaxpr_by_full_pipeline_marks`` slices the\n   ``accumulate_grad`` JAXPR into many pipeline layers.\n\n5. ``mark_missing_vars_in_backward_computation_pipeline_marks``. When\n   JAX derives the backward JAXPR, the backward layer will directly use\n   the intermediate results of the forward layer instead of adding it\n   to the backward layer’s start pipeline marker. This function fixes\n   this issue. In addition, it removes all ``Literal`` in start markers\n   and all ``DropVar`` in end markers.\n\n6. ``cluster_layers_and_slice_mesh`` performs stage construction. it\n   clusters different pipeline layers into pipeline stages, slice the\n   compute cluster represented as a 2D device mesh into many submeshes,\n   and assign each stage a submesh. Right now, a forward layer and its\n   corresponding backward layer will always be on the same submesh. See\n   the full automatic algorithm in `the Alpa paper <https://arxiv.org/abs/2201.12023>`__.\n\n7. ``process_apply_gradient`` splits the single ``apply_grad`` JAXPR into\n   #submeshes parts, each part processes the gradient updates and\n   optimizer states related to the variables on a specific submesh.\n\n8. ``create_donation_mapping`` and ``split_donate_invars``: Process\n   donated invars for each pipeline stage, and also add donation variables for gradient accumulation.\n\nAuto Sharding\n=============\n\nThen, in ``shard_each_stage`` we run the auto-sharding pass for each\npipeline stage. Because we include distributed compilation for\ndifferent stages to accelerate the compilation, the code is nested here.\nSpecifically, the following two functions are the two most important ones:\n\n1. In ``generate_sharded_xla_computations_arguments``\n   (`code <https://github.com/alpa-projects/alpa/blob/388594f00d1ee0fe4dc0d51c2d8567da13226fdf/alpa/pipeline_parallel/computation.py#L827>`__),\n   we concat the JAXPRs of all stages on a submesh (which typically\n   include forward/backward/update of a single stage) and compile it to\n   an ``HLOModule``.\n2. Then we call ``run_auto_sharding_pass``\n   (`code <https://github.com/alpa-projects/alpa/blob/388594f00d1ee0fe4dc0d51c2d8567da13226fdf/alpa/shard_parallel/auto_sharding.py#L183>`__),\n   which eventually calls ``RunAutoShardingPass`` we wrote in XLA\n   (`code <https://github.com/alpa-projects/tensorflow-alpa/blob/445b4588a93c01a155053d6b77f4621b5f704a68/tensorflow/compiler/xla/service/spmd/alpa_compile.cc#L89-L90>`__).\n   This XLA function:\n\n   1. First run a subset of XLA passes before SPMD partitioner.\n   2. Then we run the Alpa ``AutoSharding`` pass\n      (`code <https://github.com/alpa-projects/tensorflow-alpa/blob/445b4588a93c01a155053d6b77f4621b5f704a68/tensorflow/compiler/xla/service/spmd/auto_sharding.cc>`__)\n      that automatically annotate the graph with GSPMD annotations.\n   3. Then run the ``SliceAutoShardedStages`` pass\n      (`code <https://github.com/alpa-projects/tensorflow-alpa/blob/445b4588a93c01a155053d6b77f4621b5f704a68/tensorflow/compiler/xla/service/spmd/slice_auto_sharded_stages.cc>`__)\n      that slices the concated stages back to individual stages, and\n      return these stages back to Python.\n\nThe result of ``shard_each_stage`` will be a list of SPMD sharded\npipeline stages. Then the whole pipeline and sharding execution schedule\nwill be summarized and organized via a ``PipelineInstEmitter``\n(`code <https://github.com/alpa-projects/alpa/blob/388594f00d1ee0fe4dc0d51c2d8567da13226fdf/alpa/pipeline_parallel/compile_executable.py#L221-L233>`__).\nThe result ``pipeshard_config`` will be sent to the runtime to be\nexecuted.\n\n.. note::\n  To debug and visualize each step, you can debug via simply adding print instructions to the JAXPR in Python or the HLO in XLA.\n"
  },
  {
    "path": "docs/architecture/intra_op_solver.rst",
    "content": "=====================================\nCode Structure of the Intra-op Solver\n=====================================\n\nThe specific code of the intra-op solver (a.k.a auto-sharding) is scattered\nin various files of the project.\nThis page contains some pointers to key components of the intra-op solver and\nhelp you navigate the complicated code base.\n\n.. note::\n\n  All the links below are based on alpa v0.2.2\n\n\nKey Pointers\n============\n\n- Main entrance:\n   - python entrance (``run_auto_sharding_pass``): https://github.com/alpa-projects/alpa/blob/181de4f5577a72c9b30525ed3da09e5b2138cc2c/alpa/shard_parallel/auto_sharding.py#L172\n   - c++ entrance: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L2124\n\n- Where the possible sharding strategies are registred:\n   - for matmul: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding_dot_handler.cc#L327-L408\n   - for elementwise operators: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L967-L1016\n\n- Where the ILP solver is called:\n   - c++ side: https://github.com/alpa-projects/tensorflow-alpa/blob/cd865615b9b518bc507fbdc71dc44c7cc76618ac/tensorflow/compiler/xla/service/spmd/auto_sharding.cc#L2259\n   - python side: https://github.com/alpa-projects/alpa/blob/181de4f5577a72c9b30525ed3da09e5b2138cc2c/alpa/shard_parallel/auto_sharding.py#L588\n\n\nHow to Read and Learn the Code\n==============================\n.. _learn-intra-op-solver:\n\nRun some simple examples\n~~~~~~~~~~~~~~~~~~~~~~~~\nYou can run the unit tests under https://github.com/alpa-projects/alpa/tree/v0.2.2/tests/shard_parallel and set break points in the python entrance ``run_auto_sharding_pass``.\nYou can start from the most basic ones in ``test_basic.py``.\n\nInspect the sharding strategy\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\nYou can print the HLO before and after the ``run_auto_sharding_pass``.\n\n\nHow to Debug\n============\n- Set global environment variable ``ALPA_DEBUG_PRINT_AS_STRATEGY=1``. This will print the choosen sharding strategy for each instruction and edge costs in a prettier way.\n- Check batch dim analysis https://github.com/alpa-projects/tensorflow-alpa/blob/721260d122f096040762b2d226b37e8ab23f74b8/tensorflow/compiler/xla/service/spmd/auto_sharding_util.cc#L857\n"
  },
  {
    "path": "docs/architecture/overview.rst",
    "content": "=======================\nDesign and Architecture\n=======================\n\nThis document aims to describe the architecture of Alpa and explain several core concepts and compilation passes introduced by Alpa at a high level. It provides an overview of Alpa's architecture, including core terms and componenents introduced by Alpa. In :ref:`Alpa Compiler Walk-Through <Alpa Compiler Walk-Through>`, we further show the workflow of Alpa using an MLP example.\n\n\nYou are recommended to read the the following materials as well:\n\n- `Alpa paper <https://arxiv.org/pdf/2201.12023.pdf>`_ (OSDI'22)\n- `Google AI blog <https://ai.googleblog.com/2022/05/alpa-automated-model-parallel-deep.html>`_\n- `Alpa talk slides <https://docs.google.com/presentation/d/1CQ4S1ff8yURk9XmL5lpQOoMMlsjw4m0zPS6zYDcyp7Y/edit?usp=sharing>`_\n\nOverview\n========\n\n:ref:`The figure below <architecture>` shows a high-level diagram of Alpa's architecture.\n\n.. _architecture:\n\n.. figure:: alpa-arch.png\n  :align: center\n  :width: 450px\n\n  Figure 1: Alpa architecture diagram.\n\nLike many existing machine learning compilers, Alpa parallelizes the ML computation in two steps: a compilation step, followed by a runtime step.\n\nIn the compilation step, Alpa takes a model description, in the form of a :ref:`computational graph<cg>`, and a :ref:`device cluster<device-cluster>` as inputs, and performs a few compilation passes and optimizations to generate\na model-parallel execution plan, which is *custom-made* for the model and cluster. Alpa then generates binary executables based on the training code and parallel execution plan, for each parcipating compute device in the cluster.\nIn the runtime step, Alpa orchestrates the parallel execution of these executables on the cluster.\n\nCompilation\n===========\n\nBefore we start introducing the compilation architecture, we bring in two important concepts introduced by Alpa.\nUnlike many existing distributed ML training systems, Alpa views existing ML parallelization approaches into two orthogonal categories:\n**intra-operator parallelism** and **inter-operator parallelism**. They are distinguished by the fact that if the parallelism approach involves partitioning any computational operator of the model along one (or more) tensor axis.\nSome examples falling into the two categories are listed below:\n\n- **Intra-op parallelism**: data parallelism, Megatron-LM's tensor model parallelism, operator parallelism such as those in ToFu and FlexFlow, etc.\n- **Inter-op parallelism**: device placement, pipeline parallelism and their variants.\n\nFor a deeper dive into what these two classes of parallelism entail, please read the documentation about our rationale.\n\nThis new view of ML parallelization techniques is the core part that drives Alpa's design: Alpa unifies existing ML parallelization methods following this\nview by realizing them in a two-level hierarchy shown in :ref:`Figure 1<architecture>`. At the upper level, Alpa designs a set of algorithms and compilation passes, which we call\n**inter-op pass** to generate parallel execution plan corresponding to all inter-op parallelisms; at the lower level, Alpa designs another set of algorithms and\ncompilation passes, which we call **intra-op pass**, to generate the parallel execution plan mapping to all intra-op parallelisms.\n\nAlpa can guarantee the plan generated at each individual level is *locally optimal*.\nOnce the two-level plans are generated, Alpa runs a third pass **runtime orchestration pass**. In this pass, Alpa applies the plans on the input computational graph,\nperforms some post-processing, and finally compile the original, single-node graph into parallel executables. It then sends the parallel executables to devices on the cluster.\n\n\nImportant concepts\n------------------\n\nUnderstanding the following concepts are necessary to understand what each pass is precisely doing during compilation.\n\n.. _cg:\n\nComputational graph\n###################\nLike many machine learning compiler systems, Alpa represents the model computation as a static computational graph.\nFor now, this computational graph is first extracted from the user code and expressed using the `JaxPR intermediate representation <https://jax.readthedocs.io/en/latest/jaxpr.html>`__,\nand then lowered to the `XLA HLO representation <https://www.tensorflow.org/xla/operation_semantics>`__.\n\n\n.. _device-cluster:\n\nDevice cluster\n##############\nAlpa runs on a cluster of compute devices, managed by Ray_. For example, a cluster of four AWS p3.16xlarge nodes, with 8 GPUs on each node, form an 4x8 device cluster, illustrated\nin :ref:`Figure 2<cluster-mesh>` below. We also call this device cluster *the cluster mesh*.\n\n.. _cluster-mesh:\n\n.. figure:: cluster-mesh.png\n  :align: center\n  :width: 450px\n\n  Figure 2: an M x N cluster mesh.\n\nDevice mesh\n###########\n\nAlpa's :ref:`inter-op compilation pass<inter-op-pass>` will slice the cluster mesh into multiple groups of devices. Each group might contain a number of devices\nwith high communication bandwidth, such as `NVIDIA NVLink <https://www.nvidia.com/en-us/data-center/nvlink/>`__. We call each group of devices a device mesh.\n:ref:`Figure 2<cluster-mesh>` shows how a cluster mesh is sliced into 4 device meshes.\n\nWorker\n######\n\nEach device mesh might consist of partial or full devices from a single node or from multiple nodes. Alpa uses a worker to manage multiple devices from a node; hence a device mesh might contain multiple workers, each mapping to a process that manages multiple devices on a node.\nFor example, :ref:`Figure 3<mesh-worker>` shows a mesh, consisted of 2 workers, and each worker manages 4 devices.\nThe workers are implemented as `Ray actors <https://github.com/alpa-projects/alpa/blob/main/alpa/device_mesh.py>`__.\n\n.. _mesh-worker:\n\n.. figure:: mesh-worker.png\n  :align: center\n  :width: 350px\n\n  Figure 3: A mesh is consisted of multiple workers managing devices.\n\nStage\n#####\nAlpa slices the input computational graph into multiple, adjacent subgraphs. We call each subgraph a stage.\n\nResharding\n##########\n# TODO\n\n\nCompilation Passes\n------------------\nWith the above concepts, we now explain what each compilation pass is exactly doing.\n\n.. _inter-op-pass:\n\nInter-op Pass\n#############\n\nInter-op pass slices the computational graph into multiple stages and the cluster mesh into multiple smaller device meshes; it then assigns each stage to a mesh.\nAlpa generates the slicing and assignment scheme optimally using a dynamic programming algorithm to minimize the inter-op parallel execution latency.\n\nIntra-op pass\n#############\nIntra-op pass looks at each <stage, mesh> pair generated by the inter-op pass, and generates the optimal intra-op parallelism execution plan for this stage to run on its assigned mesh.\n\n\nRuntime Orchestratoin pass\n##########################\nThe runtime orchestration pass looks at the pairs of stages and meshes generated by the inter-op pass, and the intra-op parallelism strategy generated for each <stage, mesh> pair by the intra-op pass.\nIt analyzes their data dependency, and tries to fullfills some requirements before runtime. These requirements include:\n\n- **Communication**: sending a tensor from a stage to its next stage. When the two stages have different intra-op parallelism execution plan, the tensor might be sharded differently on two meshes.\n  In that case, cross-mesh resharding is required. Alpa's runtime orchestration pass will try to generate the optimal scheme on how to communicate the tensors between two meshes.\n- **Scheduling**: Alpa's runtime will also compile and generate static scheduling instructions for pipelined execution of all stages, to minimize scheduling overheads at Runtime.\n\n\nThese three compilation passes are implemented on top of XLA_ and GSPMD_.\nDespite the compilation passes for distributed execution, XLA_ and GSPMD_ additionally perform some other necessary optimizations to improve the single-device execution performance.\n\n.. _XLA: https://www.tensorflow.org/xla\n.. _GSPMD: https://arxiv.org/pdf/2105.04663.pdf\n\n\n\nRuntime\n=======\nAlpa implements a runtime_ to orchestrate the inter-op parallel execution of different stages on these meshes.\nFor each stage, Alpa uses the GSPMD runtime to parallelize its execution on its assigned device mesh, following the intra-op parallelism execution plan generated by the intra-op pass.\n\n.. _Ray: https://github.com/ray-project/ray\n.. _MLP: tutorial/getting_started\n.. _worker: https://github.com/alpa-projects/alpa/blob/main/alpa/device_mesh.py#L64\n.. _runtime: https://github.com/alpa-projects/alpa/blob/main/alpa/pipeline_parallel/decentralized_distributed_runtime.py\n"
  },
  {
    "path": "docs/architecture/parallelism-view-and-rationale.rst",
    "content": ".. _rationale:\n\nRationale\n=========\ntest\n"
  },
  {
    "path": "docs/benchmark/benchmark.rst",
    "content": "Performance Benchmark\n=====================\n\nThe figure below shows the scaling efficiency of Alpa on training models with billions of parameters on an AWS cluster.\nThe instructions to reproduce the benchmark results is in this `README.md <https://github.com/alpa-projects/alpa/blob/main/benchmark/alpa/README.md>`_.\nThe explanation of the results can be found in Section 8.1 of `Alpa paper <https://arxiv.org/pdf/2201.12023.pdf>`_.\n\n.. figure:: bench-paper.png\n  :align: center\n\n.. raw:: html\n\n  <br></br>\n"
  },
  {
    "path": "docs/cluster_setup.md",
    "content": "# AWS Cluster Setup Guide\n\n1. Create a [placement group](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/placement-groups.html) on the AWS Management Console. Choose the `Cluster` placement strategy. This can make sure the interconnection bandwidth among different nodes in the cluster are high.\n2. Create a securiy group on the AWS Management Console (EC2 -> Network & Security -> Security Groups).\n3. Create an [EFS](https://console.aws.amazon.com/efs). This is used as an NFS for all nodes in the cluster. Please add the security group ID of the node you just started (can be found on the AWS Management Console) to the EFS to make sure your node can access the EFS. After that, you need to install the [efs-utils](https://docs.aws.amazon.com/efs/latest/ug/installing-other-distro.html) to mount the EFS on the node:\n   ```bash\n   git clone https://github.com/aws/efs-utils\n   cd efs-utils\n   ./build-deb.sh\n   sudo apt-get -y install ./build/amazon-efs-utils*deb\n   ```\n   You can try to mount the EFS on the node by:\n   ```bash\n   mkdir -p ~/efs\n   sudo mount -t efs {Your EFS file system ID}:/ ~/efs\n   sudo chmod 777 ~/efs\n   ```\n   If this takes forever, make sure you configure the sercurity groups right.\n\n\nClone the git repos under `~/efs`.\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# Configuration file for the Sphinx documentation builder.\n#\n# This file only contains a selection of the most common options. For a full\n# list see the documentation:\n# https://www.sphinx-doc.org/en/master/usage/configuration.html\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n\nimport os\nimport sys\n\n# -- Project information -----------------------------------------------------\n\nproject = 'Alpa'\nauthor = 'Alpa Developers'\ncopyright = f'2022, {author}'\n\n\ndef git_describe_version():\n    \"\"\"Get git describe version.\"\"\"\n    ver_py = os.path.join(\"..\", \"update_version.py\")\n    libver = {\"__file__\": ver_py}\n    exec(compile(open(ver_py, \"rb\").read(), ver_py, \"exec\"), libver, libver)\n    gd_version, _ = libver[\"git_describe_version\"]()\n    return gd_version\n\n\nimport alpa\nversion = git_describe_version()\nrelease = version\n\n\n# -- General configuration ---------------------------------------------------\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.autodoc',\n    'sphinx_gallery.gen_gallery',\n    'sphinx.ext.napoleon',\n    'sphinx.ext.intersphinx'\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n\n# Explicitly define the order within a subsection.\n# The listed files are sorted according to the list.\n# The unlisted files are sorted by filenames.\n# The unlisted files always appear after listed files.\n\n# Note: we need to execute files that use distributed runtime before\n# files that uses local runtime. Because all tutorials run on a single\n# process, using local runtime will allocate all GPU memory on the driver\n# script and leave no GPU memory for workers.\nwithin_subsection_order = {\n    \"tutorials\": [\n        \"quickstart.py\",\n        \"pipeshard_parallelism.py\",\n        \"alpa_vs_pmap.py\",\n    ],\n}\n\nclass WithinSubsectionOrder:\n    def __init__(self, src_dir):\n        self.src_dir = src_dir.split(\"/\")[-1]\n\n    def __call__(self, filename):\n        # If the order is provided, use the provided order\n        if (\n            self.src_dir in within_subsection_order\n            and filename in within_subsection_order[self.src_dir]\n        ):\n            index = within_subsection_order[self.src_dir].index(filename)\n            assert index < 1e10\n            return \"\\0%010d\" % index\n\n        # Otherwise, sort by filename\n        return filename\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_rtd_theme'\n\nhtml_favicon = 'logo/alpa-logo.ico'\n\nhtml_context = {\n    'display_github': True,\n    'github_user': 'alpa-projects',\n    'github_repo': 'alpa',\n    'github_version': 'main',\n    \"conf_py_path\": \"/docs/\",\n}\n\nhtml_theme_options = {\n    'analytics_id': 'G-587CCSSRL2',\n    'analytics_anonymize_ip': False,\n}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n\n# sphinx-gallery configuration\nsphinx_gallery_conf = {\n    'examples_dirs': ['gallery/tutorials'],\n    'gallery_dirs': ['tutorials'],\n    'within_subsection_order': WithinSubsectionOrder,\n    'backreferences_dir': 'gen_modules/backreferences',\n    \"filename_pattern\": os.environ.get(\"ALPA_TUTORIAL_EXEC_PATTERN\", r\".py\"),\n}\n\n# configuration for intersphinx: refer to the Python standard library.\nintersphinx_mapping = {\n    'python': ('https://docs.python.org/{.major}'.format(sys.version_info), None),\n    'matplotlib': ('https://matplotlib.org/', None),\n    'pandas': ('https://pandas.pydata.org/', None),\n}\n\n# -- Monkey patch -------------------------------------------------\n\n# Fix bugs in sphinx_gallery\nimport io\nfrom sphinx_gallery import gen_rst\nsetattr(gen_rst._LoggingTee, \"close\", lambda x: x.restore_std())\ndef raise_io_error(*args):\n    raise io.UnsupportedOperation()\nsetattr(gen_rst._LoggingTee, \"fileno\", raise_io_error)\n"
  },
  {
    "path": "docs/developer/developer_guide.rst",
    "content": "===============\nDeveloper Guide\n===============\n\nCode Organization\n=================\n\nThe code in alpa's repository is organized as follows:\n  - `alpa <https://github.com/alpa-projects/alpa/tree/main/alpa>`__: the python source code of Alpa\n  - `benchmark <https://github.com/alpa-projects/alpa/tree/main/benchmark>`__: benchmark scripts\n  - `build_jaxlib <https://github.com/alpa-projects/alpa/tree/main/build_jaxlib>`__: build scripts for Alpa's version of jaxlib\n  - `docs <https://github.com/alpa-projects/alpa/tree/main/docs>`__: documentation and tutorials\n  - `examples <https://github.com/alpa-projects/alpa/tree/main/examples>`__: public examples\n  - `playground <https://github.com/alpa-projects/alpa/tree/main/playground>`__: experimental scripts\n  - `tests <https://github.com/alpa-projects/alpa/tree/main/tests>`__: unit tests\n  - `third_party <https://github.com/alpa-projects/alpa/tree/main/third_party>`__: third party repos\n\nIn addition, Alpa maintains a tensorflow fork. This is because alpa modifies the XLA compiler, whose code\nis hosted in the tensorflow repo.\n\n- `tensorflow-alpa <https://github.com/alpa-projects/tensorflow-alpa>`__: The TensorFlow fork for Alpa.\n  The c++ source code of Alpa mainly resides in ``tensorflow/compiler/xla/service/spmd``.\n\n\nContribute to Alpa\n==================\nPlease submit a `pull request <https://github.com/alpa-projects/alpa/compare>`__ if you plan to contribute to Alpa.\n\nFormatting and Linting\n----------------------\nWe follow `Google Python Style Guide <https://google.github.io/styleguide/pyguide.html>`__.\n\nInstall yapf and pylint via:\n\n.. code-block:: bash\n\n    pip install yapf==0.32.0 pylint==2.14.0\n\nUse the following script to format the code and check linting errors:\n\n.. code-block:: bash\n\n    ./format.sh\n\nUnit Testing\n------------\nEvery New feature should come with a unit test. See this `README.md <https://github.com/alpa-projects/alpa/tree/main/tests/README.md>`_ on how to run tests locally.\n\nUpdating submodule tensorflow-alpa\n----------------------------------\nAlpa repo stores a commit hash of the submodule tensorflow-alpa, so git knows which version of tensorflow-alpa should be used.\nHowever, commands like ``git pull`` do not update the submodule to the latest stored commit. You need to additionally use the commands below.\n\n.. code-block:: bash\n\n    git submodule update --init --recursive\n\nContributing to submodule tensorflow-alpa\n-----------------------------------------\nIf you want to contribute code to tensorflow-alpa, you can follow the steps below\n\n1. Contributors send a pull request to tensorflow-alpa.\n2. Maintainers review the pull request and merge it to tensorflow-alpa.\n3. Contributors send a pull request to alpa. The pull request should update the stored hash commit of the submodule and other modifications to alpa if necessary.\n4. Maintainers review the pull request and merge it to alpa.\n"
  },
  {
    "path": "docs/gallery/tutorials/README.rst",
    "content": "Alpa Tutorials\n==============\n"
  },
  {
    "path": "docs/gallery/tutorials/advanced_api_usage.py_disable",
    "content": "\"\"\"\nAdvanced API Usage\n==================\n\nThis page will cover some more advanced examples of Alpa.\n\"\"\"\n\n###########################################\n# We first import libraries and create example model and train step functions.\n\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\nimport ray\nimport optax\n\nimport alpa\nfrom alpa import global_config, parallelize\nfrom alpa.device_mesh import DeviceCluster\nfrom alpa.model.bert_model import BertConfig, FlaxBertLayer\nfrom alpa.model.model_util import TrainState\nfrom alpa.util import count_communication_primitives, get_ray_namespace_str\n\n# launch the cluster\nray.init()\ncluster = DeviceCluster()\nglobal_config.devices = cluster.get_physical_mesh()\n\n# define consts\nbatch_size = 64\nseq_len = 512\nhidden_size = 512\nnum_heads = 4\nn_layers = 4\n\n\n# Define model, train state and train step\nclass BertLayerModel(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.layers = [\n            FlaxBertLayer(config=self.config, dtype=self.dtype)\n            for _ in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(self, x, attention_mask):\n        for i, layer in enumerate(self.layers):\n            layer_outputs = layer(x, attention_mask)\n            x = layer_outputs[0]\n        return x\n\n\ndef create_train_state(rngkey, model, inputs):\n    params = model.init(rngkey, *inputs)\n    tx = optax.adam(learning_rate=1e-2)\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              dynamic_scale=None)\n    return state\n\n\nrngkey = jax.random.PRNGKey(0)\nx = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size))\ny = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size))\nattention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32)\nbatch = {'x': x, 'y': y, \"attention_mask\": attention_mask}\nbert_config = BertConfig(hidden_size=hidden_size,\n                         intermediate_size=hidden_size * 4,\n                         num_attention_heads=num_heads,\n                         num_hidden_layers=n_layers)\nmodel = BertLayerModel(config=bert_config)\nstate = create_train_state(rngkey, model, [x, attention_mask])\n\n\ndef train_step(state, batch):\n\n    def loss_func(params):\n        out = state.apply_fn(params, batch[\"x\"], batch[\"attention_mask\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    grads = jax.grad(loss_func)(state.params)\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\n\n# define test utils\ndef print_hlo_communication_stats(hlo_text):\n    (n_total, n_all_reduce, n_all_gather, n_reduce_scatter,\n     n_all_to_all) = count_communication_primitives(hlo_text)\n\n    print(f\"#total: {n_total}, #all-reduce: {n_all_reduce}, \"\n          f\"#all-gather: {n_all_gather}, #reduce-scatter: {n_reduce_scatter}, \"\n          f\"#all-to-all: {n_all_to_all}\")\n\n\ndef reset_state():\n    global state\n    state = create_train_state(rngkey, model, [x, attention_mask])\n\n\n###########################################\n# Auto-Sharding Options\n# ~~~~~~~~~~~~~~~~~~~~~\n#\n# AutoShardingOption is designed to control the inter-operator parallelism more precisely.\n#\n# Control specific collective primitive\n# -----------------------------------------\n#\n# Some primitive is not well-supported on specific platforms(e.g. may cause deadlock).\n# In case of that, they should be excluded in auto-sharding's optimization space.\n# We control this by some auto-sharding options.\n#\n# In some cases, an allreduce can be replaced by a reduce-scatter first,\n# and an all-gather later. The two has the same communication, but reduce-scatter\n# may readuce the peak memory.\n\nas_option = global_config.default_autosharding_option\nas_option_backup = as_option.backup()\n\nas_option.prefer_reduce_scatter = True\nexecutable = parallelize(train_step).get_executable(state, batch)\nprint_hlo_communication_stats(executable.get_hlo_text())\n\n# create new state to avoid jit\nas_option.prefer_reduce_scatter = False\nstate = create_train_state(rngkey, model, [x, attention_mask])\nexecutable = parallelize(train_step).get_executable(state, batch)\nprint_hlo_communication_stats(executable.get_hlo_text())\n\nas_option.restore(as_option_backup)\n\n###########################################\n# Force to use data parallel\n# --------------------------\n#\n# Alpa can forcibly generates data parallel solution, or map a specific\n# mesh dimension to the batch dimension.\n#\n# With force_batch_dim_to_mesh_dim, Alpa forcibly maps the given logical mesh\n# dimension (0 or 1) to batch dimension inferred in auto-sharding.\n# If the option's value is None, but the two dimensions of the logical mesh is\n# larger than 1, Alpa still forcibly maps the first logical mesh dimension to\n# batch dimension.\n#\n# With force_data_parallel, Alpa sets the first dimension larger than 1 to the force_batch_dim_to_mesh_dim value.\n\n# Default mesh shape: (num_host,num_device)=(1,4)\n\nas_option.force_batch_dim_to_mesh_dim = 0\nreset_state()\nexecutable = parallelize(train_step).get_executable(state, batch)\nprint_hlo_communication_stats(executable.get_hlo_text())\n# The above uses model parallel\n\nas_option.force_batch_dim_to_mesh_dim = 1\nreset_state()\nexecutable = parallelize(train_step).get_executable(state, batch)\nprint_hlo_communication_stats(executable.get_hlo_text())\n# The above uses data parallel\n\nas_option.restore(as_option_backup)\n\n###########################################\n# Specify inter-operator parallelism strategy\n# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n#\n# We can specify inter-operator parallelism config with global_config.\n# To start with, we first set parallel strategy to 3d parallel and use alpa's grad decorator:\n\nglobal_config.devices.shutdown()\nglobal_config.strategy = \"pipeshard_parallel\"\nglobal_config.devices = cluster.get_virtual_physical_mesh()\n\n\ndef train_step(state, batch):\n\n    def loss_func(params):\n        out = state.apply_fn(params, batch[\"x\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    # modify the grad decorator here\n    grads = alpa.grad(loss_func)(state.params)\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\n\ndef profile_and_pp_pipeshard_stats(executable):\n    pipeshard_stats = executable.profile_all_executables()\n    print(\"All stages' stats in form of (time, memory)\")\n    for mesh_idx, mesh_stats in enumerate(pipeshard_stats):\n        output_str = \"\"\n        for stat in mesh_stats.values():\n            output_str += f\"({stat[0]:.3f}s,{stat[1]:.2f}GB),\"\n        print(f\"mesh {mesh_idx}:\" + output_str)\n\n\n###########################################\n# Specify layer clustering\n# ------------------------\n#\n# Layer cluster forms a number of JaxprEqns (atom in JAX IR) into the same layer.\n# We can also manually assign layers using the pipeline marker.\n\nfrom alpa import mark_pipeline, manual_layer_construction\n\n\nclass UnequalManualLayerBertLayerModel(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n    manual_pipeline_layer: bool = True\n\n    def setup(self):\n        self.layers = [\n            FlaxBertLayer(config=self.config, dtype=self.dtype)\n            for _ in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(self, x, attention_mask):\n        for i, layer in enumerate(self.layers):\n            # Add the pipeline start marker here\n            if i < 2:\n                mark_pipeline(name=str(i), mark_type='start')\n            layer_outputs = layer(x, attention_mask)\n            x = layer_outputs[0]\n            # Add the pipeline end marker here\n            if i == 0 or i == self.config.num_hidden_layers - 1:\n                mark_pipeline(name=str(i), mark_type='end')\n        return x\n\n\ndef train_step(state, batch):\n    # Add the manual layer construction decorator here\n    @manual_layer_construction(lift_markers=True)\n    def loss_func(params):\n        out = state.apply_fn(params, batch[\"x\"], batch[\"attention_mask\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    grads = alpa.grad(loss_func)(state.params)\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\n\nmodel = UnequalManualLayerBertLayerModel(config=bert_config)\nstate = create_train_state(rngkey, model, [x, attention_mask])\n\nexecutable = parallelize(train_step).get_executable(state, batch)\nprofile_and_pp_pipeshard_stats(executable)\n\nexecutable.shutdown()\n\n###########################################\n# The code above creates a model with four bert layers, then split them into\n# two alpa layers. With default setting, each layer maps a pipeline stage and\n# each stage use the same submesh. As we split between the first bert layer and\n# the other three layers, the memory consumption of the first stage is\n# approximately third of the second's.\n#\n# In manual layer construction, each instruction in the forward computation\n# should between a pipeline start marker and its corresponding pipeline end\n# marker. When using the manual pipeline marker, the loss function should be\n# decorated by the manual_layer_construction mark.\n#\n# For simplicity, manual_layer_construction provides a lift_marker option.\n# If it is turned on, the first and last pipeline marker are automatically\n# moved to the first and last JaxprEqn.\n#\n# Specify stage construction\n# --------------------------\n#\n# Stage construction merges layers into stages and assigns devices to each stage\n# with a logical mesh shape. Here we manually give the stage construction plan\n# with options in global_config.\n\n\nclass EqualManualLayerBertLayerModel(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n    manual_pipeline_layer: bool = True\n\n    def setup(self):\n        self.layers = [\n            FlaxBertLayer(config=self.config, dtype=self.dtype)\n            for _ in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(self, x, attention_mask):\n        for i, layer in enumerate(self.layers):\n            # Add the pipeline start marker here\n            mark_pipeline(name=str(i), mark_type='start')\n            layer_outputs = layer(x, attention_mask)\n            x = layer_outputs[0]\n            # Add the pipeline end marker here\n            mark_pipeline(name=str(i), mark_type='end')\n        return x\n\n\nmodel = EqualManualLayerBertLayerModel(config=bert_config)\nstate = create_train_state(rngkey, model, [x, attention_mask])\n\nglobal_config_backup = global_config.backup()\n\n# turn on manual stage plan\nglobal_config.pipeline_stage_mode = \"manual_stage\"\n# Layer-stage mapping\nglobal_config.forward_stage_layer_ids = [[0], [1], [2, 3]]\n# Physical mesh shape of each stage\nglobal_config.sub_physical_mesh_shapes = [(1, 1), (1, 1), (1, 2)]\n# Logical mesh shape of each stage\nglobal_config.sub_logical_mesh_shapes = [(1, 1), (1, 1), (2, 1)]\n# auto sharding option of each stage\nglobal_config.submesh_autosharding_option_dicts = [{}, {}, {}]\nexecutable = parallelize(train_step).get_executable(state, batch)\nprofile_and_pp_pipeshard_stats(executable)\n\nexecutable.shutdown()\nglobal_config.restore(global_config_backup)\n\n###########################################\n# Rematerialization with layer construction\n# -----------------------------------------\n#\n# We provide a layer-based rematerialization.\n\nmodel = EqualManualLayerBertLayerModel(config=bert_config)\nstate = create_train_state(rngkey, model, [x, attention_mask])\n\n\ndef get_train_step(remat_layer):\n\n    def train_step(state, batch):\n\n        # Set remat_layer in manual layer construction decorator here.\n        # The same is true for automatic layer construction decorator.\n        @manual_layer_construction(lift_markers=True, remat_layer=remat_layer)\n        def loss_func(params):\n            out = state.apply_fn(params, batch[\"x\"], batch[\"attention_mask\"])\n            loss = jnp.mean((out - batch[\"y\"])**2)\n            return loss\n\n        grads = alpa.grad(loss_func)(state.params)\n        new_state = state.apply_gradients(grads=grads)\n        return new_state\n\n    return train_step\n\n\nprint(\">>>>> With remat\")\nexecutable = parallelize(get_train_step(True)).get_executable(state, batch)\nprofile_and_pp_pipeshard_stats(executable)\nexecutable.shutdown()\nreset_state()\nprint(\">>>>> Without remat\")\nexecutable = parallelize(get_train_step(False)).get_executable(state, batch)\nprofile_and_pp_pipeshard_stats(executable)\nexecutable.shutdown()\n\n###########################################\n# The peak memory is significantly smaller when remat_layer is turned on.\n#\n# Moreover, we can remat at a fine-grained level, then do parallel at a relatively\n# coarse-grained level. The example below remat at each Bert Layer, but do\n# inter-operator parallelization for each two Bert Layers\n\nfrom alpa import automatic_remat, automatic_layer_construction\n\nmodel = BertLayerModel(config=bert_config)\n\n\ndef get_train_step(remat_layer):\n\n    def train_step(state, batch):\n\n        def loss_func(params):\n            out = state.apply_fn(params, batch[\"x\"], batch[\"attention_mask\"])\n            loss = jnp.mean((out - batch[\"y\"])**2)\n            return loss\n\n        # Split the forward into 4 parts for remat\n        if remat_layer:\n            loss_func = automatic_remat(loss_func, layer_num=4)\n        # Split the forward(remat-marked) into 2 parts for inter-operator parallel\n        loss_func = automatic_layer_construction(loss_func, layer_num=2)\n        grads = alpa.grad(loss_func)(state.params)\n        new_state = state.apply_gradients(grads=grads)\n        return new_state\n\n    return train_step\n\n\nprint(\">>>>> With remat\")\nstate = create_train_state(rngkey, model, [x, attention_mask])\nexecutable = parallelize(get_train_step(True)).get_executable(state, batch)\nprofile_and_pp_pipeshard_stats(executable)\nexecutable.shutdown()\nreset_state()\nprint(\">>>>> Without remat\")\nexecutable = parallelize(get_train_step(False)).get_executable(state, batch)\nprofile_and_pp_pipeshard_stats(executable)\nexecutable.shutdown()\n"
  },
  {
    "path": "docs/gallery/tutorials/alpa_vs_pmap.py",
    "content": "\"\"\"\nDifferences between alpa.parallelize, jax.pmap and jax.pjit\n===========================================================\n\nThe most common tool for parallelization or distributed computing in jax is\n`pmap <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`_.\nWith several lines of code change, we can use ``pmap`` for data parallel\ntraining. However, we cannot use ``pmap`` for model parallel training,\nwhich is required for training large models with billions of parameters.\n\nOn the contrary, ``alpa.parallelize`` supports both data parallelism and\nmodel parallelism in an automatic way. ``alpa.parallelize`` analyzes the\njax computational graph and picks the best strategy.\nIf data parallelism is more suitable, ``alpa.parallelize`` achieves the same\nperformance as ``pmap`` but with less code change.\nIf model parallelism is more suitable, ``alpa.parallelize`` achieves better performance\nand uses less memory than ``pmap``.\n\nIn this tutorial, we are going to compare ``alpa.parallelize`` and ``pmap`` on two\nworkloads. A more detailed comparison among ``alpa.parallelize``, ``pmap``, and ``xmap``\nis also attached at the end of the article.\n\"\"\"\n\n################################################################################\n# When data parallelism is prefered\n# ---------------------------------\n\n# TODO\n\n################################################################################\n# When model parallelism is prefered\n# ----------------------------------\n\n# TODO\n\n################################################################################\n# Comparing ``alpa.parallelize``, ``pmap``, ``xmap``, and ``pjit``\n# ----------------------------------------------------------------\n# Besides ``pmap``, jax also provides\n# `xmap <https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html>`_ and\n# `pjit <https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html>`_\n# for more advanced parallelization.\n# The table below compares the features of ``alpa.parallelize``, ``pmap``, ``xmap``\n# and ``pjit``. In summary, ``alpa.parallelize`` supports more parallelism\n# techniques in a more automatic way.\n#\n# ================  ================ ==================== ==================== =========\n# Transformation    Data Parallelism Operator Parallelism Pipeline Parallelism Automated\n# ================  ================ ==================== ==================== =========\n# alpa.parallelize  yes              yes                  yes                  yes\n# pmap              yes              no                   no                   no\n# xmap              yes              yes                  no                   no\n# pjit              yes              yes                  no                   no\n# ================  ================ ==================== ==================== =========\n#\n# .. note::\n#   Operator parallelism and pipeline parallelism are two forms of model parallelism.\n#   Operator parallelism partitions the work in a single operator and assigns them\n#   to different devices. Pipeline parallelism partitions the computational\n#   graphs and assigns different operators to different devices.\n"
  },
  {
    "path": "docs/gallery/tutorials/pipeshard_parallelism.py",
    "content": "\"\"\"\nDistributed Training with Both Shard and Pipeline Parallelism\n=============================================================\n\nAlpa can automatically parallelizes jax functions with both shard\nparallelism (a.k.a. intra-operator parallelism) and pipeline parallelism\n(a.k.a. inter-operator parallelism). Shard parallelism includes\ndata parallelism, operator parallelism, and their combinations.\nThe previous :ref:`quick start <alpa-quickstart>` tutorial focuses on\nusing Alpa for shard parallelism.\n\nIn this tutorial, we show how to use Alpa with both shard and pipeline parallelism.\nFirst, we show how to use Alpa to manually assign stages for pipeline parallelism.\nThen we show how to use Alpa to automate this process.\n\"\"\"\n\n################################################################################\n# Import Libraries and Initialize Environment\n# -------------------------------------------\n# First, import the required libraries.\n\nimport alpa\nfrom alpa.testing import assert_allclose\nimport copy\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nimport jax\nimport jax.numpy as jnp\nfrom jax import random\nimport optax\nimport ray\n\nalpa.util.disable_tqdm_globally()\n\n################################################################################\n# Connect to a Ray Cluster\n# ------------------------\n# Alpa uses a distributed framework `ray <https://docs.ray.io/>`_ to manage\n# the cluster and disributed workers. We initialize ray and alpa.\n\nray.init()\nalpa.init(cluster=\"ray\")\n\n# Alternatively, you can use the following command to connect to an existing\n# ray cluster.\n# ray.init(address=\"auto\")\n#\n# Note: `alpa.init(cluster=\"ray\")` uses the gpus resources of the whole ray\n# cluster. To configure Alpa to only use a subset of gpu resources, one can \n# specific the number of nodes and number of gpus per node.\n# For example, only run 2 gpus when 8 gpus are available \n# alpa.init('ray', devices_per_node=2, num_nodes=1)  \n\n################################################################################\n# Train an MLP on a Single Device\n# -------------------------------\n# In this tutorial, we use a toy dataset to train an MLP model.\n# Specifically, we use the model to fit the function: :math:`y = Wx + b`.\n# Note that now this model is being executed on CPU because we force the driver\n# process to use the CPU.\n\n\nclass MLPModel(nn.Module):\n    hidden_dim: int\n\n    @nn.compact\n    def __call__(self, x):\n        x = nn.Dense(features=self.hidden_dim * 4)(x)\n        x = nn.relu(x)\n        x = nn.Dense(features=self.hidden_dim)(x)\n        x = nn.relu(x)\n        x = nn.Dense(features=self.hidden_dim * 4)(x)\n        x = nn.relu(x)\n        x = nn.Dense(features=self.hidden_dim)(x)\n        x = nn.relu(x)\n        return x\n\n\ndim = 2048\nbatch_size = 2048\n\n# Generate ground truth W and b\nrngkey = jax.random.PRNGKey(0)\nk1, k2 = random.split(rngkey)\nW = random.normal(k1, (dim, dim))\nb = random.normal(k2, (dim,))\n\n# Generate the training data\nksample, knoise = random.split(k1)\nx = random.normal(ksample, (batch_size, dim))\ny = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim))\n\n# Initialize a train state, which includes the model paramter and optimizer\n# state.\nmodel = MLPModel(hidden_dim=dim)\nparams = model.init(rngkey, x)\ntx = optax.adam(learning_rate=1e-3)\nstate = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n\n# Define the training step\ndef train_step(state, batch):\n\n    def loss_func(params):\n        out = model.apply(params, batch[\"x\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    grads = jax.grad(loss_func)(state.params)\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\n\nbatch = {\"x\": x, \"y\": y}\nexpected_state = train_step(state, batch)\n\n################################################################################\n# Pipeline Parallelism with Manual Assignment\n# -------------------------------------------\n# Pipeline paralleism requires partitioning the model into several pipeline\n# stages. To manually assign stages, we can use ``alpa.mark_pipeline_boundary``\n# to mark the boundary of each pipeline stage in the forward function.\n# Note that each pipeline stage is also automatically parallelized by the\n# shard parallel pass.\n\n\n# Define a MLP model with manual stage boundaries.\nclass ManualPipelineMLPModel(nn.Module):\n    hidden_dim: int\n\n    @nn.compact\n    def __call__(self, x):\n        x = nn.Dense(features=self.hidden_dim * 4)(x)\n        x = nn.relu(x)\n        x = nn.Dense(features=self.hidden_dim)(x)\n        x = nn.relu(x)\n        # Use this boundary marker to separate the model into two stages.\n        alpa.mark_pipeline_boundary()\n        x = nn.Dense(features=self.hidden_dim * 4)(x)\n        x = nn.relu(x)\n        x = nn.Dense(features=self.hidden_dim)(x)\n        x = nn.relu(x)\n        return x\n\n\n# Initialize the train state with the same parameters as the single-device\n# model.\nmanual_pipeline_model = ManualPipelineMLPModel(hidden_dim=dim)\nmanual_pipeline_state = TrainState.create(apply_fn=manual_pipeline_model.apply,\n                                          params=copy.deepcopy(params),\n                                          tx=tx)\n\n\n# Define the training step.\n# We use the \"alpa.PipeshardParallel\" option to let alpa use both\n# pipeline parallelism and shard parallelism. To make pipeline parallelism\n# efficient, we need to fill the pipeline with many micro batches,\n# so a `num_micro_batches` should be specified.\n@alpa.parallelize(method=alpa.PipeshardParallel(num_micro_batches=16,\n                                                layer_option=\"manual\"))\ndef manual_pipeline_train_step(state, batch):\n\n    def loss_func(params):\n        out = state.apply_fn(params, batch[\"x\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    # We use `alpa.grad` here to separate the apply gradient stage with the\n    # forward/backward stages in the pipeline. This is necessary to ensure that\n    # the gradient accumulation is correct.\n    grads = alpa.grad(loss_func)(state.params)\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\n\nmanual_pipeline_actual_state = manual_pipeline_train_step(\n    manual_pipeline_state, batch)\nassert_allclose(expected_state.params,\n                manual_pipeline_actual_state.params,\n                atol=5e-3)\n\nalpa.shutdown()\n\n####################\n#\n# .. note::\n#\n#   In addition, Alpa supports more flexible manual assignments of pipeline\n#   parallelism strategies. In the above example, each partitioned stages will\n#   be assigned an equal number of devices to run. If you want to control the\n#   device assignment of each stage, you can use the more advanced\n#   ``stage_option=alpa.ManualStageOption``.\n\n################################################################################\n# Pipeline Parallelism with Automatic Assignment\n# ----------------------------------------------\n# Alpa also supports automatically partitioning the model into multiple\n# pipeline stages and assign each pipeline stage a device mesh such that\n# the total execution latency is minimized. Specifically, the automatic\n# partitioning algorithm consists of the following steps:\n#\n# 1. **Layer Construction:** In this step, the operators in the model are\n#    clustered into \"layers\" based on a graph clustering algorithm. The\n#    user needs to specify the total number of layers (i.e. clusters) as\n#    a hyperparameter.\n# 2. **Stage Construction and Mesh Slicing:** In this step, we partition\n#    the device cluster (device mesh) to multiple submeshes and assign\n#    layers to submeshes to form pipeline stages to minimize the total\n#    pipeline execution latency.\n\nalpa.init(cluster=\"ray\")\n\n# Define the parallel method.\n# `alpa.AutoLayerOption(layer_num=2)` means we use the auto layer construcion\n# algorithm to cluster primitive operators into two layers.\n# `stage_option=\"auto\"` means we enable the auto stage construction algorithm.\nmethod = alpa.PipeshardParallel(num_micro_batches=16,\n                                layer_option=alpa.AutoLayerOption(layer_num=2),\n                                stage_option=\"auto\")\n\n\n# Define the training step. The function body is the same as the above one.\n@alpa.parallelize(method=method)\ndef auto_pipeline_train_step(state, batch):\n\n    def loss_func(params):\n        out = state.apply_fn(params, batch[\"x\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    # Again, we use `alpa.grad` here to separate the apply gradient stage with\n    # the forward/backward stages in the pipeline.\n    grads = alpa.grad(loss_func)(state.params)\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\n\n# In the first call, alpa triggers the compilation.\n# The compilation first profiles several costs and solves an optimization\n# problem to get the optimal pipeline assignments.\nauto_pipeline_actual_state = auto_pipeline_train_step(state, batch)\nassert_allclose(expected_state.params,\n                auto_pipeline_actual_state.params,\n                atol=5e-3)\n\nalpa.shutdown()\n\n################################################################################\n# Interpret the Results\n# ---------------------\n# **Some basic concepts**\n# - Cluster mesh and submeshes\n#     - Cluster mesh is a computer cluster that contains GPUs. A ``N×M`` cluster mesh means the cluster has ``N`` physical machines and each machine has ``M`` GPUs.\n#     - Submeshes can be obtained by slicing from the cluster mesh. For example, given a ``N×M`` cluster mesh, a submesh ``(1, M)`` means using all GPUs in one physical machine.\n#     - For more details on how Alpa uses submeshes to solve *inter-operator parallelism*, you can read the **Section 5: Inter-Operator Parallelism** in the `Alpa paper <https://arxiv.org/pdf/2201.12023.pdf>`_.\n# - Device mesh and logical mesh\n#     - A device mesh is a 2-dimensional logical view of a set of physical devices.\n#     - For a set of physical devices, there can be multiple logical views. For example, given 2 nodes and 8 GPUs per node (i.e., 16 devices in total), we can view them as a 2×8, 1×16, 4×4, 8×2, or 16×1 device mesh.\n#     - The mapping between physical devices and the logical device mesh view is optimized by the inter-op pass\n#         - Hence, you can see ``Result mesh_shapes`` and the corresponding ``Result logical_mesh_shapes`` in the optimization output.\n#\n# With the basic concepts in mind, you now can better understand the ``ModuleProfileResult``:\n# - ``ModuleProfileResult``: ``result[(i, j, s, c), m]`` means this stage contains forward layers ``i, i+1, ..., j`` and corresponding backward layers, and runs under the ``s``-th submesh and the ``c``-th auto sharding config for the submesh. The ``m = 0`` means the result is for the forward pass, and ``m = 1`` for backward pass."
  },
  {
    "path": "docs/gallery/tutorials/quickstart.py",
    "content": "\"\"\"\n.. _alpa-quickstart:\n\nAlpa Quickstart\n===============\n\nAlpa is built on top of a tensor computation framework `Jax <https://jax.readthedocs.io/en/latest/index.html>`_ .\nAlpa can automatically parallelize jax functions and runs them on a distributed cluster.\nAlpa analyses the computational graph and generates a distributed execution plan\ntailored for the computational graph and target cluster.\nThe generated execution plan can combine state-of-the-art distributed training techniques\nincluding data parallelism, operator parallelism, and pipeline parallelism.\n\nAlpa provides a simple API ``alpa.parallelize`` and automatically generates the best execution\nplan by solving optimization problems. Therefore, you can efficiently scale your jax computation\non a distributed cluster, without any expertise in distributed computing.\n\nIn this tutorial, we show the usage of Alpa with an MLP example.\n\"\"\"\n\n################################################################################\n# Import Libraries\n# ----------------\n# We first import the required libraries.\n# Flax and optax are libraries on top of jax for training neural networks.\n# Although we use these libraries in this example, Alpa works on jax's and XLA's internal\n# intermediate representations and does not depend on any specific high-level libraries.\n\nfrom functools import partial\n\nimport alpa\nfrom alpa.testing import assert_allclose\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nimport jax\nimport jax.numpy as jnp\nfrom jax import random\nimport numpy as np\nimport optax\n\n\n################################################################################\n# Train an MLP on a Single Device\n# -------------------------------\n# To begin with, we implement the model and training loop on a single device. We will\n# parallelize it later. We train an MLP to learn a function y = Wx + b.\n\nclass MLPModel(nn.Module):\n    hidden_dim: int\n    num_layers: int\n\n    @nn.compact\n    def __call__(self, x):\n        for i in range(self.num_layers):\n            if i % 2 == 0:\n                x = nn.Dense(features=self.hidden_dim * 4)(x)\n            else:\n                x = nn.Dense(features=self.hidden_dim)(x)\n            x = nn.relu(x)\n        return x\n\ndim = 2048\nbatch_size = 2048\nnum_layers = 10\n\n# Generate ground truth W and b\nrngkey = jax.random.PRNGKey(0)\nk1, k2 = random.split(rngkey)\nW = random.normal(k1, (dim, dim))\nb = random.normal(k2, (dim,))\n\n# Generate the training data\nksample, knoise = random.split(k1)\nx = random.normal(ksample, (batch_size, dim))\ny = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim))\n\n# Initialize a train state, which includes the model paramter and optimizer state.\nmodel = MLPModel(hidden_dim=dim, num_layers=num_layers)\nparams = model.init(rngkey, x)\ntx = optax.adam(learning_rate=1e-3)\nstate = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n# Define the training function and execute one step\ndef train_step(state, batch):\n    def loss_func(params):\n        out = state.apply_fn(params, batch[\"x\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    grads = jax.grad(loss_func)(state.params)\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\nbatch = {\"x\": x, \"y\": y}\nexpected_state = train_step(state, batch)\n\n################################################################################\n# Auto-parallelization with ``alpa.parallelize``\n# ----------------------------------------------\n# Alpa provides a transformation ``alpa.parallelize`` to parallelize a jax function.\n# ``alpa.parallelize`` is similar to ``jax.jit`` . ``jax.jit`` compiles a jax\n# function for a single device, while ``alpa.parallelize`` compiles a jax function\n# for a distributed device cluster.\n# You may know that jax has some built-in transformations for parallelization,\n# such as ``pmap``, ``pjit``, and ``xmap``. However, these transformations are not\n# fully automatic, because they require users to manually specify the parallelization\n# strategies such as parallelization axes and device mapping schemes. You also need to\n# manually call communication primitives such as ``lax.pmean`` and ``lax.all_gather``,\n# which is nontrivial if you want to do advanced model parallelization.\n# Unlike these transformations, ``alpa.parallelize`` can do all things automatically for\n# you. ``alpa.parallelize`` finds the best parallelization strategy for the given jax\n# function and does the code tranformation. You only need to write the code as if you are\n# writing for a single device.\n\n# Define the training step. The body of this function is the same as the\n# ``train_step`` above. The only difference is to decorate it with\n# ``alpa.paralellize``.\n\n@alpa.parallelize\ndef alpa_train_step(state, batch):\n    def loss_func(params):\n        out = state.apply_fn(params, batch[\"x\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    grads = jax.grad(loss_func)(state.params)\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\n# Test correctness\nactual_state = alpa_train_step(state, batch)\nassert_allclose(expected_state.params, actual_state.params, atol=5e-3)\n\n################################################################################\n# After being decorated by ``alpa.parallelize``, the function can still take numpy\n# arrays or jax arrays as inputs. The function will first distribute the input\n# arrays into correct devices according to the parallelization strategy and then\n# execute the function distributedly. The returned result arrays are also\n# stored distributedly.\n\nprint(\"Input parameter type:\", type(state.params[\"params\"][\"Dense_0\"][\"kernel\"]))\nprint(\"Output parameter type:\", type(actual_state.params[\"params\"][\"Dense_0\"][\"kernel\"]))\n\n# We can use `np.array` to convert a distributed array back to a numpy array.\nkernel_np = np.array(actual_state.params[\"params\"][\"Dense_0\"][\"kernel\"])\n\n################################################################################\n# Execution Speed Comparison\n# --------------------------\n# By parallelizing a jax function, we can accelerate the computation and reduce\n# the memory usage per GPU, so we can train larger models faster.\n# We benchmark the execution speed of ``jax.jit`` and ``alpa.parallelize``\n# on a 8-GPU machine.\n\nstate = actual_state  # We need this assignment because the original `state` is \"donated\" and freed.\nfrom alpa.util import benchmark_func\n\n# Benchmark serial execution with jax.jit\njit_train_step = jax.jit(train_step, donate_argnums=(0,))\n\ndef sync_func():\n    jax.local_devices()[0].synchronize_all_activity()\n\ndef serial_execution():\n    global state\n    state = jit_train_step(state, batch)\n\ncosts = benchmark_func(serial_execution, sync_func, warmup=5, number=10, repeat=5) * 1e3\nprint(f\"Serial execution time. Mean: {np.mean(costs):.2f} ms, Std: {np.std(costs):.2f} ms\")\n\n# Benchmark parallel execution with alpa\n# We distribute arguments in advance for the benchmarking purpose.\nstate, batch = alpa_train_step.preshard_dynamic_args(state, batch)\n\ndef alpa_execution():\n    global state\n    state = alpa_train_step(state, batch)\n\nalpa_costs = benchmark_func(alpa_execution, sync_func, warmup=5, number=10, repeat=5) * 1e3\nprint(f\"Alpa execution time.   Mean: {np.mean(alpa_costs):.2f} ms, Std: {np.std(alpa_costs):.2f} ms\")\n\n################################################################################\n# Memory Usage Comparison\n# -----------------------\n# We can also compare the memory usage per GPU.\n\nGB = 1024 ** 3\n\nexecutable = jit_train_step.lower(state, batch).compile().runtime_executable()\nprint(f\"Serial execution per GPU memory usage: {executable.total_allocation_size() / GB:.2f} GB\")\n\nalpa_executable = alpa_train_step.get_executable(state, batch)\nprint(f\"Alpa execution per GPU memory usage:   {alpa_executable.get_total_allocation_size() / GB:.2f} GB\")\n\n################################################################################\n# Comparison against Data Parallelism (or ``jax.pmap``)\n# -----------------------------------------------------\n# The most common parallelization technique in deep learning is data parallelism.\n# In jax, we can use ``jax.pmap`` to implement data parallelism.\n# However, data parallelism only is not enough for training large models due to\n# both memory and communication costs. Here, we use the same model to benchmark the\n# execution speed and memory usage of ``jax.pmap`` on the same 8-GPU machine.\n\n@partial(jax.pmap, axis_name=\"batch\")\ndef pmap_train_step(state, batch):\n    def loss_func(params):\n        out = model.apply(params, batch[\"x\"])\n        loss = jnp.mean((out - batch[\"y\"])**2)\n        return loss\n\n    grads = jax.grad(loss_func)(state.params)\n    # all-reduce gradients\n    grads = jax.lax.pmean(grads, axis_name=\"batch\")\n    new_state = state.apply_gradients(grads=grads)\n    return new_state\n\n# Replicate model and distribute batch\ndevices = jax.local_devices()\nstate = jax.device_put_replicated(state, devices)\ndef shard_batch(x):\n    x = x.reshape((len(devices), -1) + x.shape[1:])\n    return jax.device_put_sharded(list(x), devices)\nbatch = jax.tree_map(shard_batch, batch)\n\n# Benchmark data parallel execution\ndef data_parallel_execution():\n    global state\n    state = pmap_train_step(state, batch)\n\ncosts = benchmark_func(data_parallel_execution, sync_func, warmup=5, number=10, repeat=5) * 1e3\nprint(f\"Data parallel execution time. Mean: {np.mean(costs):.2f} ms, Std: {np.std(costs):.2f} ms\")\nprint(f\"Alpa execution time.          Mean: {np.mean(alpa_costs):.2f} ms, Std: {np.std(alpa_costs):.2f} ms\\n\")\n\nexecutable = pmap_train_step.lower(state, batch).compile().runtime_executable()\nprint(f\"Data parallel execution per GPU memory usage: {executable.total_allocation_size() / GB:.2f} GB\")\nprint(f\"Alpa execution per GPU memory usage:          {alpa_executable.get_total_allocation_size() / GB:.2f} GB\")\n\n################################################################################\n# As you can see, ``alpa.parallelize`` achieves better execution speed and\n# requires less memory compared with data parallelism.\n# This is because data parallelism only works well if the activation size is much\n# larger than the model size, which is not the case in this benchmark.\n# In contrast, ``alpa.parallelize`` analyzes the computational graph and\n# finds the best parallelization strategy.\n"
  },
  {
    "path": "docs/index.rst",
    "content": "Alpa Documentation\n==================\n.. raw:: html\n\n  <a class=\"github-button\" href=\"https://github.com/alpa-projects/alpa\" data-size=\"large\" data-show-count=\"true\" aria-label=\"Star alpa-projects/alpa on GitHub\">Star</a>\n  <a class=\"github-button\" href=\"https://github.com/alpa-projects/alpa/fork\" data-icon=\"octicon-repo-forked\" data-size=\"large\" data-show-count=\"true\" aria-label=\"Fork alpa-projects/alpa on GitHub\">Fork</a>\n  <script async defer src=\"https://buttons.github.io/buttons.js\"></script>\n  <br></br>\n\nAlpa is a system for training and serving large-scale neural networks.\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Getting Started\n\n   install.rst\n   tutorials/quickstart.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Tutorials\n\n   tutorials/pipeshard_parallelism.rst\n   tutorials/alpa_vs_pmap.rst\n   tutorials/opt_serving.rst\n   tutorials/perf_tuning_guide.rst\n   tutorials/icml_big_model_tutorial.rst\n   tutorials/alpa_on_slurm.rst\n   tutorials/faq.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Architecture\n\n   architecture/overview.rst\n   architecture/alpa_compiler_walk_through.rst\n   architecture/intra_op_solver.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Benchmark\n\n   benchmark/benchmark.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Publications\n\n   publications/publications.rst\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Developer Guide\n\n   developer/developer_guide.rst\n"
  },
  {
    "path": "docs/install.rst",
    "content": "Install Alpa\n============\n\nThis page provides instructions to install alpa from Python wheels or from source. The minimum supported python version is 3.7.\n\nPrerequisites\n-------------\n\nRegardless of installing from wheels or from source, there are a few prerequisite packages:\n\n1. CUDA toolkit:\n\n  Follow the official guides to install `CUDA <https://developer.nvidia.com/cuda-toolkit>`_ and `cuDNN <https://developer.nvidia.com/cudnn>`_.\n  Alpa requires CUDA >= 11.1 and  cuDNN >= 8.0.5.\n\n2. Update pip version and install cupy:\n\n  .. code:: bash\n\n    # Update pip\n    pip3 install --upgrade pip\n\n    # Install cupy\n    pip3 install cupy-cuda11x\n\n  Then, check whether your system already has NCCL installed.\n\n  .. code:: bash\n\n    python3 -c \"from cupy.cuda import nccl\"\n\n  If it prints nothing, then NCCL has already been installed.\n  Otherwise, follow the printed instructions to install NCCL.\n\nMethods\n-------\nChoose one of the methods below.\n\n.. _install-from-wheels:\n\nMethod 1: Install from Python Wheels\n####################################\n\nAlpa provides wheels for the following CUDA (cuDNN) and Python versions:\n\n- CUDA (cuDNN): 11.1 (8.0.5), 11.2 (8.1.0), 11.3 (8.2.0)\n- Python: 3.7, 3.8, 3.9\n\nIf you need to use other CUDA, cuDNN, or Python versions, please follow the next section to :ref:`install from source<install-from-source>`.\n\n1. Install Alpa python package.\n\n  .. code:: bash\n\n    pip3 install alpa\n\n2. Install Alpa-modified Jaxlib. Make sure that the jaxlib version corresponds to the version of\n   the existing CUDA and cuDNN installation you want to use.\n   You can specify a particular CUDA and cuDNN version for jaxlib explicitly via:\n\n  .. code:: bash\n\n    pip3 install jaxlib==0.3.22+cuda{cuda_version}.cudnn{cudnn_version} -f https://alpa-projects.github.io/wheels.html\n\n  For example, to install the wheel compatible with CUDA >= 11.1 and cuDNN >= 8.0.5, use the following command:\n\n  .. code:: bash\n\n    pip3 install jaxlib==0.3.22+cuda111.cudnn805 -f https://alpa-projects.github.io/wheels.html\n\n  You can see all available wheel versions we provided at our `PyPI index <https://alpa-projects.github.io/wheels.html>`_.\n\n.. note::\n\n  As of now, Alpa modified the original jaxlib at the version ``jaxlib==0.3.22``. Alpa regularly rebases the official jaxlib repository to catch up with the upstream.\n\n\n.. _install-from-source:\n\nMethod 2: Install from Source\n#############################\n\n1. Clone repos\n\n  .. code:: bash\n\n    git clone --recursive https://github.com/alpa-projects/alpa.git\n\n2. Install Alpa python package.\n\n  .. code:: bash\n\n    cd alpa\n    pip3 install -e \".[dev]\"  # Note that the suffix `[dev]` is required to build custom modules.\n\n3. Build and install Alpa-modified Jaxlib. The Jaxlib contains c++ code of Alpa.\n\n  .. code:: bash\n\n    cd build_jaxlib\n    python3 build/build.py --enable_cuda --dev_install --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa\n    cd dist\n\n    pip3 install -e .\n\n\n.. note::\n\n  Building the latest Alpa-modified jaxlib requires new C++17 standards. It is known that some compiler versions such as ``gcc==7.3`` or ``gcc==9.4`` cannot correctly compile the jaxlib code.\n  See `this thread <https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90415>`_ about the know issues.\n\n  If you meet compilation errors, please install our recommended gcc version ``gcc==7.5``; newer gcc versions might also work.\n  Then please clean the bazel cache (``rm -rf ~/.cache/bazel``) and try to build jaxlib again.\n\n.. note::\n\n  All installations are in development mode, so you can modify python code and it will take effect immediately.\n  To modify c++ code in tensorflow, you only need to run the command below from step 3 to recompile jaxlib::\n\n    python3 build/build.py --enable_cuda --dev_install --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa\n\n.. note::\n\n   Alpa python package and Alpa-modified Jaxlib are two separate libraries. If you only want to develop the python source code, you can install\n   Alpa python package from source and install Alpa-modified Jaxlib from wheels.\n\nCheck Installation\n------------------\nYou can check the installation by running the following commands.\n\n.. code:: bash\n\n  ray start --head\n  python3 -m alpa.test_install\n\n[Optional] PyTorch Frontend\n-------------------------------------\n\nWhile Alpa is mainly designed for Jax, Alpa also provides an experimental PyTorch frontend.\nAlpa supports PyTorch models that meet the following requirements:\n\n1. No input-dependent control flow\n2. No weight sharing\n\nTo enable Alpa for PyTorch, install the following dependencies:\n\n  .. code:: bash\n\n    # Install torch and torchdistx\n    pip3 uninstall -y torch torchdistx\n    pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==1.12 torchdistx\n\n    # Build functorch from source\n    git clone https://github.com/pytorch/functorch\n    cd functorch/\n    git checkout 76976db8412b60d322c680a5822116ba6f2f762a\n    python3 setup.py install\n\nPlease look at ``tests/torch_frontend/test_simple.py`` for usage examples.\n\nTroubleshooting\n---------------\n\nUnhandled Cuda Error\n####################\nIf you see errors like ``cupy_backends.cuda.libs.nccl.NcclError: NCCL_ERROR_UNHANDLED_CUDA_ERROR: unhandled cuda error``, it is mainly due to the compatibility issues between CUDA, NCCL, and GPU driver versions. Please double check these versions and see `Issue #496 <https://github.com/alpa-projects/alpa/issues/496>`_ for more details.\n\nUsing Alpa on Slurm\n###################\nSince Alpa relies on Ray to manage the cluster nodes, Alpa can run on a Slurm cluster as long as Ray can run on it.\nIf you have trouble running Alpa on a Slurm cluster, we recommend to follow `this guide <https://docs.ray.io/en/latest/cluster/slurm.html>`__ to setup Ray on Slurm and make sure simple Ray examples\ncan run without any problem, then move forward to install and run Alpa in the same environment.\n\nCommon issues of running Alpa on Slurm include:\n\n- The Slurm cluster has installed additional networking proxies, so XLA client connections time out. Example errors can be found in `this thread <https://github.com/alpa-projects/alpa/issues/452#issuecomment-1134260817>`_.\n  The slurm cluster users might need to check and fix those proxies on their slurm cluster and make sure processes spawned by Alpa can see each other.\n\n- When launching a Slurm job using ``SRUN``, the users do not request enough CPU threads or GPU resources for Ray to spawn many actors on Slurm.\n  The users need to adjust the value for the argument ``--cpus-per-task`` passed to ``SRUN`` when launching Alpa. See `Slurm documentation <https://slurm.schedmd.com/srun.html>`_ for more information.\n\nYou might also find the discussion under `Issue #452 <https://github.com/alpa-projects/alpa/issues/452>`__ helpful.\n\nJaxlib, Jax, Flax Version Problems\n##################################\nAlpa is only tested against specific versions of Jax and Flax.\nThe recommended Jax and Flax versions are specified by ``install_require_list`` in `setup.py <https://github.com/alpa-projects/alpa/blob/main/setup.py>`_ .\n(You can checkout the file to specific version tag if you are not using the latest HEAD.)\n\nIf you see version errors like below\n\n.. code:: bash\n\n  >>> import alpa\n    ......\n    RuntimeError: jaxlib version 0.3.7 is newer than and incompatible with jax version 0.3.5. Please update your jax and/or jaxlib packages\n\nMake sure your Jax, Flax and Optax/Chex versions are compatible with the versions specified in Alpa's ``setup.py``.\nMake sure you re-install **Alpa-modified Jaxlib** by either using :ref:`our prebuilt wheels<install-from-wheels>` or :ref:`Install from Source<install-from-source>` to overwrite the default Jaxlib.\n\nNumpy Version Problems\n#######################\nIf you start with a clean Python virtual environment and have followed the procedures in this guide strictly, you should not see problems about Numpy versions.\n\nHowever, sometimes due to the installation of other Python packages, another version of numpy might be silently installed before compiling jaxlib,\nand you might see numpy version errors similar to the following one when launching Alpa after installing from source:\n\n.. code:: bash\n\n  >>> python3 tests/test_install.py\n    ......\n    RuntimeError: module compiled against API version 0xf but this version of numpy is 0xd\n    ImportError: numpy.core._multiarray_umath failed to import\n    ImportError: numpy.core.umath failed to import\n    2022-05-20 21:57:35.710782: F external/org_tensorflow/tensorflow/compiler/xla/python/xla.cc:83] Check failed: tensorflow::RegisterNumpyBfloat16()\n    Aborted (core dumped)\n\nThis is because you have used a higher version of numpy when compiling jaxlib, but later used a lower version of numpy to run Alpa.\n\nTo address the problem, please first downgrade the numpy in your Python environment to ``numpy==1.20`` via ``pip install numpy==1.20``,\nthen follow the procedures in :ref:`install from source<install-from-source>` to rebuild and reinstall jaxlib.\nOptionally, you can switch back to use the higher version of numpy (``numpy>=1.20``) to run Alpa and your other applications, thanks to numpy's backward compatibility.\n\nSee `Issue#461 <https://github.com/alpa-projects/alpa/issues/461>`_ for more discussion.\n\nTests Hang with no Errors on Multi-GPU Nodes\n############################################\nThis could be an indication that IO virtualization (VT-d, or IOMMU) is interfereing with the NCCL library. On multi-gpu systems, PCI point-to-point traffic can be redirected to the CPU by these systems causing performance reductions or programs to hang. These settings can typically be disabled from the BIOS, or sometimes from the OS. You can find more information on Nividia's NCCL troubleshooting guide `here <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html>`_. Note that disabling IO virtualization can introduce security vulnerabilities, with peripherals having read/write access to DRAM through the DMA (Direct Memory Access) protocol.\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\r\n\r\npushd %~dp0\r\n\r\nREM Command file for Sphinx documentation\r\n\r\nif \"%SPHINXBUILD%\" == \"\" (\r\n\tset SPHINXBUILD=sphinx-build\r\n)\r\nset SOURCEDIR=.\r\nset BUILDDIR=_build\r\n\r\nif \"%1\" == \"\" goto help\r\n\r\n%SPHINXBUILD% >NUL 2>NUL\r\nif errorlevel 9009 (\r\n\techo.\r\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\r\n\techo.installed, then set the SPHINXBUILD environment variable to point\r\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\r\n\techo.may add the Sphinx directory to PATH.\r\n\techo.\r\n\techo.If you don't have Sphinx installed, grab it from\r\n\techo.http://sphinx-doc.org/\r\n\texit /b 1\r\n)\r\n\r\n%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\ngoto end\r\n\r\n:help\r\n%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%\r\n\r\n:end\r\npopd\r\n"
  },
  {
    "path": "docs/publications/publications.rst",
    "content": "Publications\n============\n\nAlpa is developed as a research project with collaborators from multiple institutions.\nThis page includes references to publications describing the ideas behind Alpa.\n\n| `Alpa: Automating Inter- and Intra-Operator Parallelism for Distributed Deep Learning <https://arxiv.org/abs/2201.12023>`_\n| Lianmin Zheng*, Zhuohan Li*, Hao Zhang*, Yonghao Zhuang, Zhifeng Chen, Yanping Huang, Yida Wang, Yuanzhong Xu, Danyang Zhuo, Eric P. Xing, Joseph E. Gonzalez, Ion Stoica\n| *OSDI 2022*\n| \n| `On Optimizing the Communication of Model Parallelism <https://arxiv.org/abs/2211.05322>`_\n| Yonghao Zhuang*, Hexu Zhao*, Lianmin Zheng, Zhuohan Li, Eric P. Xing, Qirong Ho, Joseph E. Gonzalez, Ion Stoica, Hao Zhang\n| *MLSys 2023*\n| \n| `AlpaServe: Statistical Multiplexing with Model Parallelism for Deep Learning Serving <https://arxiv.org/abs/2302.11665>`_\n| Zhuohan Li*, Lianmin Zheng*, Yinmin Zhong*, Vincent Liu, Ying Sheng, Xin Jin, Yanping Huang, Zhifeng Chen, Hao Zhang, Joseph E. Gonzalez, Ion Stoica\n| *OSDI 2023*\n"
  },
  {
    "path": "docs/publish.py",
    "content": "#!/usr/bin/python3\n\nimport os\nfrom datetime import datetime\n\n\ndef run_cmd(cmd):\n    print(cmd)\n    os.system(cmd)\n\n\nrun_cmd(f\"cd $ALPA_SITE_PATH; git pull\")\n\n# (Optional) Remove old files\n# run_cmd(\"rm -rf $ALPA_SITE_PATH/*\")\n\nrun_cmd(\"cp -r _build/html/* $ALPA_SITE_PATH\")\n\ncmd_message = f\"Archive {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\"\nrun_cmd(\n    f\"cd $ALPA_SITE_PATH; git add .; git commit -m '{cmd_message}'; git push origin master\"\n)\n"
  },
  {
    "path": "examples/ViT/README.md",
    "content": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\nAdopted from https://github.com/huggingface/transformers/tree/main/examples/flax/vision\n\nUse `alpa.parallelize` to parallelize the training loop.\n\n# Image Classification training examples\n\nThe following example showcases how to train/fine-tune `ViT` for image-classification using the JAX/Flax backend.\n\nJAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.\nModels written in JAX/Flax are **immutable** and updated in a purely functional\nway which enables simple and efficient model parallelism.\n\n\nIn this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.\n\n## Prepare the dataset\n\nWe will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).\n\n\n### Download and extract the data.\n\n```bash\nwget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz\ntar -xvzf imagenette2.tgz\n```\n\nThis will create a `imagenette2` dir with two subdirectories `train` and `val` each with multiple subdirectories per class. The training script expects the following directory structure\n\n```bash\nroot/dog/xxx.png\nroot/dog/xxy.png\nroot/dog/[...]/xxz.png\n\nroot/cat/123.png\nroot/cat/nsdf3.png\nroot/cat/[...]/asd932_.png\n```\n\n### Train the model\n\nFinally, we can run the example script to pretrain the model:\n\n#### Launch a Ray cluster\n1. Use the command below to launch ray on a head node  \n  ```ray start --head```\n2. (Optional) If you have more nodes, connect them to the head node. The command should look like this, but with the ip address and password printed by the previous command.   \n  ```ray start --address='172.31.34.216:6379' --redis-password='5241590000000000'```\n\n##### Run\n```bash\npython run_image_classification.py \\\n    --output_dir ./vit-base-patch16-imagenette \\\n    --model_name_or_path google/vit-base-patch16-224-in21k \\\n    --train_dir=\"imagenette2/train\" \\\n    --validation_dir=\"imagenette2/val\" \\\n    --num_train_epochs 5 \\\n    --num_micro_batches 2 \\\n    --learning_rate 1e-3 \\\n    --per_device_train_batch_size 64 \\\n    --per_device_eval_batch_size 64 \\\n    --overwrite_output_dir \\\n    --preprocessing_num_workers 32 \\\n```\nTraining should converge at a loss of 0.0614 and validation accuracy of ~98% after 5 epochs. This should take ~7 minutes on a single machine with 2 P100 GPUs. Training statistics can be accessed on https://tensorboard.dev/experiment/3Vz06C4xQKaqaHENFeIrGg/"
  },
  {
    "path": "examples/ViT/run_image_classification.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2021 The HuggingFace Team All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPre-training/Fine-tuning ViT for image classification .\nHere is the full list of checkpoints on the hub that can be fine-tuned by this script:\nhttps://huggingface.co/models?filter=vit\n\"\"\"\n\nimport logging\nimport os\nimport sys\nimport time\nfrom dataclasses import asdict, dataclass, field\nfrom enum import Enum\nfrom pathlib import Path\nfrom typing import Callable, Optional\n\n# for dataset and preprocessing\nimport torch\nimport torchvision\nimport torchvision.transforms as transforms\nfrom tqdm import tqdm\n\nimport alpa\nfrom alpa.model.model_util import TrainState\nimport jax\nimport jax.numpy as jnp\nimport optax\nimport transformers\nfrom flax.training.common_utils import onehot\nfrom huggingface_hub import Repository\nfrom transformers import (\n    CONFIG_MAPPING,\n    FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,\n    AutoConfig,\n    FlaxAutoModelForImageClassification,\n    HfArgumentParser,\n    is_tensorboard_available,\n    set_seed,\n)\nfrom transformers.utils import get_full_repo_name, send_example_telemetry\n\nalpa.init(cluster=\"ray\")\nlogger = logging.getLogger(__name__)\n\n\nMODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())\nMODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n\n\n@dataclass\nclass TrainingArguments:\n    output_dir: str = field(\n        metadata={\"help\": \"The output directory where the model predictions and checkpoints will be written.\"},\n    )\n    overwrite_output_dir: bool = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"Overwrite the content of the output directory. \"\n                \"Use this to continue training if output_dir points to a checkpoint directory.\"\n            )\n        },\n    )\n    do_train: bool = field(default=False, metadata={\"help\": \"Whether to run training.\"})\n    do_eval: bool = field(default=False, metadata={\"help\": \"Whether to run eval on the dev set.\"})\n    num_micro_batches: int = field(default=1, metadata={\"help\": \"The number of micro batches for gradient accumulation.\"})\n    per_device_train_batch_size: int = field(\n        default=8, metadata={\"help\": \"Batch size per GPU/TPU core/CPU for training.\"}\n    )\n    per_device_eval_batch_size: int = field(\n        default=8, metadata={\"help\": \"Batch size per GPU/TPU core/CPU for evaluation.\"}\n    )\n    learning_rate: float = field(default=5e-5, metadata={\"help\": \"The initial learning rate for AdamW.\"})\n    weight_decay: float = field(default=0.0, metadata={\"help\": \"Weight decay for AdamW if we apply some.\"})\n    adam_beta1: float = field(default=0.9, metadata={\"help\": \"Beta1 for AdamW optimizer\"})\n    adam_beta2: float = field(default=0.999, metadata={\"help\": \"Beta2 for AdamW optimizer\"})\n    adam_epsilon: float = field(default=1e-8, metadata={\"help\": \"Epsilon for AdamW optimizer.\"})\n    adafactor: bool = field(default=False, metadata={\"help\": \"Whether or not to replace AdamW by Adafactor.\"})\n    num_train_epochs: float = field(default=3.0, metadata={\"help\": \"Total number of training epochs to perform.\"})\n    warmup_steps: int = field(default=0, metadata={\"help\": \"Linear warmup over warmup_steps.\"})\n    logging_steps: int = field(default=500, metadata={\"help\": \"Log every X updates steps.\"})\n    save_steps: int = field(default=500, metadata={\"help\": \"Save checkpoint every X updates steps.\"})\n    eval_steps: int = field(default=None, metadata={\"help\": \"Run an evaluation every X steps.\"})\n    seed: int = field(default=42, metadata={\"help\": \"Random seed that will be set at the beginning of training.\"})\n    push_to_hub: bool = field(\n        default=False, metadata={\"help\": \"Whether or not to upload the trained model to the model hub after training.\"}\n    )\n    hub_model_id: str = field(\n        default=None, metadata={\"help\": \"The name of the repository to keep in sync with the local `output_dir`.\"}\n    )\n    hub_token: str = field(default=None, metadata={\"help\": \"The token to use to push to the Model Hub.\"})\n\n    def __post_init__(self):\n        if self.output_dir is not None:\n            self.output_dir = os.path.expanduser(self.output_dir)\n\n    def to_dict(self):\n        \"\"\"\n        Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates\n        the token values by removing their value.\n        \"\"\"\n        d = asdict(self)\n        for k, v in d.items():\n            if isinstance(v, Enum):\n                d[k] = v.value\n            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):\n                d[k] = [x.value for x in v]\n            if k.endswith(\"_token\"):\n                d[k] = f\"<{k.upper()}>\"\n        return d\n\n\n@dataclass\nclass ModelArguments:\n    \"\"\"\n    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n    \"\"\"\n\n    model_name_or_path: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch.\"\n            )\n        },\n    )\n    model_type: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n    )\n    config_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n    )\n    cache_dir: Optional[str] = field(\n        default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n    )\n    dtype: Optional[str] = field(\n        default=\"float32\",\n        metadata={\n            \"help\": (\n                \"Floating-point format in which the model weights should be initialized and trained. Choose one of\"\n                \" `[float32, float16, bfloat16]`.\"\n            )\n        },\n    )\n    use_auth_token: bool = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"Will use the token generated when running `huggingface-cli login` (necessary to use this script \"\n                \"with private models).\"\n            )\n        },\n    )\n\n@dataclass\nclass DataTrainingArguments:\n    \"\"\"\n    Arguments pertaining to what data we are going to input our model for training and eval.\n    \"\"\"\n\n    train_dir: str = field(\n        metadata={\"help\": \"Path to the root training directory which contains one subdirectory per class.\"}\n    )\n    validation_dir: str = field(\n        metadata={\"help\": \"Path to the root validation directory which contains one subdirectory per class.\"},\n    )\n    image_size: Optional[int] = field(default=224, metadata={\"help\": \" The size (resolution) of each image.\"})\n    max_train_samples: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n                \"value if set.\"\n            )\n        },\n    )\n    max_eval_samples: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n                \"value if set.\"\n            )\n        },\n    )\n    overwrite_cache: bool = field(\n        default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n    )\n    preprocessing_num_workers: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n    )\n\n\ndef write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):\n    summary_writer.scalar(\"train_time\", train_time, step)\n\n    train_metrics = alpa.util.get_metrics(train_metrics)\n    for key, vals in train_metrics.items():\n        tag = f\"train_{key}\"\n        for i, val in enumerate(vals):\n            summary_writer.scalar(tag, val, step - len(vals) + i + 1)\n\n    for metric_name, value in eval_metrics.items():\n        summary_writer.scalar(f\"eval_{metric_name}\", value, step)\n\n\ndef create_learning_rate_fn(\n    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float\n) -> Callable[[int], jnp.array]:\n    \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n    steps_per_epoch = train_ds_size // train_batch_size\n    num_train_steps = steps_per_epoch * num_train_epochs\n    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n    decay_fn = optax.linear_schedule(\n        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps\n    )\n    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n    return schedule_fn\n\n\ndef main():\n    # See all possible arguments in src/transformers/training_args.py\n    # or by passing the --help flag to this script.\n    # We now keep distinct sets of args, for a cleaner separation of concerns.\n\n    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))\n    if len(sys.argv) == 2 and sys.argv[1].endswith(\".json\"):\n        # If we pass only one argument to the script and it's the path to a json file,\n        # let's parse it to get our arguments.\n        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))\n    else:\n        model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n\n    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The\n    # information sent is the one passed as arguments along with your Python/PyTorch versions.\n    send_example_telemetry(\"run_image_classification\", model_args, data_args, framework=\"flax\")\n\n    if (\n        os.path.exists(training_args.output_dir)\n        and os.listdir(training_args.output_dir)\n        and training_args.do_train\n        and not training_args.overwrite_output_dir\n    ):\n        raise ValueError(\n            f\"Output directory ({training_args.output_dir}) already exists and is not empty.\"\n            \"Use --overwrite_output_dir to overcome.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    # Setup logging, we only want one process per machine to log things on the screen.\n    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)\n    if jax.process_index() == 0:\n        transformers.utils.logging.set_verbosity_info()\n    else:\n        transformers.utils.logging.set_verbosity_error()\n\n    # Set the verbosity to info of the Transformers logger (on main process only):\n    logger.info(f\"Training/evaluation parameters {training_args}\")\n\n    # set seed for random transforms and torch dataloaders\n    set_seed(training_args.seed)\n\n    # Handle the repository creation\n    if training_args.push_to_hub:\n        if training_args.hub_model_id is None:\n            repo_name = get_full_repo_name(\n                Path(training_args.output_dir).absolute().name, token=training_args.hub_token\n            )\n        else:\n            repo_name = training_args.hub_model_id\n        repo = Repository(training_args.output_dir, clone_from=repo_name)\n\n    # Initialize datasets and pre-processing transforms\n    # We use torchvision here for faster pre-processing\n    # Note that here we are using some default pre-processing, for maximum accuray\n    # one should tune this part and carefully select what transformations to use.\n    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n    train_dataset = torchvision.datasets.ImageFolder(\n        data_args.train_dir,\n        transforms.Compose(\n            [\n                transforms.RandomResizedCrop(data_args.image_size),\n                transforms.RandomHorizontalFlip(),\n                transforms.ToTensor(),\n                normalize,\n            ]\n        ),\n    )\n\n    eval_dataset = torchvision.datasets.ImageFolder(\n        data_args.validation_dir,\n        transforms.Compose(\n            [\n                transforms.Resize(data_args.image_size),\n                transforms.CenterCrop(data_args.image_size),\n                transforms.ToTensor(),\n                normalize,\n            ]\n        ),\n    )\n\n    # Load pretrained model and tokenizer\n    if model_args.config_name:\n        config = AutoConfig.from_pretrained(\n            model_args.config_name,\n            num_labels=len(train_dataset.classes),\n            image_size=data_args.image_size,\n            cache_dir=model_args.cache_dir,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    elif model_args.model_name_or_path:\n        config = AutoConfig.from_pretrained(\n            model_args.model_name_or_path,\n            num_labels=len(train_dataset.classes),\n            image_size=data_args.image_size,\n            cache_dir=model_args.cache_dir,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    else:\n        config = CONFIG_MAPPING[model_args.model_type]()\n        logger.warning(\"You are instantiating a new config instance from scratch.\")\n\n    if model_args.model_name_or_path:\n        model = FlaxAutoModelForImageClassification.from_pretrained(\n            model_args.model_name_or_path,\n            config=config,\n            seed=training_args.seed,\n            dtype=getattr(jnp, model_args.dtype),\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    else:\n        model = FlaxAutoModelForImageClassification.from_config(\n            config,\n            seed=training_args.seed,\n            dtype=getattr(jnp, model_args.dtype),\n        )\n\n    # Store some constant\n    num_epochs = int(training_args.num_train_epochs)\n    train_batch_size = int(training_args.per_device_train_batch_size) * alpa.get_global_num_devices()\n    eval_batch_size = int(training_args.per_device_eval_batch_size) * alpa.get_global_num_devices()\n    steps_per_epoch = len(train_dataset) // train_batch_size\n    total_train_steps = steps_per_epoch * num_epochs\n\n    def collate_fn(examples):\n        pixel_values = torch.stack([example[0] for example in examples])\n        labels = torch.tensor([example[1] for example in examples])\n\n        batch = {\"pixel_values\": pixel_values, \"labels\": labels}\n        batch = {k: v.numpy() for k, v in batch.items()}\n\n        return batch\n\n    # Create data loaders\n    train_loader = torch.utils.data.DataLoader(\n        train_dataset,\n        batch_size=train_batch_size,\n        shuffle=True,\n        num_workers=data_args.preprocessing_num_workers,\n        persistent_workers=True,\n        drop_last=True,\n        collate_fn=collate_fn,\n    )\n\n    eval_loader = torch.utils.data.DataLoader(\n        eval_dataset,\n        batch_size=eval_batch_size,\n        shuffle=False,\n        num_workers=data_args.preprocessing_num_workers,\n        persistent_workers=True,\n        drop_last=False,\n        collate_fn=collate_fn,\n    )\n\n    # Enable tensorboard only on the master node\n    has_tensorboard = is_tensorboard_available()\n    if has_tensorboard and jax.process_index() == 0:\n        try:\n            from flax.metrics.tensorboard import SummaryWriter\n\n            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))\n        except ImportError as ie:\n            has_tensorboard = False\n            logger.warning(\n                f\"Unable to display metrics through TensorBoard because some package are not installed: {ie}\"\n            )\n    else:\n        logger.warning(\n            \"Unable to display metrics through TensorBoard because the package is not installed: \"\n            \"Please run pip install tensorboard to enable.\"\n        )\n\n    # Initialize our training\n    rng = jax.random.PRNGKey(training_args.seed)\n    rng, dropout_rng = jax.random.split(rng)\n\n    # Create learning rate schedule\n    linear_decay_lr_schedule_fn = create_learning_rate_fn(\n        len(train_dataset),\n        train_batch_size,\n        training_args.num_train_epochs,\n        training_args.warmup_steps,\n        training_args.learning_rate,\n    )\n\n    # create adam optimizer\n    adamw = optax.adamw(\n        learning_rate=linear_decay_lr_schedule_fn,\n        b1=training_args.adam_beta1,\n        b2=training_args.adam_beta2,\n        eps=training_args.adam_epsilon,\n        weight_decay=training_args.weight_decay,\n    )\n\n    # Setup train state\n    state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dynamic_scale=None)\n\n    def loss_fn(logits, labels):\n        loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))\n        return loss.mean()\n\n    # Define gradient update step fn\n    def train_step(state, batch):\n\n        def compute_loss(params):\n            labels = batch.pop(\"labels\")\n            logits = state.apply_fn(**batch, params=params, train=True)[0]\n            loss = loss_fn(logits, labels)\n            return loss\n\n        grad_fn = alpa.value_and_grad(compute_loss)\n        loss, grad = grad_fn(state.params)\n        new_state = state.apply_gradients(grads=grad)\n\n        metrics = {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}\n\n        return new_state, metrics\n\n    # Define eval fn\n    def eval_step(params, batch):\n        labels = batch.pop(\"labels\")\n        logits = model(**batch, params=params, train=False)[0]\n        loss = loss_fn(logits, labels)\n\n        # summarize metrics\n        accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()\n        metrics = {\"loss\": loss, \"accuracy\": accuracy}\n        return metrics\n\n    # Create parallel version of the train and eval step\n    method = alpa.Zero2Parallel()\n    p_train_step = alpa.parallelize(train_step,\n                                    method=method,\n                                    donate_argnums=(0,))\n    p_eval_step = alpa.parallelize(eval_step)\n    dump_debug_info_train_step = dump_debug_info_eval_step = True\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {num_epochs}\")\n    logger.info(f\"  Instantaneous batch size per device = {training_args.per_device_train_batch_size}\")\n    logger.info(f\"  Total train batch size (w. parallel & distributed) = {train_batch_size}\")\n    logger.info(f\"  Total optimization steps = {total_train_steps}\")\n\n    train_time = 0\n    last_time = time.time()\n    epochs = tqdm(range(num_epochs), desc=f\"Epoch ... (1/{num_epochs})\", position=0)\n\n    for epoch in epochs:\n        # ======================== Training ================================\n        train_start = time.time()\n\n        # Create sampling rng\n        rng, input_rng = jax.random.split(rng)\n        train_metrics = []\n\n        steps_per_epoch = len(train_dataset) // train_batch_size\n        train_step_progress_bar = tqdm(total=steps_per_epoch, desc=\"Training...\", position=1, leave=False)\n        # train\n        for step, batch in enumerate(train_loader):\n            state, train_metric = p_train_step(state, batch)\n            train_metrics.append(train_metric)\n\n            cur_step = epoch * (len(train_dataset) // train_batch_size) + step\n\n            if dump_debug_info_train_step:\n                dump_debug_info_train_step = False\n                executable = p_train_step.get_last_executable()\n                executable.sync()\n                executable.dump_debug_info(\"alpa_debug_info\")\n                epochs.write(f\"Initial compilation completed. \"\n                             f\"Time elapsed: {time.time() - train_start:.2f} s\")\n                             \n            train_step_progress_bar.update(1)\n\n        latency = time.time() - last_time\n        images_per_second = len(train_dataset) / latency\n        train_time += time.time() - train_start\n        last_time = time.time()\n\n        train_step_progress_bar.close()\n        epochs.write(\n            f\"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:\"\n            f\" {train_metric['learning_rate']}), \"\n            f\"Throughput: {images_per_second:.2f} images/s\"\n        )\n\n        # ======================== Evaluating ==============================\n        eval_metrics = []\n        eval_steps = max(len(eval_dataset) // eval_batch_size, 1)\n        eval_step_progress_bar = tqdm(total=eval_steps, desc=\"Evaluating...\", position=2, leave=False)\n        for batch in eval_loader:\n            # Model forward\n            metrics = p_eval_step(state.params, batch)\n            eval_metrics.append(metrics)\n\n            if dump_debug_info_eval_step:\n                dump_debug_info_eval_step = False\n                executable = p_eval_step.get_last_executable()\n                executable.dump_debug_info(\"alpa_debug_info\")\n\n            eval_step_progress_bar.update(1)\n\n        # normalize eval metrics\n        eval_metrics = alpa.util.get_metrics(eval_metrics)\n        eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n\n        # Print metrics and update progress bar\n        eval_step_progress_bar.close()\n        desc = (\n            f\"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | \"\n            f\"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})\"\n        )\n        epochs.write(desc)\n        epochs.desc = desc\n\n        # Save metrics\n        if has_tensorboard and jax.process_index() == 0:\n            cur_step = epoch * (len(train_dataset) // train_batch_size)\n            write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)\n\n        # save checkpoint after each epoch and push checkpoint to the hub\n        if jax.process_index() == 0:\n            alpa.prefetch(state.params)\n            params = alpa.util.map_to_nparray(state.params)\n            model.save_pretrained(training_args.output_dir, params=params)\n            if training_args.push_to_hub:\n                repo.push_to_hub(commit_message=f\"Saving weights and logs of step {cur_step}\", blocking=False)\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "examples/__init__.py",
    "content": ""
  },
  {
    "path": "examples/gpt2/README.md",
    "content": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n--------------------------------------------------------------------------------\n\nAdopted from https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling\n\nUse `alpa.parallelize` to parallelize the training loop.\n\n--------------------------------------------------------------------------------\n\n# Language model training examples\n\nThe following example showcases how to train a language model from scratch \nusing the JAX/Flax backend.\n\nJAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.\nModels written in JAX/Flax are **immutable** and updated in a purely functional\nway which enables simple and efficient model parallelism.\n\n## Causal language modeling\n\nIn the following, we demonstrate how to train an auto-regressive causal transformer model \nin JAX/Flax.\nMore specifically, we pretrain a randomely initialized [**`gpt2`**](https://huggingface.co/gpt2) model in Norwegian\nto pre-train 124M [**`gpt2`**](https://huggingface.co/gpt2)\nin Norwegian.\n\nThe example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.\n\n\nTo setup all relevant files for training, let's create a directory.\n\n```bash\nmkdir ./norwegian-gpt2\n```\n\n### Train tokenizer\n\nIn the first step, we train a tokenizer to efficiently process the text input for the model. Similar to how it is shown in [How to train a new language model from scratch using Transformers and Tokenizers](https://huggingface.co/blog/how-to-train), we use a **`ByteLevelBPETokenizer`**.\nThe tokenizer is trained on the complete Norwegian dataset of OSCAR\nand consequently saved in the cloned model directory.\nThis can take up to 10 minutes depending on your hardware ☕.\n\n```python\nfrom datasets import load_dataset\nfrom tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n\n# load dataset\ndataset = load_dataset(\"oscar\", \"unshuffled_deduplicated_no\", split=\"train\")\n\n# Instantiate tokenizer\ntokenizer = ByteLevelBPETokenizer()\n\ndef batch_iterator(batch_size=1000):\n    for i in range(0, len(dataset), batch_size):\n        yield dataset[i: i + batch_size][\"text\"]\n\n# Customized training\ntokenizer.train_from_iterator(batch_iterator(), vocab_size=50256, min_frequency=2, special_tokens=[\n    \"<s>\",\n    \"<pad>\",\n    \"</s>\",\n    \"<unk>\",\n    \"<mask>\",\n])\n\n# Save files to disk\ntokenizer.save(\"./norwegian-gpt2/tokenizer.json\")\n```\n\n### Create configuration\n\nNext, we create the model's configuration file. This is as simple \nas loading and storing [`**gpt2**`](https://huggingface.co/gpt2)\nin the local model folder:\n\n```python\nfrom transformers import GPT2Config\n\nconfig = GPT2Config.from_pretrained(\"gpt2\", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50256)\nconfig.save_pretrained(\"./norwegian-gpt2\")\n```\n\nGreat, we have set up our model repository. During training, we will now automatically\npush the training logs and model weights to the repo.\n\n### Train model\n\nFinally, we can run the example script to pretrain the model:\n\n#### Launch a Ray cluster\n1. Use the command below to launch ray on a head node  \n  ```ray start --head```\n2. (Optional) If you have more nodes, connect them to the head node. The command should look like this, but with the ip address and password printed by the previous command.   \n  ```ray start --address='172.31.34.216:6379' --redis-password='5241590000000000'```\n\n##### Run\n```bash\npython3 run_clm_flax.py \\\n    --output_dir=\"./norwegian-gpt2\" \\\n    --model_type=\"gpt2\" \\\n    --config_name=\"./norwegian-gpt2\" \\\n    --tokenizer_name=\"./norwegian-gpt2\" \\\n    --dataset_name=\"oscar\" \\\n    --dataset_config_name=\"unshuffled_deduplicated_no\" \\\n    --do_train --do_eval \\\n    --block_size=\"512\" \\\n    --per_device_train_batch_size=\"96\" \\\n    --per_device_eval_batch_size=\"96\" \\\n    --num_micro_batches=\"4\" \\\n    --dtype=\"float16\" \\\n    --learning_rate=\"1e-3\" --warmup_steps=\"1000\" \\\n    --adam_beta1=\"0.9\" --adam_beta2=\"0.98\" --weight_decay=\"0.01\" \\\n    --overwrite_output_dir \\\n    --num_train_epochs=\"20\" \\\n    --logging_steps=\"100\" \\\n    --save_steps=\"2500\" \\\n    --eval_steps=\"2500\"\n```\n\nTraining should converge at a loss and perplexity \nof 3.24 and 25.72 respectively after 20 epochs\nThis should take less than ~21 hours on a single TPUv3-8 or a machine with 8 V100 GPUs.\nTraining statistics can be accessed on [tfhub.de](https://tensorboard.dev/experiment/2zEhLwJ0Qp2FAkI3WVH9qA).\n\nFor a step-by-step walkthrough of how to do causal language modeling in Flax, please have a \nlook at [this](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/causal_language_modeling_flax.ipynb) google colab.\n"
  },
  {
    "path": "examples/gpt2/create_config.py",
    "content": "from transformers import GPT2Config\n\nconfig = GPT2Config.from_pretrained(\"gpt2\", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50256)\nconfig.save_pretrained(\"./norwegian-gpt2\")\n"
  },
  {
    "path": "examples/gpt2/run_clm_flax.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2021 The HuggingFace Team All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.\n\nHere is the full list of checkpoints on the hub that can be fine-tuned by this script:\nhttps://huggingface.co/models?filter=text-generation\n\"\"\"\n# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.\n\nimport json\nimport logging\nimport math\nimport os\nimport sys\nimport time\nfrom dataclasses import asdict, dataclass, field\nfrom enum import Enum\nimport functools\nfrom itertools import chain\nfrom pathlib import Path\nfrom typing import Callable, Optional\n\nimport datasets\nimport numpy as np\nfrom datasets import Dataset, load_dataset\nfrom tqdm import tqdm\n\nimport alpa\nfrom alpa.model.model_util import DynamicScale, TrainState\nimport jax\nimport jax.numpy as jnp\nimport optax\nimport transformers\nimport tensorflow as tf\nfrom flax import jax_utils, traverse_util\nfrom flax.training import train_state\nfrom flax.training.common_utils import onehot, shard, shard_prng_key\nfrom huggingface_hub import Repository\nfrom transformers import (\n    CONFIG_MAPPING,\n    FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n    AutoConfig,\n    AutoTokenizer,\n    FlaxAutoModelForCausalLM,\n    HfArgumentParser,\n    is_tensorboard_available,\n    set_seed,\n)\nfrom transformers.testing_utils import CaptureLogger\nfrom transformers.utils import get_full_repo_name, send_example_telemetry\n\nalpa.init(cluster=\"ray\")\ntf.config.experimental.set_visible_devices([], 'GPU')\n\nlogger = logging.getLogger(__name__)\n\nMODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())\nMODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n\n\n@dataclass\nclass TrainingArguments:\n    output_dir: str = field(\n        metadata={\"help\": \"The output directory where the model predictions and checkpoints will be written.\"},\n    )\n    overwrite_output_dir: bool = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"Overwrite the content of the output directory. \"\n                \"Use this to continue training if output_dir points to a checkpoint directory.\"\n            )\n        },\n    )\n    do_train: bool = field(default=False, metadata={\"help\": \"Whether to run training.\"})\n    do_eval: bool = field(default=False, metadata={\"help\": \"Whether to run eval on the dev set.\"})\n    per_device_train_batch_size: int = field(\n        default=8, metadata={\"help\": \"Batch size per GPU/TPU core/CPU for training.\"}\n    )\n    per_device_eval_batch_size: int = field(\n        default=8, metadata={\"help\": \"Batch size per GPU/TPU core/CPU for evaluation.\"}\n    )\n    num_micro_batches: int = field(default=1, metadata={\"help\": \"The number of micro batches for gradient accumulation.\"})\n    learning_rate: float = field(default=5e-5, metadata={\"help\": \"The initial learning rate for AdamW.\"})\n    weight_decay: float = field(default=0.0, metadata={\"help\": \"Weight decay for AdamW if we apply some.\"})\n    adam_beta1: float = field(default=0.9, metadata={\"help\": \"Beta1 for AdamW optimizer\"})\n    adam_beta2: float = field(default=0.999, metadata={\"help\": \"Beta2 for AdamW optimizer\"})\n    adam_epsilon: float = field(default=1e-8, metadata={\"help\": \"Epsilon for AdamW optimizer.\"})\n    adafactor: bool = field(default=False, metadata={\"help\": \"Whether or not to replace AdamW by Adafactor.\"})\n    num_train_epochs: float = field(default=3.0, metadata={\"help\": \"Total number of training epochs to perform.\"})\n    warmup_steps: int = field(default=0, metadata={\"help\": \"Linear warmup over warmup_steps.\"})\n    logging_steps: int = field(default=500, metadata={\"help\": \"Log every X updates steps.\"})\n    save_steps: int = field(default=500, metadata={\"help\": \"Save checkpoint every X updates steps.\"})\n    eval_steps: int = field(default=None, metadata={\"help\": \"Run an evaluation every X steps.\"})\n    seed: int = field(default=42, metadata={\"help\": \"Random seed that will be set at the beginning of training.\"})\n    push_to_hub: bool = field(\n        default=False, metadata={\"help\": \"Whether or not to upload the trained model to the model hub after training.\"}\n    )\n    hub_model_id: str = field(\n        default=None, metadata={\"help\": \"The name of the repository to keep in sync with the local `output_dir`.\"}\n    )\n    hub_token: str = field(default=None, metadata={\"help\": \"The token to use to push to the Model Hub.\"})\n\n    def __post_init__(self):\n        if self.output_dir is not None:\n            self.output_dir = os.path.expanduser(self.output_dir)\n\n    def to_dict(self):\n        \"\"\"\n        Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates\n        the token values by removing their value.\n        \"\"\"\n        d = asdict(self)\n        for k, v in d.items():\n            if isinstance(v, Enum):\n                d[k] = v.value\n            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):\n                d[k] = [x.value for x in v]\n            if k.endswith(\"_token\"):\n                d[k] = f\"<{k.upper()}>\"\n        return d\n\n\n@dataclass\nclass ModelArguments:\n    \"\"\"\n    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n    \"\"\"\n\n    model_name_or_path: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch.\"\n            )\n        },\n    )\n    model_type: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n    )\n    config_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n    )\n    tokenizer_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n    )\n    cache_dir: Optional[str] = field(\n        default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n    )\n    use_fast_tokenizer: bool = field(\n        default=True,\n        metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n    )\n    dtype: Optional[str] = field(\n        default=\"float32\",\n        metadata={\n            \"help\": (\n                \"Floating-point format in which the model weights should be initialized and trained. Choose one of\"\n                \" `[float32, float16, bfloat16]`.\"\n            )\n        },\n    )\n    use_auth_token: bool = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"Will use the token generated when running `transformers-cli login` (necessary to use this script \"\n                \"with private models).\"\n            )\n        },\n    )\n\n\n@dataclass\nclass DataTrainingArguments:\n    \"\"\"\n    Arguments pertaining to what data we are going to input our model for training and eval.\n    \"\"\"\n\n    dataset_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"The name of the dataset to use (via the datasets library).\"}\n    )\n    dataset_config_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n    )\n    train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n    validation_file: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n    )\n    max_train_samples: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n                \"value if set.\"\n            )\n        },\n    )\n    max_eval_samples: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n                \"value if set.\"\n            )\n        },\n    )\n    overwrite_cache: bool = field(\n        default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n    )\n    validation_split_percentage: Optional[int] = field(\n        default=5,\n        metadata={\n            \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n        },\n    )\n    block_size: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"Optional input sequence length after tokenization. \"\n                \"The training dataset will be truncated in block of this size for training. \"\n                \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n            )\n        },\n    )\n    overwrite_cache: bool = field(\n        default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n    )\n    preprocessing_num_workers: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n    )\n    keep_linebreaks: bool = field(\n        default=True, metadata={\"help\": \"Whether to keep line breaks when using TXT files or not.\"}\n    )\n\n    def __post_init__(self):\n        if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n            raise ValueError(\"Need either a dataset name or a training/validation file.\")\n        else:\n            if self.train_file is not None:\n                extension = self.train_file.split(\".\")[-1]\n                assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n            if self.validation_file is not None:\n                extension = self.validation_file.split(\".\")[-1]\n                assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n\n\ndef data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int,\n                min_batch_size: int, shuffle: bool = False):\n    \"\"\"\n    Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n    Shuffle batches if `shuffle` is `True`.\n    \"\"\"\n    if len(dataset) < batch_size:\n        assert len(dataset) >= min_batch_size\n        batch_size = len(dataset) // min_batch_size * min_batch_size\n\n    data_collator = transformers.DefaultDataCollator(\"np\")\n    tf_dataset = dataset.to_tf_dataset(batch_size=batch_size,\n                                       columns=dataset.column_names,\n                                       collate_fn=data_collator,\n                                       shuffle=shuffle,\n                                       drop_remainder=True)\n\n    for batch in tf_dataset:\n        batch = {k: v._numpy() for k, v in batch.items()}\n        yield batch\n\n\ndef write_train_metric(summary_writer, train_metrics, train_time, step):\n    summary_writer.scalar(\"train_time\", train_time, step)\n\n    train_metrics = alpa.util.get_metrics(train_metrics)\n    for key, vals in train_metrics.items():\n        tag = f\"train_{key}\"\n        for i, val in enumerate(vals):\n            summary_writer.scalar(tag, val, step - len(vals) + i + 1)\n\n\ndef write_eval_metric(summary_writer, eval_metrics, step):\n    for metric_name, value in eval_metrics.items():\n        summary_writer.scalar(f\"eval_{metric_name}\", value, step)\n\n\ndef create_learning_rate_fn(\n    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float\n) -> Callable[[int], jnp.array]:\n    \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n    steps_per_epoch = train_ds_size // train_batch_size\n    num_train_steps = steps_per_epoch * num_train_epochs\n    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n    decay_fn = optax.linear_schedule(\n        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps\n    )\n    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n    return schedule_fn\n\n\ndef main():\n    # See all possible arguments in src/transformers/training_args.py\n    # or by passing the --help flag to this script.\n    # We now keep distinct sets of args, for a cleaner separation of concerns.\n\n    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))\n    if len(sys.argv) == 2 and sys.argv[1].endswith(\".json\"):\n        # If we pass only one argument to the script and it's the path to a json file,\n        # let's parse it to get our arguments.\n        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))\n    else:\n        model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n\n    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The\n    # information sent is the one passed as arguments along with your Python/PyTorch versions.\n    send_example_telemetry(\"run_clm\", model_args, data_args, framework=\"flax\")\n\n    if (\n        os.path.exists(training_args.output_dir)\n        and os.listdir(training_args.output_dir)\n        and training_args.do_train\n        and not training_args.overwrite_output_dir\n    ):\n        raise ValueError(\n            f\"Output directory ({training_args.output_dir}) already exists and is not empty.\"\n            \"Use --overwrite_output_dir to overcome.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    # Setup logging, we only want one process per machine to log things on the screen.\n    logger.setLevel(logging.INFO)\n    datasets.utils.logging.set_verbosity_warning()\n    transformers.utils.logging.set_verbosity_info()\n\n    # Set the verbosity to info of the Transformers logger (on main process only):\n    logger.info(f\"Training/evaluation parameters {training_args}\")\n\n    # Set seed before initializing model.\n    set_seed(training_args.seed)\n\n    # Handle the repository creation\n    if training_args.push_to_hub:\n        if training_args.hub_model_id is None:\n            repo_name = get_full_repo_name(\n                Path(training_args.output_dir).absolute().name, token=training_args.hub_token\n            )\n        else:\n            repo_name = training_args.hub_model_id\n        repo = Repository(training_args.output_dir, clone_from=repo_name)\n\n    #  Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n    # (the dataset will be downloaded automatically from the datasets Hub).\n    #\n    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n    # 'text' is found. You can easily tweak this behavior (see below).\n    #\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if data_args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            data_args.dataset_name,\n            data_args.dataset_config_name,\n            cache_dir=model_args.cache_dir,\n            keep_in_memory=False,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n\n        if \"validation\" not in dataset.keys():\n            dataset[\"validation\"] = load_dataset(\n                data_args.dataset_name,\n                data_args.dataset_config_name,\n                split=f\"train[:{data_args.validation_split_percentage}%]\",\n                cache_dir=model_args.cache_dir,\n                use_auth_token=True if model_args.use_auth_token else None,\n            )\n            dataset[\"train\"] = load_dataset(\n                data_args.dataset_name,\n                data_args.dataset_config_name,\n                split=f\"train[{data_args.validation_split_percentage}%:]\",\n                cache_dir=model_args.cache_dir,\n                use_auth_token=True if model_args.use_auth_token else None,\n            )\n    else:\n        data_files = {}\n        dataset_args = {}\n        if data_args.train_file is not None:\n            data_files[\"train\"] = data_args.train_file\n        if data_args.validation_file is not None:\n            data_files[\"validation\"] = data_args.validation_file\n        extension = data_args.train_file.split(\".\")[-1]\n        if extension == \"txt\":\n            extension = \"text\"\n            dataset_args[\"keep_linebreaks\"] = data_args.keep_linebreaks\n        dataset = load_dataset(\n            extension,\n            data_files=data_files,\n            cache_dir=model_args.cache_dir,\n            **dataset_args,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n\n        if \"validation\" not in dataset.keys():\n            dataset[\"validation\"] = load_dataset(\n                extension,\n                data_files=data_files,\n                split=f\"train[:{data_args.validation_split_percentage}%]\",\n                cache_dir=model_args.cache_dir,\n                **dataset_args,\n                use_auth_token=True if model_args.use_auth_token else None,\n            )\n            dataset[\"train\"] = load_dataset(\n                extension,\n                data_files=data_files,\n                split=f\"train[{data_args.validation_split_percentage}%:]\",\n                cache_dir=model_args.cache_dir,\n                **dataset_args,\n                use_auth_token=True if model_args.use_auth_token else None,\n            )\n    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n    # https://huggingface.co/docs/datasets/loading_datasets.html.\n\n    # Load pretrained model and tokenizer\n\n    # Distributed training:\n    # The .from_pretrained methods guarantee that only one local process can concurrently\n    # download model & vocab.\n    if model_args.config_name:\n        config = AutoConfig.from_pretrained(\n            model_args.config_name,\n            cache_dir=model_args.cache_dir,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    elif model_args.model_name_or_path:\n        config = AutoConfig.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=model_args.cache_dir,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    else:\n        config = CONFIG_MAPPING[model_args.model_type]()\n        logger.warning(\"You are instantiating a new config instance from scratch.\")\n\n    if model_args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_args.tokenizer_name,\n            cache_dir=model_args.cache_dir,\n            use_fast=model_args.use_fast_tokenizer,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    elif model_args.model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=model_args.cache_dir,\n            use_fast=model_args.use_fast_tokenizer,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    else:\n        raise ValueError(\n            \"You are instantiating a new tokenizer from scratch. This is not supported by this script.\"\n            \"You can do it from another script, save it, and load it from here, using --tokenizer_name.\"\n        )\n\n    if model_args.model_name_or_path:\n        model = FlaxAutoModelForCausalLM.from_pretrained(\n            model_args.model_name_or_path,\n            config=config,\n            seed=training_args.seed,\n            dtype=getattr(jnp, model_args.dtype),\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    else:\n        model = FlaxAutoModelForCausalLM.from_config(\n            config,\n            seed=training_args.seed,\n            dtype=getattr(jnp, model_args.dtype),\n        )\n\n    # Preprocessing the datasets.\n    # First we tokenize all the texts.\n    if training_args.do_train:\n        column_names = dataset[\"train\"].column_names\n    else:\n        column_names = dataset[\"validation\"].column_names\n    text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n\n    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function\n    tok_logger = transformers.utils.logging.get_logger(\"transformers.tokenization_utils_base\")\n\n    def tokenize_function(examples):\n        with CaptureLogger(tok_logger) as cl:\n            output = tokenizer(examples[text_column_name])\n        # clm input could be much much longer than block_size\n        if \"Token indices sequence length is longer than the\" in cl.out:\n            tok_logger.warning(\n                \"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits\"\n                \" before being passed to the model.\"\n            )\n        return output\n\n    logger.info(\"***** Tokenize dataset *****\")\n    tokenized_datasets = dataset.map(\n        tokenize_function,\n        batched=True,\n        num_proc=data_args.preprocessing_num_workers,\n        remove_columns=column_names,\n        load_from_cache_file=not data_args.overwrite_cache,\n    )\n\n    if data_args.block_size is None:\n        block_size = tokenizer.model_max_length\n        if block_size > config.max_position_embeddings:\n            logger.warning(\n                f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n                \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n            )\n            block_size = 1024\n    else:\n        if data_args.block_size > tokenizer.model_max_length:\n            logger.warning(\n                f\"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model\"\n                f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n            )\n        block_size = min(data_args.block_size, tokenizer.model_max_length)\n\n    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n    def group_texts(examples):\n        # Concatenate all texts.\n        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n        total_length = len(concatenated_examples[list(examples.keys())[0]])\n        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n        # customize this part to your needs.\n        if total_length >= block_size:\n            total_length = (total_length // block_size) * block_size\n        # Split by chunks of max_len.\n        result = {\n            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n            for k, t in concatenated_examples.items()\n        }\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder\n    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower\n    # to preprocess.\n    #\n    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n\n    logger.info(\"***** Build dataset *****\")\n    lm_datasets = tokenized_datasets.map(\n        group_texts,\n        batched=True,\n        num_proc=data_args.preprocessing_num_workers,\n        load_from_cache_file=not data_args.overwrite_cache,\n    )\n\n    if training_args.do_train:\n        if \"train\" not in tokenized_datasets:\n            raise ValueError(\"--do_train requires a train dataset\")\n        train_dataset = lm_datasets[\"train\"]\n        if data_args.max_train_samples is not None:\n            max_train_samples = min(len(train_dataset), data_args.max_train_samples)\n            train_dataset = train_dataset.select(range(max_train_samples))\n\n    if training_args.do_eval:\n        if \"validation\" not in tokenized_datasets:\n            raise ValueError(\"--do_eval requires a validation dataset\")\n        eval_dataset = lm_datasets[\"validation\"]\n        if data_args.max_eval_samples is not None:\n            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)\n            eval_dataset = eval_dataset.select(range(max_eval_samples))\n\n    # Enable tensorboard only on the master node\n    has_tensorboard = is_tensorboard_available()\n    if has_tensorboard:\n        try:\n            from flax.metrics.tensorboard import SummaryWriter\n\n            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))\n        except ImportError as ie:\n            has_tensorboard = False\n            logger.warning(\n                f\"Unable to display metrics through TensorBoard because some package are not installed: {ie}\"\n            )\n    else:\n        logger.warning(\n            \"Unable to display metrics through TensorBoard because the package is not installed: \"\n            \"Please run pip install tensorboard to enable.\"\n        )\n\n    # Initialize our training\n    rng = jax.random.PRNGKey(training_args.seed)\n    rng, dropout_rng = jax.random.split(rng)\n\n    # Store some constant\n    num_epochs = int(training_args.num_train_epochs)\n    train_batch_size = int(training_args.per_device_train_batch_size) * alpa.get_global_num_devices()\n    eval_batch_size = int(training_args.per_device_eval_batch_size) * alpa.get_global_num_devices()\n    steps_per_epoch = len(train_dataset) // train_batch_size\n    total_train_steps = steps_per_epoch * num_epochs\n\n    # Create learning rate schedule\n    linear_decay_lr_schedule_fn = create_learning_rate_fn(\n        len(train_dataset),\n        train_batch_size,\n        training_args.num_train_epochs,\n        training_args.warmup_steps,\n        training_args.learning_rate,\n    )\n\n    # We use Optax's \"masking\" functionality to not apply weight decay\n    # to bias and LayerNorm scale parameters. decay_mask_fn returns a\n    # mask boolean with the same structure as the parameters.\n    # The mask is True for parameters that should be decayed.\n    # Note that this mask is specifically adapted for FlaxGPT2.\n    # For other models, one should correct the layer norm parameter naming\n    # accordingly.\n    def decay_mask_fn(params):\n        flat_params = traverse_util.flatten_dict(params)\n        flat_mask = {\n            path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n            for path in flat_params\n        }\n        return traverse_util.unflatten_dict(flat_mask)\n\n    # create adam optimizer\n    if training_args.adafactor:\n        # We use the default parameters here to initialize adafactor,\n        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74\n        optimizer = optax.adafactor(\n            learning_rate=linear_decay_lr_schedule_fn,\n        )\n    else:\n        optimizer = optax.chain(\n            optax.clip_by_global_norm(1.0),\n            optax.adamw(\n                learning_rate=linear_decay_lr_schedule_fn,\n                b1=training_args.adam_beta1,\n                b2=training_args.adam_beta2,\n                eps=training_args.adam_epsilon,\n                weight_decay=training_args.weight_decay,\n                mask=decay_mask_fn)\n        )\n\n    # Setup train state\n    if model_args.dtype == \"float16\":\n        use_master_copy = True\n        dynamic_scale = DynamicScale()\n        # Fix a bug in huggingface's implementation (https://github.com/huggingface/transformers/pull/18462)\n        alpa.global_config.flax_always_use_fp16_embedding = True\n    else:\n        use_master_copy = dynamic_scale = None\n    state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer,\n                              dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)\n\n    def loss_fn(logits, labels):\n        shift_logits = logits[..., :-1, :]\n        shift_labels = labels[..., 1:]\n        loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))\n        return loss.mean()\n\n    # Define gradient update step fn\n    def train_step(state, batch):\n\n        def compute_loss(params):\n            labels = batch.pop(\"labels\")\n            logits = state.apply_fn(**batch, params=params, train=True)[0]\n            loss = loss_fn(logits, labels)\n            return loss\n\n        dynamic_scale = state.dynamic_scale\n        if dynamic_scale:\n            grad_fn = dynamic_scale.value_and_grad(compute_loss)\n            dynamic_scale, is_fin, loss, grads = grad_fn(state.params)\n        else:\n            grad_fn = alpa.value_and_grad(compute_loss)\n            loss, grads = grad_fn(state.params)\n\n        new_state = state.apply_gradients(grads=grads)\n\n        if dynamic_scale:\n            new_state = new_state.replace(\n                opt_state=jax.tree_map(\n                    functools.partial(jnp.where, is_fin),\n                    new_state.opt_state, state.opt_state),\n                params=jax.tree_map(\n                    functools.partial(jnp.where, is_fin),\n                    new_state.params, state.params),\n                master_copy=jax.tree_map(\n                    functools.partial(jnp.where, is_fin),\n                    new_state.master_copy, state.master_copy),\n                dynamic_scale=dynamic_scale)\n\n        metrics = {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}\n\n        return new_state, metrics\n\n    # Define eval fn\n    def eval_step(params, batch):\n        labels = batch.pop(\"labels\")\n        logits = model(**batch, params=params, train=False)[0]\n        loss = loss_fn(logits, labels)\n\n        # summarize metrics\n        metrics = {\"loss\": loss}\n        return metrics\n\n    # Create parallel version of the train and eval step\n    method = alpa.Zero2Parallel(num_micro_batches=training_args.num_micro_batches)\n    p_train_step = alpa.parallelize(train_step,\n                                    method=method,\n                                    donate_argnums=(0,))\n    p_eval_step = alpa.parallelize(eval_step)\n\n    min_batch_size = alpa.get_global_num_devices() * training_args.num_micro_batches\n    dump_debug_info_train_step = dump_debug_info_eval_step = True\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {num_epochs}\")\n    logger.info(f\"  Batch size per device (w. accumulation) = {training_args.per_device_train_batch_size}\")\n    logger.info(f\"  Global train batch size (w. parallel & distributed) = {train_batch_size}\")\n    logger.info(f\"  Total optimization steps = {total_train_steps}\")\n\n    train_time = 0\n    train_metrics = []\n    epochs = tqdm(range(num_epochs), desc=\"Epoch ... \", position=0)\n\n    step_ct = 0\n    last_time = time.time()\n\n    epochs.write(\"Initial compilation. This might take some minutes...\")\n\n    for epoch in epochs:\n        # ======================== Training ================================\n        train_start = time.time()\n\n        # Create sampling rng\n        rng, input_rng = jax.random.split(rng)\n\n        # Generate an epoch by shuffling sampling indices from the train dataset\n        train_loader = data_loader(input_rng, train_dataset, train_batch_size, min_batch_size, shuffle=True)\n        steps_per_epoch = len(train_dataset) // train_batch_size\n        # train\n        for step in tqdm(range(steps_per_epoch), desc=\"Training...\", position=1, leave=False):\n            batch = next(train_loader)\n            state, train_metric = p_train_step(state, batch)\n            train_metrics.append(train_metric)\n\n            cur_step = epoch * (len(train_dataset) // train_batch_size) + step\n\n            if dump_debug_info_train_step:\n                dump_debug_info_train_step = False\n                executable = p_train_step.get_last_executable()\n                executable.sync()\n                executable.dump_debug_info(\"alpa_debug_info\")\n                epochs.write(f\"Initial compilation completed. \"\n                             f\"Time elapsed: {time.time() - train_start:.2f} s\")\n\n            step_ct += 1\n            if cur_step % training_args.logging_steps == 0 and cur_step > 0:\n                executable.sync()\n                latency = (time.time() - last_time) / step_ct\n                throughput_tokens = np.prod(batch[\"input_ids\"].shape) / latency\n                throughput_tflops = alpa.util.compute_gpt_tflops(\n                    batch_size=batch[\"input_ids\"].shape[0],\n                    seq_len=batch[\"input_ids\"].shape[1],\n                    num_layers=config.num_hidden_layers,\n                    hidden_size=config.hidden_size,\n                    vocab_size=config.vocab_size,\n                    num_gpus=alpa.get_global_num_devices(),\n                    latency=latency)\n                step_ct = 0\n\n                #print(f\"driver latency: {latency:.2f}, \"\n                #      f\"worker latency: {executable.get_execution_time_costs()[-1]:.2f}\")\n\n                # Save metrics\n                train_time += time.time() - train_start\n                if has_tensorboard:\n                    write_train_metric(summary_writer, train_metrics, train_time, cur_step)\n\n                train_metric = jax.tree_map(np.mean, train_metric)\n\n                epochs.write(\n                    f\"Step... {cur_step} | \"\n                    f\"Loss: {train_metric['loss'].mean():.4f}, \"\n                    f\"Learning Rate: {train_metric['learning_rate'].mean():.5f}, \"\n                    f\"Throughput: {throughput_tokens:.2f} token/s, \"\n                    f\"{throughput_tflops:.2f} TFLOP/s\"\n                )\n\n                train_metrics = []\n                last_time = time.time()\n\n            if cur_step % training_args.eval_steps == 0 and cur_step > 0:\n                # ======================== Evaluating ==============================\n                eval_metrics = []\n                eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, min_batch_size)\n                eval_steps = max(len(eval_dataset) // eval_batch_size, 1)\n                for _ in tqdm(range(eval_steps), desc=\"Evaluating...\", position=2, leave=False):\n                    # Model forward\n                    batch = next(eval_loader)\n                    metrics = p_eval_step(state.params, batch)\n                    eval_metrics.append(metrics)\n\n                    if dump_debug_info_eval_step:\n                        dump_debug_info_eval_step = False\n                        executable = p_eval_step.get_last_executable()\n                        executable.dump_debug_info(\"alpa_debug_info\")\n\n                # normalize eval metrics\n                eval_metrics = alpa.util.get_metrics(eval_metrics)\n                eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n\n                try:\n                    eval_metrics[\"perplexity\"] = math.exp(eval_metrics[\"loss\"])\n                except OverflowError:\n                    eval_metrics[\"perplexity\"] = float(\"inf\")\n\n                # Print metrics and update progress bar\n                desc = (\n                    f\"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:\"\n                    f\" {eval_metrics['perplexity']})\"\n                )\n                epochs.write(desc)\n                epochs.desc = desc\n\n                # Save metrics\n                if has_tensorboard:\n                    write_eval_metric(summary_writer, eval_metrics, cur_step)\n\n            if cur_step % training_args.save_steps == 0 and cur_step > 0:\n                # save checkpoint after each epoch and push checkpoint to the hub\n                alpa.prefetch(state.params)\n                params = alpa.util.map_to_nparray(state.params)\n                model.save_pretrained(training_args.output_dir, params=params)\n                tokenizer.save_pretrained(training_args.output_dir)\n                if training_args.push_to_hub:\n                    repo.push_to_hub(commit_message=f\"Saving weights and logs of step {cur_step}\", blocking=False)\n\n    # Eval after training\n    if training_args.do_eval:\n        eval_metrics = []\n        eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, min_batch_size)\n        eval_steps = max(len(eval_dataset) // eval_batch_size, 1)\n        for _ in tqdm(range(eval_steps), desc=\"Evaluating...\", position=2, leave=False):\n            # Model forward\n            batch = next(eval_loader)\n            metrics = p_eval_step(state.params, batch)\n            eval_metrics.append(metrics)\n\n        # normalize eval metrics\n        eval_metrics = alpa.util.get_metrics(eval_metrics)\n        eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)\n\n        try:\n            eval_metrics[\"perplexity\"] = math.exp(eval_metrics[\"loss\"])\n        except OverflowError:\n            eval_metrics[\"perplexity\"] = float(\"inf\")\n\n        eval_metrics = {f\"eval_{metric_name}\": value for metric_name, value in eval_metrics.items()}\n        path = os.path.join(training_args.output_dir, \"eval_results.json\")\n        with open(path, \"w\") as f:\n            json.dump(eval_metrics, f, indent=4, sort_keys=True)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/gpt2/train_tokenizer.py",
    "content": "from datasets import load_dataset\nfrom tokenizers import trainers, Tokenizer, normalizers, ByteLevelBPETokenizer\n\n# load dataset\ndataset = load_dataset(\"oscar\", \"unshuffled_deduplicated_no\", split=\"train\")\n\n# Instantiate tokenizer\ntokenizer = ByteLevelBPETokenizer()\n\ndef batch_iterator(batch_size=1000):\n    for i in range(0, len(dataset), batch_size):\n        yield dataset[i: i + batch_size][\"text\"]\n\n# Customized training\ntokenizer.train_from_iterator(batch_iterator(), vocab_size=50256, min_frequency=2, special_tokens=[\n    \"<s>\",\n    \"<pad>\",\n    \"</s>\",\n    \"<unk>\",\n    \"<mask>\",\n])\n\n# Save files to disk\ntokenizer.save(\"./norwegian-gpt2/tokenizer.json\")\n"
  },
  {
    "path": "examples/imagenet/README.md",
    "content": "--------------------------------------------------------------------------------\n\nAdopted from https://github.com/google/flax/tree/main/examples/imagenet.\n\nUse `alpa.parallelize` to parallelize the training loop.\n\nQuick run:\n```\nray start --head\npython3 main.py --workdir=./imagenet --config=configs/v100_x8.py --config.batch_size 1024\n```\n\n--------------------------------------------------------------------------------\n\n## ImageNet classification\n\nTrains a ResNet50 model ([He *et al.*, 2016]) for the ImageNet classification task\n([Russakovsky *et al.*, 2015]).\n\nThis example uses linear learning rate warmup and cosine learning rate schedule.\n\n[He *et al.*, 2016]: https://arxiv.org/abs/1512.03385\n[Russakovsky *et al.*, 2015]: https://arxiv.org/abs/1409.0575\n\nYou can run this code and even modify it directly in Google Colab, no\ninstallation required:\n\nhttps://colab.research.google.com/github/google/flax/blob/main/examples/imagenet/imagenet.ipynb\n\nThe Colab also demonstrates how to load pretrained checkpoints from Cloud\nstorage at\n[gs://flax_public/examples/imagenet/](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet)\n\nTable of contents:\n\n- [Requirements](#requirements)\n- [Example runs](#example-runs)\n- [Running locally](#running-locally)\n  - [Overriding parameters on the command line](#overriding-parameters-on-the-command-line)\n- [Running on Cloud](#running-on-cloud)\n  - [Preparing the dataset](#preparing-the-dataset)\n  - [Google Cloud TPU](#google-cloud-tpu)\n  - [Google Cloud GPU](#google-cloud-gpu)\n\n### Requirements\n\n* TensorFlow dataset `imagenet2012:5.*.*`\n* `≈180GB` of RAM if you want to cache the dataset in memory for faster IO\n\n### Example runs\n\nWhile the example should run on a variety of hardware,\nwe have tested the following GPU and TPU configurations:\n\n|          Name           | Steps  | Walltime | Top-1 accuracy |                                                                       Metrics                                                                        |                                                                               Workdir                                                                                |\n| :---------------------- | -----: | :------- | :------------- | :--------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| TPU v3-32                | 125100 | 2.1h     | 76.54%         | [tfhub.dev](https://tensorboard.dev/experiment/GhPHRoLzTqu7c8vynTk6bg/)                                                                              | [gs://flax_public/examples/imagenet/tpu_v3_32](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu_v3_32)                                         |\n| TPU v2-32                | 125100 | 2.5h     | 76.67%         | [tfhub.dev](https://tensorboard.dev/experiment/qBJ7T9VPSgO5yeb0HAKbIA/)                                                                              | [gs://flax_public/examples/imagenet/tpu_v2_32](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu_v2_32)                                         |\n| TPU v3-8                | 125100 | 4.4h     | 76.37%         | [tfhub.dev](https://tensorboard.dev/experiment/JwxRMYrsR4O6V6fnkn3dmg/)                                                                              | [gs://flax_public/examples/imagenet/tpu](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/tpu)                                         |\n| v100_x8                 | 250200 | 13.2h    | 76.72%         | [tfhub.dev](https://tensorboard.dev/experiment/venzpsNXR421XLkvvzSkqQ/#scalars&_smoothingWeight=0&regexInput=%5Eimagenet/v100_x8%24)                 | [gs://flax_public/examples/imagenet/v100_x8](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/v100_x8)                                 |\n| v100_x8_mixed_precision |  62500 | 4.3h     | 76.27%         | [tfhub.dev](https://tensorboard.dev/experiment/venzpsNXR421XLkvvzSkqQ/#scalars&_smoothingWeight=0&regexInput=%5Eimagenet/v100_x8_mixed_precision%24) | [gs://flax_public/examples/imagenet/v100_x8_mixed_precision](https://console.cloud.google.com/storage/browser/flax_public/examples/imagenet/v100_x8_mixed_precision) |\n\n\n### Running locally\n\n```shell\npython main.py --workdir=./imagenet --config=configs/default.py\n```\n\n#### Overriding parameters on the command line\n\nSpecify a hyperparameter configuration by the means of setting `--config` flag.\nConfiguration flag is defined using\n[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).\n`config_flags` allows overriding configuration fields. This can be done as\nfollows:\n\n```shell\npython main.py --workdir=./imagenet_default --config=configs/default.py \\\n--config.num_epochs=100\n```\n\n### Running on Cloud\n\n#### Preparing the dataset\n\nFor running the ResNet50 model on imagenet dataset,\nyou first need to prepare the `imagenet2012` dataset.\nDownload the data from http://image-net.org/ as described in the\n[tensorflow_datasets catalog](https://www.tensorflow.org/datasets/catalog/imagenet2012).\nThen point the environment variable `$IMAGENET_DOWNLOAD_PATH`\nto the directory where the downloads are stored and prepare the dataset\nby running\n\n```shell\npython -c \"\nimport tensorflow_datasets as tfds\ntfds.builder('imagenet2012').download_and_prepare(\n    download_config=tfds.download.DownloadConfig(\n        manual_dir='$IMAGENET_DOWNLOAD_PATH'))\n\"\n```\n\nThe contents of the directory `~/tensorflow_datasets` should be copied to your\ngcs bucket. Point the environment variable `GCS_TFDS_BUCKET` to your bucket and\nrun the following command:\n\n```shell\ngsutil cp -r ~/tensorflow_datasets gs://$GCS_TFDS_BUCKET/datasets\n```\n\n#### Google Cloud TPU\n\nSetup the TPU VM and install the Flax dependencies on it as described\n[here](https://cloud.google.com/tpu/docs/jax-pods) for creating pod slices, or\n[here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) for a single\nv3-8 TPU.\n\nIf running on the single v3-8 TPU (i.e. 8 accelerators connected to a single\nhost), simply connect to the machine with\n`gcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE` and then start the\ntraining with below command:\n\n```shell\nexport TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets\npython3 main.py --workdir=./imagenet_tpu --config=configs/tpu.py\n```\n\nWhen running on pod slices, after creating the TPU VM, there are different ways\nof running the training in SPMD fashion on the hosts connected to the TPUs that\nmake up the slice. We simply send the same installation/execution shell commands\nto all hosts in parallel with the command below. If anything fails it's\nusually a good idea to connect to a single host and execute the commands\ninteractively.\n\nFor convenience, the TPU creation commands are inlined below.\n\n```shell\nVM_NAME=imagenet\nREPO=https://github.com/google/flax\nBRANCH=main\nWORKDIR=gs://$YOUR_BUCKET/flax/examples/imagenet/$(date +%Y%m%d_%H%M)\n\ngcloud alpha compute tpus tpu-vm create $VM_NAME \\\n    --zone=$ZONE \\\n    --version v2-alpha --accelerator-type v3-32\nFLAGS=\"--config.batch_size=$((32*256))\"\n\ngcloud alpha compute tpus tpu-vm ssh $VM_NAME --zone $ZONE \\\n--worker=all --command \"\npip install 'jax[tpu]>=0.2.21' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html &&\npip install --user git+$REPO.git &&\ngit clone --depth=1 -b $BRANCH $REPO &&\ncd flax/examples/imagenet &&\npip install -r requirements.txt &&\nexport TFDS_DATA_DIR=gs://$GCS_TFDS_BUCKET/datasets &&\npython3 main.py --workdir=$WORKDIR --config=configs/tpu.py $FLAGS\n\"\n```\n\n#### Google Cloud GPU\n\nCan be launched with utility script described in\n[../cloud/README.md](../cloud/README.md)\n\nThere are two configuratoins available:\n\n- `configs/v100_x8.py` : Full precision GPU training\n- `configs/v100_x8_mixed_precision.py` : Mixed precision GPU training. Note that\n  mixed precision handling is implemented manually with\n  [`optim.dynamic_scale`](https://github.com/google/flax/blob/main/flax/optim/dynamic_scale.py)\n"
  },
  {
    "path": "examples/imagenet/configs/default.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright 2021 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Default Hyperparameter configuration.\"\"\"\n\nimport ml_collections\n\n\ndef get_config():\n  \"\"\"Get the default hyperparameter configuration.\"\"\"\n  config = ml_collections.ConfigDict()\n\n  # As defined in the `models` module.\n  config.model = 'ResNet50'\n  # `name` argument of tensorflow_datasets.builder()\n  config.dataset = 'imagenet2012:5.*.*'\n\n  config.learning_rate = 0.1\n  config.warmup_epochs = 5.0\n  config.momentum = 0.9\n  config.batch_size = 128\n\n  config.num_epochs = 100.0\n  config.log_every_steps = 50\n\n  config.cache = True\n  config.half_precision = False\n\n  # If num_train_steps==-1 then the number of training steps is calculated from\n  # num_epochs using the entire dataset. Similarly for steps_per_eval.\n  config.num_train_steps = -1\n  config.steps_per_eval = -1\n  return config\n"
  },
  {
    "path": "examples/imagenet/configs/fake_data_benchmark.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Hyperparameter configuration for Fake data benchmark.\"\"\"\n\nimport jax\n\nfrom configs import default as default_lib\n\n\ndef get_config():\n  \"\"\"Get the hyperparameter configuration for Fake data benchmark.\"\"\"\n  # Override default configuration to avoid duplication of field definition.\n  config = default_lib.get_config()\n  config.batch_size = 256 * jax.device_count()\n  config.half_precision = True\n  config.num_epochs = 5\n\n  # Previously the input pipeline computed:\n  # `steps_per_epoch` as input_pipeline.TRAIN_IMAGES // batch_size\n  config.num_train_steps = 1024 // config.batch_size\n  # and `steps_per_eval` as input_pipeline.EVAL_IMAGES // batch_size\n  config.steps_per_eval = 512 // config.batch_size\n\n  return config\n"
  },
  {
    "path": "examples/imagenet/configs/tpu.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n# Copyright 2021 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"Hyperparameter configuration to run the example on TPUs.\"\"\"\n\nimport ml_collections\n\n\ndef get_config():\n  \"\"\"Get the hyperparameter configuration to train on TPUs.\"\"\"\n  config = ml_collections.ConfigDict()\n\n  # As defined in the `models` module.\n  config.model = 'ResNet50'\n  # `name` argument of tensorflow_datasets.builder()\n  config.dataset = 'imagenet2012:5.*.*'\n\n  config.learning_rate = 0.1\n  config.warmup_epochs = 5.0\n  config.momentum = 0.9\n\n  config.num_epochs = 100.0\n  config.log_every_steps = 100\n\n  # If num_train_steps==-1 then the number of training steps is calculated from\n  # num_epochs using the entire dataset. Similarly for steps_per_eval.\n  config.num_train_steps = -1\n  config.steps_per_eval = -1\n\n  # Consider setting the batch size to max(tpu_chips * 256, 8 * 1024) if you\n  # train on a larger pod slice.\n  config.batch_size = 1024\n  config.cache = True\n  config.half_precision = True\n\n  return config\n"
  },
  {
    "path": "examples/imagenet/configs/v100_x8.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.\"\"\"\n\nfrom configs import default as default_lib\n\n\ndef get_config():\n  \"\"\"Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.\"\"\"\n  # Override default configuration to avoid duplication of field definition.\n  config = default_lib.get_config()\n\n  config.batch_size = 512\n  config.cache = True\n\n  return config\n"
  },
  {
    "path": "examples/imagenet/configs/v100_x8_mixed_precision.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Hyperparameter configuration to run the example on 8 x Nvidia V100 GPUs.\"\"\"\n\nfrom configs import default as default_lib\n\n\ndef get_config():\n  \"\"\"Get the hyperparameter configuration to train on 8 x Nvidia V100 GPUs.\"\"\"\n  # Override default configuration to avoid duplication of field definition.\n  config = default_lib.get_config()\n\n  config.batch_size = 2048\n  config.cache = True\n  config.half_precision = True\n\n  return config\n"
  },
  {
    "path": "examples/imagenet/input_pipeline.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"ImageNet input pipeline.\n\"\"\"\n\nimport jax\nimport tensorflow as tf\nimport tensorflow_datasets as tfds\n\n\nIMAGE_SIZE = 224\nCROP_PADDING = 32\nMEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]\nSTDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]\n\n\ndef distorted_bounding_box_crop(image_bytes,\n                                bbox,\n                                min_object_covered=0.1,\n                                aspect_ratio_range=(0.75, 1.33),\n                                area_range=(0.05, 1.0),\n                                max_attempts=100):\n  \"\"\"Generates cropped_image using one of the bboxes randomly distorted.\n\n  See `tf.image.sample_distorted_bounding_box` for more documentation.\n\n  Args:\n    image_bytes: `Tensor` of binary image data.\n    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`\n        where each coordinate is [0, 1) and the coordinates are arranged\n        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole\n        image.\n    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped\n        area of the image must contain at least this fraction of any bounding\n        box supplied.\n    aspect_ratio_range: An optional list of `float`s. The cropped area of the\n        image must have an aspect ratio = width / height within this range.\n    area_range: An optional list of `float`s. The cropped area of the image\n        must contain a fraction of the supplied image within in this range.\n    max_attempts: An optional `int`. Number of attempts at generating a cropped\n        region of the image of the specified constraints. After `max_attempts`\n        failures, return the entire image.\n  Returns:\n    cropped image `Tensor`\n  \"\"\"\n  shape = tf.io.extract_jpeg_shape(image_bytes)\n  sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(\n      shape,\n      bounding_boxes=bbox,\n      min_object_covered=min_object_covered,\n      aspect_ratio_range=aspect_ratio_range,\n      area_range=area_range,\n      max_attempts=max_attempts,\n      use_image_if_no_bounding_boxes=True)\n  bbox_begin, bbox_size, _ = sample_distorted_bounding_box\n\n  # Crop the image to the specified bounding box.\n  offset_y, offset_x, _ = tf.unstack(bbox_begin)\n  target_height, target_width, _ = tf.unstack(bbox_size)\n  crop_window = tf.stack([offset_y, offset_x, target_height, target_width])\n  image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)\n\n  return image\n\n\ndef _resize(image, image_size):\n  return tf.image.resize([image], [image_size, image_size],\n                         method=tf.image.ResizeMethod.BICUBIC)[0]\n\n\ndef _at_least_x_are_equal(a, b, x):\n  \"\"\"At least `x` of `a` and `b` `Tensors` are equal.\"\"\"\n  match = tf.equal(a, b)\n  match = tf.cast(match, tf.int32)\n  return tf.greater_equal(tf.reduce_sum(match), x)\n\n\ndef _decode_and_random_crop(image_bytes, image_size):\n  \"\"\"Make a random crop of image_size.\"\"\"\n  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])\n  image = distorted_bounding_box_crop(\n      image_bytes,\n      bbox,\n      min_object_covered=0.1,\n      aspect_ratio_range=(3. / 4, 4. / 3.),\n      area_range=(0.08, 1.0),\n      max_attempts=10)\n  original_shape = tf.io.extract_jpeg_shape(image_bytes)\n  bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)\n\n  image = tf.cond(\n      bad,\n      lambda: _decode_and_center_crop(image_bytes, image_size),\n      lambda: _resize(image, image_size))\n\n  return image\n\n\ndef _decode_and_center_crop(image_bytes, image_size):\n  \"\"\"Crops to center of image with padding then scales image_size.\"\"\"\n  shape = tf.io.extract_jpeg_shape(image_bytes)\n  image_height = shape[0]\n  image_width = shape[1]\n\n  padded_center_crop_size = tf.cast(\n      ((image_size / (image_size + CROP_PADDING)) *\n       tf.cast(tf.minimum(image_height, image_width), tf.float32)),\n      tf.int32)\n\n  offset_height = ((image_height - padded_center_crop_size) + 1) // 2\n  offset_width = ((image_width - padded_center_crop_size) + 1) // 2\n  crop_window = tf.stack([offset_height, offset_width,\n                          padded_center_crop_size, padded_center_crop_size])\n  image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)\n  image = _resize(image, image_size)\n\n  return image\n\n\ndef normalize_image(image):\n  image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)\n  image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)\n  return image\n\n\ndef preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE):\n  \"\"\"Preprocesses the given image for training.\n\n  Args:\n    image_bytes: `Tensor` representing an image binary of arbitrary size.\n    dtype: data type of the image.\n    image_size: image size.\n\n  Returns:\n    A preprocessed image `Tensor`.\n  \"\"\"\n  image = _decode_and_random_crop(image_bytes, image_size)\n  image = tf.reshape(image, [image_size, image_size, 3])\n  image = tf.image.random_flip_left_right(image)\n  image = normalize_image(image)\n  image = tf.image.convert_image_dtype(image, dtype=dtype)\n  return image\n\n\ndef preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE):\n  \"\"\"Preprocesses the given image for evaluation.\n\n  Args:\n    image_bytes: `Tensor` representing an image binary of arbitrary size.\n    dtype: data type of the image.\n    image_size: image size.\n\n  Returns:\n    A preprocessed image `Tensor`.\n  \"\"\"\n  image = _decode_and_center_crop(image_bytes, image_size)\n  image = tf.reshape(image, [image_size, image_size, 3])\n  image = normalize_image(image)\n  image = tf.image.convert_image_dtype(image, dtype=dtype)\n  return image\n\n\ndef create_split(dataset_builder, batch_size, train,\n                 split_start, split_end,\n                 dtype=tf.float32,\n                 image_size=IMAGE_SIZE, cache=False):\n  \"\"\"Creates a split from the ImageNet dataset using TensorFlow Datasets.\n\n  Args:\n    dataset_builder: TFDS dataset builder for ImageNet.\n    batch_size: the batch size returned by the data pipeline.\n    train: Whether to load the train or evaluation split.\n    dtype: data type of the image.\n    image_size: The target size of the images.\n    cache: Whether to cache the dataset.\n  Returns:\n    A `tf.data.Dataset`.\n  \"\"\"\n  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make\n  # it unavailable to JAX.\n  tf.config.experimental.set_visible_devices([], 'GPU')\n  if train:\n    train_examples = dataset_builder.info.splits['train'].num_examples\n    split = f'train[{split_start}:{split_end}]'\n  else:\n    validate_examples = dataset_builder.info.splits['validation'].num_examples\n    split = f'validation[{split_start}:{split_end}]'\n\n  def decode_example(example):\n    if train:\n      image = preprocess_for_train(example['image'], dtype, image_size)\n    else:\n      image = preprocess_for_eval(example['image'], dtype, image_size)\n    return {'image': image, 'label': example['label']}\n\n  ds = dataset_builder.as_dataset(split=split, decoders={\n      'image': tfds.decode.SkipDecoding(),\n  })\n  options = tf.data.Options()\n  options.experimental_threading.private_threadpool_size = 48\n  ds = ds.with_options(options)\n\n  if cache:\n    ds = ds.cache()\n\n  if train:\n    ds = ds.repeat()\n    ds = ds.shuffle(16 * batch_size, seed=0)\n\n  ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n  ds = ds.batch(batch_size, drop_remainder=True)\n\n  if not train:\n    ds = ds.repeat()\n\n  ds = ds.prefetch(10)\n\n  return ds\n"
  },
  {
    "path": "examples/imagenet/main.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Main file for running the ImageNet example.\n\nThis file is intentionally kept short. The majority for logic is in libraries\nthat can be easily tested and imported in Colab.\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nfrom clu import platform\nimport jax\nfrom ml_collections import config_flags\nimport tensorflow as tf\n\nimport train\n\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('workdir', None, 'Directory to store model data.')\nconfig_flags.DEFINE_config_file(\n    'config',\n    None,\n    'File path to the training hyperparameter configuration.',\n    lock_config=True)\n\n\ndef main(argv):\n  if len(argv) > 1:\n    raise app.UsageError('Too many command-line arguments.')\n\n  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make\n  # it unavailable to JAX.\n  tf.config.experimental.set_visible_devices([], 'GPU')\n\n  #logging.info('JAX process: %d / %d', jax.process_index(), jax.process_count())\n  #logging.info('JAX local devices: %r', jax.local_devices())\n\n  # Add a note so that we can tell which task is which JAX host.\n  # (Depending on the platform task 0 is not guaranteed to be host 0)\n  #platform.work_unit().set_task_status(f'process_index: {jax.process_index()}, '\n  #                                     f'process_count: {jax.process_count()}')\n  platform.work_unit().create_artifact(platform.ArtifactType.DIRECTORY,\n                                       FLAGS.workdir, 'workdir')\n\n  train.train_and_evaluate(FLAGS.config, FLAGS.workdir)\n\n\nif __name__ == '__main__':\n  flags.mark_flags_as_required(['config', 'workdir'])\n  app.run(main)\n"
  },
  {
    "path": "examples/imagenet/models.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Flax implementation of ResNet V1.\"\"\"\n\n# See issue #620.\n# pytype: disable=wrong-arg-count\n\nfrom functools import partial\nfrom typing import Any, Callable, Sequence, Tuple\n\nfrom flax import linen as nn\nimport jax.numpy as jnp\n\nModuleDef = Any\n\n\nclass ResNetBlock(nn.Module):\n  \"\"\"ResNet block.\"\"\"\n  filters: int\n  conv: ModuleDef\n  norm: ModuleDef\n  act: Callable\n  strides: Tuple[int, int] = (1, 1)\n\n  @nn.compact\n  def __call__(self, x,):\n    residual = x\n    y = self.conv(self.filters, (3, 3), self.strides)(x)\n    y = self.norm()(y)\n    y = self.act(y)\n    y = self.conv(self.filters, (3, 3))(y)\n    y = self.norm(scale_init=nn.initializers.zeros)(y)\n\n    if residual.shape != y.shape:\n      residual = self.conv(self.filters, (1, 1),\n                           self.strides, name='conv_proj')(residual)\n      residual = self.norm(name='norm_proj')(residual)\n\n    return self.act(residual + y)\n\n\nclass BottleneckResNetBlock(nn.Module):\n  \"\"\"Bottleneck ResNet block.\"\"\"\n  filters: int\n  conv: ModuleDef\n  norm: ModuleDef\n  act: Callable\n  strides: Tuple[int, int] = (1, 1)\n\n  @nn.compact\n  def __call__(self, x):\n    residual = x\n    y = self.conv(self.filters, (1, 1))(x)\n    y = self.norm()(y)\n    y = self.act(y)\n    y = self.conv(self.filters, (3, 3), self.strides)(y)\n    y = self.norm()(y)\n    y = self.act(y)\n    y = self.conv(self.filters * 4, (1, 1))(y)\n    y = self.norm(scale_init=nn.initializers.zeros)(y)\n\n    if residual.shape != y.shape:\n      residual = self.conv(self.filters * 4, (1, 1),\n                           self.strides, name='conv_proj')(residual)\n      residual = self.norm(name='norm_proj')(residual)\n\n    return self.act(residual + y)\n\n\nclass ResNet(nn.Module):\n  \"\"\"ResNetV1.\"\"\"\n  stage_sizes: Sequence[int]\n  block_cls: ModuleDef\n  num_classes: int\n  num_filters: int = 64\n  dtype: Any = jnp.float32\n  act: Callable = nn.relu\n  conv: ModuleDef = nn.Conv\n\n  @nn.compact\n  def __call__(self, x, train: bool = True):\n    conv = partial(self.conv, use_bias=False, dtype=self.dtype)\n    norm = partial(nn.BatchNorm,\n                   use_running_average=not train,\n                   momentum=0.9,\n                   epsilon=1e-5,\n                   dtype=self.dtype)\n\n    x = conv(self.num_filters, (7, 7), (2, 2),\n             padding=[(3, 3), (3, 3)],\n             name='conv_init')(x)\n    x = norm(name='bn_init')(x)\n    x = nn.relu(x)\n    x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')\n    for i, block_size in enumerate(self.stage_sizes):\n      for j in range(block_size):\n        strides = (2, 2) if i > 0 and j == 0 else (1, 1)\n        x = self.block_cls(self.num_filters * 2 ** i,\n                           strides=strides,\n                           conv=conv,\n                           norm=norm,\n                           act=self.act)(x)\n    x = jnp.mean(x, axis=(1, 2))\n    x = nn.Dense(self.num_classes, dtype=self.dtype)(x)\n    x = jnp.asarray(x, self.dtype)\n    return x\n\n\nResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2],\n                   block_cls=ResNetBlock)\nResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3],\n                   block_cls=ResNetBlock)\nResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3],\n                   block_cls=BottleneckResNetBlock)\nResNet101 = partial(ResNet, stage_sizes=[3, 4, 23, 3],\n                    block_cls=BottleneckResNetBlock)\nResNet152 = partial(ResNet, stage_sizes=[3, 8, 36, 3],\n                    block_cls=BottleneckResNetBlock)\nResNet200 = partial(ResNet, stage_sizes=[3, 24, 36, 3],\n                    block_cls=BottleneckResNetBlock)\n\n\nResNet18Local = partial(ResNet, stage_sizes=[2, 2, 2, 2],\n                        block_cls=ResNetBlock, conv=nn.ConvLocal)\n\n\n# Used for testing only.\n_ResNet1 = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock)\n_ResNet1Local = partial(ResNet, stage_sizes=[1], block_cls=ResNetBlock,\n                        conv=nn.ConvLocal)\n"
  },
  {
    "path": "examples/imagenet/train.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"ImageNet example.\n\nThis script trains a ResNet-50 on the ImageNet dataset.\nThe data is loaded using tensorflow_datasets.\n\"\"\"\n\nimport functools\nimport os\nimport time\nfrom typing import Any\n\nimport alpa\nfrom absl import logging\nfrom clu import metric_writers\nfrom clu import periodic_actions\nimport flax\nfrom flax import jax_utils\nfrom flax.training import train_state, dynamic_scale as dynamic_scale_lib\nfrom flax.training import checkpoints, common_utils\nimport jax\nfrom jax import lax\nimport jax.numpy as jnp\nfrom jax import random\nimport ml_collections\nimport numpy as np\nimport optax\nimport ray\nimport tensorflow as tf\nimport tensorflow_datasets as tfds\n\nimport input_pipeline\nimport models\n\n\nNUM_CLASSES = 1000\n\n\ndef create_model(*, model_cls, half_precision, **kwargs):\n  platform = jax.local_devices()[0].platform\n  if half_precision:\n    if platform == 'tpu':\n      model_dtype = jnp.bfloat16\n    else:\n      model_dtype = jnp.float16\n  else:\n    model_dtype = jnp.float32\n  return model_cls(num_classes=NUM_CLASSES, dtype=model_dtype, **kwargs)\n\n\ndef initialized(key, image_size, model):\n  input_shape = (1, image_size, image_size, 3)\n  @jax.jit\n  def init(*args):\n    return model.init(*args)\n  variables = init({'params': key}, jnp.ones(input_shape, model.dtype))\n  return variables['params'], variables['batch_stats']\n\n\ndef cross_entropy_loss(logits, labels):\n  one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)\n  xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)\n  return jnp.mean(xentropy)\n\n\ndef compute_metrics(logits, labels):\n  loss = cross_entropy_loss(logits, labels)\n  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n  metrics = {\n      'loss': loss,\n      'accuracy': accuracy,\n  }\n  return metrics\n\n\ndef create_learning_rate_fn(\n    config: ml_collections.ConfigDict,\n    base_learning_rate: float,\n    steps_per_epoch: int):\n  \"\"\"Create learning rate schedule.\"\"\"\n  warmup_fn = optax.linear_schedule(\n      init_value=0., end_value=base_learning_rate,\n      transition_steps=config.warmup_epochs * steps_per_epoch)\n  cosine_epochs = max(config.num_epochs - config.warmup_epochs, 1)\n  cosine_fn = optax.cosine_decay_schedule(\n      init_value=base_learning_rate,\n      decay_steps=cosine_epochs * steps_per_epoch)\n  schedule_fn = optax.join_schedules(\n      schedules=[warmup_fn, cosine_fn],\n      boundaries=[config.warmup_epochs * steps_per_epoch])\n  return schedule_fn\n\n\ndef train_step(state, batch, learning_rate_fn):\n  \"\"\"Perform a single training step.\"\"\"\n  def loss_fn(params):\n    \"\"\"loss function used for training.\"\"\"\n    logits, new_model_state = state.apply_fn(\n        {'params': params, 'batch_stats': state.batch_stats},\n        batch['image'],\n        mutable=['batch_stats'])\n    loss = cross_entropy_loss(logits, batch['label'])\n    weight_penalty_params = jax.tree_leaves(params)\n    weight_decay = 0.0001\n    weight_l2 = sum(jnp.sum(x ** 2)\n                     for x in weight_penalty_params\n                     if x.ndim > 1)\n    weight_penalty = weight_decay * 0.5 * weight_l2\n    loss = loss + weight_penalty\n    return loss, (new_model_state, logits)\n\n  step = state.step\n  dynamic_scale = state.dynamic_scale\n  lr = learning_rate_fn(step)\n\n  if dynamic_scale:\n    grad_fn = dynamic_scale.value_and_grad(\n        loss_fn, has_aux=True)\n    dynamic_scale, is_fin, aux, grads = grad_fn(state.params)\n    # dynamic loss takes care of averaging gradients across replicas\n  else:\n    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n    aux, grads = grad_fn(state.params)\n  new_model_state, logits = aux[1]\n  metrics = compute_metrics(logits, batch['label'])\n  metrics['learning_rate'] = lr\n\n  new_state = state.apply_gradients(\n      grads=grads, batch_stats=new_model_state['batch_stats'])\n  if dynamic_scale:\n    # if is_fin == False the gradients contain Inf/NaNs and optimizer state and\n    # params should be restored (= skip this step).\n    new_state = new_state.replace(\n        opt_state=jax.tree_map(\n            functools.partial(jnp.where, is_fin),\n            new_state.opt_state,\n            state.opt_state),\n        params=jax.tree_map(\n            functools.partial(jnp.where, is_fin),\n            new_state.params,\n            state.params),\n        dynamic_scale=dynamic_scale)\n    metrics['scale'] = dynamic_scale.scale\n\n  return new_state, metrics\n\n\ndef eval_step(state, batch):\n  variables = {'params': state.params, 'batch_stats': state.batch_stats}\n  logits = state.apply_fn(\n      variables, batch['image'], train=False, mutable=False)\n  return compute_metrics(logits, batch['label'])\n\n\ndef create_input_iter(dataset_builder, batch_size, image_size, dtype,\n                      placement_specs, train, cache):\n\n  def input_iter_func(start, end, batch_size):\n      ds = input_pipeline.create_split(\n          dataset_builder, batch_size, train,\n          start, end,\n          image_size=image_size, dtype=dtype, cache=cache)\n      return map(lambda xs: (xs[\"image\"]._numpy(), xs[\"label\"]._numpy()), ds)\n\n  split_name = \"train\" if train else \"validation\"\n\n  it = alpa.MeshDriverDataLoader(\n      batch_size, dataset_builder.info.splits[split_name].num_examples,\n      input_iter_func, placement_specs, prefetch_size=4, repeat=True)\n  it = map(lambda x: {\"image\": x[0], \"label\": x[1]}, it)\n  return it\n\n\nclass TrainState(train_state.TrainState):\n  batch_stats: Any\n  dynamic_scale: dynamic_scale_lib.DynamicScale\n\n\ndef restore_checkpoint(state, workdir):\n  return checkpoints.restore_checkpoint(workdir, state)\n\n\ndef save_checkpoint(state, workdir):\n  alpa.prefetch(state)\n  state = alpa.util.map_to_nparray(state)\n  step = int(state.step)\n  checkpoints.save_checkpoint(workdir, state, step, keep=3)\n\n\n# pmean only works inside pmap because it needs an axis name.\n# This function will average the inputs across all devices.\ncross_replica_mean = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')\n\n\ndef sync_batch_stats(state):\n  \"\"\"Sync the batch statistics across replicas.\"\"\"\n  # Each device has its own version of the running average batch statistics and\n  # we sync them before evaluation.\n  return state.replace(batch_stats=cross_replica_mean(state.batch_stats))\n\n\ndef create_train_state(rng, config: ml_collections.ConfigDict,\n                       model, image_size, learning_rate_fn):\n  \"\"\"Create initial training state.\"\"\"\n  dynamic_scale = None\n  platform = jax.local_devices()[0].platform\n  if config.half_precision and platform == 'gpu':\n    dynamic_scale = dynamic_scale_lib.DynamicScale()\n  else:\n    dynamic_scale = None\n\n  params, batch_stats = initialized(rng, image_size, model)\n  tx = optax.sgd(\n      learning_rate=learning_rate_fn,\n      momentum=config.momentum,\n      nesterov=True,\n  )\n  state = TrainState.create(\n      apply_fn=model.apply,\n      params=params,\n      tx=tx,\n      batch_stats=batch_stats,\n      dynamic_scale=dynamic_scale)\n  return state\n\n\ndef train_and_evaluate(config: ml_collections.ConfigDict,\n                       workdir: str) -> TrainState:\n  \"\"\"Execute model training and evaluation loop.\n\n  Args:\n    config: Hyperparameter configuration for training and evaluation.\n    workdir: Directory where the tensorboard summaries are written to.\n\n  Returns:\n    Final TrainState.\n  \"\"\"\n  # Initialize ray.\n  # The `runtime_env` argument is used to upload local python scripts to\n  # remote workers while excluding checkpoints, profiling events, etc.\n  ray.init(address=\"auto\",\n           runtime_env={\"working_dir\": \".\",\n\t                \"excludes\": [os.path.relpath(workdir)]})\n  # Initialize alpa.\n  alpa.init(cluster=\"ray\")\n\n  writer = metric_writers.create_default_writer(\n      logdir=workdir, just_logging=jax.process_index() != 0)\n\n  rng = random.PRNGKey(0)\n\n  image_size = 224\n\n  if config.batch_size % jax.device_count() > 0:\n    raise ValueError('Batch size must be divisible by the number of devices')\n  local_batch_size = config.batch_size // jax.process_count()\n\n  platform = jax.local_devices()[0].platform\n\n  if config.half_precision:\n    if platform == 'tpu':\n      input_dtype = tf.bfloat16\n    else:\n      input_dtype = tf.float16\n  else:\n    input_dtype = tf.float32\n\n  dataset_builder = tfds.builder(config.dataset)\n  steps_per_epoch = (\n      dataset_builder.info.splits['train'].num_examples // config.batch_size\n  )\n\n  if config.num_train_steps == -1:\n    num_steps = int(steps_per_epoch * config.num_epochs)\n  else:\n    num_steps = config.num_train_steps\n\n  if config.steps_per_eval == -1:\n    num_validation_examples = dataset_builder.info.splits[\n        'validation'].num_examples\n    steps_per_eval = num_validation_examples // config.batch_size\n  else:\n    steps_per_eval = config.steps_per_eval\n\n  steps_per_checkpoint = steps_per_epoch * 10\n\n  base_learning_rate = config.learning_rate * config.batch_size / 256.\n\n  model_cls = getattr(models, config.model)\n  model = create_model(\n      model_cls=model_cls, half_precision=config.half_precision)\n\n  learning_rate_fn = create_learning_rate_fn(\n      config, base_learning_rate, steps_per_epoch)\n\n  state = create_train_state(rng, config, model, image_size, learning_rate_fn)\n  state = restore_checkpoint(state, workdir)\n  # step_offset > 0 if restarting from checkpoint\n  step_offset = int(state.step)\n\n  p_train_step = alpa.parallelize(\n      functools.partial(train_step, learning_rate_fn=learning_rate_fn))\n  p_eval_step = alpa.parallelize(eval_step, donate_argnums=())\n\n  logging.info('Initial compilation. This might take some minutes...')\n  batch = {\n    \"image\": jax.core.ShapedArray(\n        (config.batch_size, image_size, image_size, 3), jnp.float32),\n    \"label\": jax.core.ShapedArray((config.batch_size,), jnp.int32),\n  }\n  executable = p_train_step.get_executable(state, batch)\n  executable.dump_debug_info(\"alpa_debug_info\")\n  logging.info('Initial compilation completed.')\n\n  batch_placement_specs = executable.get_input_placement_specs()[1]\n\n  train_iter = create_input_iter(\n      dataset_builder, local_batch_size, image_size, input_dtype,\n      batch_placement_specs, train=True, cache=config.cache)\n  eval_iter = create_input_iter(\n      dataset_builder, local_batch_size, image_size, input_dtype,\n      batch_placement_specs, train=False, cache=config.cache)\n\n  train_metrics = []\n  hooks = []\n  if jax.process_index() == 0:\n    hooks += [periodic_actions.Profile(num_profile_steps=5, logdir=workdir)]\n  train_metrics_last_t = time.time()\n  for step, batch in zip(range(step_offset, num_steps), train_iter):\n    state, metrics = p_train_step(state, batch)\n    for h in hooks:\n      h(step)\n\n    if config.get('log_every_steps'):\n      train_metrics.append(metrics)\n      if (step + 1) % config.log_every_steps == 0:\n        train_metrics = alpa.util.get_metrics(train_metrics)\n        summary = {\n            f'train_{k}': v\n            for k, v in jax.tree_map(lambda x: x.mean(), train_metrics).items()\n        }\n        summary['ips'] = config.batch_size * config.log_every_steps / (\n            time.time() - train_metrics_last_t)\n        writer.write_scalars(step + 1, summary)\n        train_metrics = []\n        train_metrics_last_t = time.time()\n\n    if (step + 1) % steps_per_epoch == 0:\n      epoch = step // steps_per_epoch\n      eval_metrics = []\n\n      for _ in range(steps_per_eval):\n        eval_batch = next(eval_iter)\n        metrics = p_eval_step(state, eval_batch)\n        eval_metrics.append(metrics)\n      eval_metrics = alpa.util.get_metrics(eval_metrics)\n      summary = jax.tree_map(lambda x: x.mean(), eval_metrics)\n      logging.info('eval epoch: %d, loss: %.4f, accuracy: %.2f',\n                   epoch, summary['loss'], summary['accuracy'] * 100)\n      writer.write_scalars(\n          step + 1, {f'eval_{key}': val for key, val in summary.items()})\n      writer.flush()\n    if (step + 1) % steps_per_checkpoint == 0 or step + 1 == num_steps:\n      save_checkpoint(state, workdir)\n\n  # Wait until computations are done before exiting\n  executable.sync()\n\n  return state\n"
  },
  {
    "path": "examples/llm_serving/README.rst",
    "content": "=======================================================\nServing OPT-175B, BLOOM-176B and CodeGen-16B using Alpa\n=======================================================\n\nThis tutorial shows how to setup a serving system to serve one of the largest available pretrained language models `OPT-175B <https://github.com/facebookresearch/metaseq/tree/main/projects/OPT>`_. The instructions for other models (BLOOM and CodeGen) are also listed at the end.\n\n👉 Try a live demo at `Alpa-OPT Demo <https://alpa-projects.github.io/opt>`_ 👈\n\nOverview\n========\nAs a serving system, Alpa offers the following unique advantages:\n\n* **Designed for large models**: Cannot fit the model into a single GPU? Not a problem, Alpa is designed for training and serving big models like GPT-3.\n\n* **Support commodity hardware**: With Alpa, you can serve OPT-175B using your in-house GPU cluster, without needing the latest generations of A100 80GB GPUs nor fancy InfiniBand connections -- no hardware constraints!\n\n* **Flexible parallelism strategies**: Alpa will automatically figure out the appropriate model-parallel strategies based on your cluster setup and your model architecture.\n\nIn this example, we use Alpa to serve the open-source OPT model, supporting all sizes ranging from 125M to 175B. Specifically, Alpa provides:\n\n* A distributed backend to perform efficient model-parallel inference for the large OPT models.\n\n* A web frontend to collect and batch inference requests from users.\n\n.. note::\n\n  The pre-trained OPT model weights can be obtained from `Metaseq <https://github.com/facebookresearch/metaseq>`_, subject to their license.\n\n.. note::\n\n  You will need at least 350GB GPU memory on your entire cluster to serve the OPT-175B model.\n  For example, you can use 4 x AWS p3.16xlarge instances, which provide 4 (instance) x 8 (GPU/instance) x 16 (GB/GPU) = 512 GB memory.\n\n  You can also follow this guide to setup a serving system to serve smaller versions of OPT, such as OPT-66B, OPT-30B, etc.\n  Pick an appropriate size from `OPT weight downloading page <https://github.com/facebookresearch/metaseq/tree/main/projects/OPT>`_ based on your available resources.\n\nDemo\n====\nThe code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference.\n\n.. code:: python\n\n  from transformers import AutoTokenizer\n  from llm_serving.model.wrapper import get_model\n\n  # Load the tokenizer. All OPT models with different sizes share the same tokenizer\n  tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-2.7b\")\n  tokenizer.add_bos_token = False\n\n  # Load the model. Alpa automatically downloads the weights to the specificed path\n  model = get_model(model_name=\"alpa/opt-2.7b\", path=\"~/opt_weights/\")\n\n  # Generate\n  prompt = \"Paris is the capital city of\"\n\n  input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids\n  output = model.generate(input_ids=input_ids, max_length=256, do_sample=True)\n  generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)\n\n  print(generated_string)\n\nRequirements\n============\n1. Install Alpa following the `installation guide <https://alpa-projects.github.io/install.html>`_. You can either install by python wheel or build from source.\n\n2. Install additional requirements for ``llm_serving``:\n\n  .. code:: shell\n\n    pip3 install \"transformers<=4.23.1\" fastapi uvicorn omegaconf jinja2\n\n    # Install torch corresponding to your CUDA version, e.g., for CUDA 11.3:\n    pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113\n\n3. Clone the ``alpa`` repo. If you install alpa by python wheel, please clone the alpa repo. If you install from source, you already did this step.\n\n  .. code:: shell\n\n    git clone git@github.com:alpa-projects/alpa.git\n\n4. Install ``llm_serving`` package. Go to the examples folder and install the package.\n\n  .. code:: shell\n\n    cd alpa/examples\n    pip3 install -e .\n\n\nConvert Weights Format\n======================\n\nThe weights of OPT 125M--66B models are publicly available. Huggingface hosts copies of these weights.\nFor OPT 125M--66B, you **do not need** to download or convert the weights manually. Alpa will automatically download the weights from huggingface to the given path if Alpa cannot find cached weights locally.\n\nThe weights of OPT-175B can be got from meta by filling a `request form <https://github.com/facebookresearch/metaseq/tree/main/projects/OPT>`_ .\nYou then need to manually convert the obtained weights into Alpa format.\n\nConvert OPT-175B weights into Alpa formats\n------------------------------------------\nWe provide detailed instructions below on how to convert the original OPT-175B weights into Alpa-compatible formats. You can skip this section if you only want to run smaller models.\n\n  .. note::\n\n    The procedures below for converting OPT-175B weights will take about 1 hour.\n\n1. Download and verify the original weights\n    First, download Metaseq's original OPT-175B weights in 992 shards, verify the `MD5 of each shard <https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/assets/opt175b_md5sum_shards.csv>`_ , and put the shards under a folder, say, ``PATH_TO_992_SHARDS/``.\n\n2. Consolidate the weights from 992 shards into one single checkpoint\n    Use the script `step_2_consolidate_992_shards_to_singleton.py <https://github.com/alpa-projects/alpa/tree/main/examples/llm_serving/scripts/step_2_consolidate_992_shards_to_singleton.py>`_ as:\n\n  .. code:: shell\n\n    python3 step_2_consolidate_992_shards_to_singleton.py --read-prefix [PATH_TO_992_SHARDS]/checkpoint_last --save-prefix [PATH_TO_SAVE_CHECKPOINT]\n\n  The consolidated checkpoint will be saved at ``PATH_TO_SAVE_CHECKPOINT`` as specified in the command.\n\n  .. note::\n\n    The above script will require a peak memory (RAM) usage as large as twice of the model size.\n    For example, if you are performing consolidation for the 175B model, it will approximately have a peak memory usage of 175B x 2 bytes x 2 = 700GB.\n    Please make sure your RAM is sufficient to run the script without throwing an OOM exception.\n\n  .. note::\n\n    The above script will save the model weights as a single consolidated checkpoint at ``PATH_TO_SAVE_CHECKPOINT``, hence will require at least 350GB disk space available.\n\n3. Convert the single checkpoint into Alpa-compatible formats\n    Alpa ingests weights simply from numpy formats. Use the script `step_3_convert_to_numpy_weights.py <https://github.com/alpa-projects/alpa/tree/main/examples/llm_serving/scripts/step_3_convert_to_numpy_weights.py>`_ to convert the\n    single checkpoint into numpy formats:\n\n    .. code:: shell\n\n      python3 step_3_convert_to_numpy_weights.py --ckpt-path PATH_TO_SAVE_CHECKPOINT --output-folder OUTPUT_PATH\n\n\n    The weights will be saved at the folder ``OUTPUT_PATH`` as specified in the command.\n\n  .. note::\n\n    The above script also requires 350GB free disk space to write the numpy-formatted weights.\n\nConverted weights for other models\n----------------------------------\nYou do not need to download the weights manually for OPT 125M--66B. However, if you have trouble with the automatic downloading or huggingface. We also provide the converted weights for the following models.\n\n  * `OPT-125M weights <https://drive.google.com/file/d/1Ps7DFD80wNO7u2t39YCYcBX-9XwypGzl/view?usp=sharing>`_\n  * `OPT-2.7B weights <https://drive.google.com/file/d/1ayIaKRhxF9osZWgcFG-3vSkjcepSWdQd/view?usp=sharing>`_\n  * `OPT-30B weights <https://drive.google.com/file/d/1_MBcgwTqHFboV0JkGWR03AOHusrxcHlu/view?usp=sharing>`_\n\nCopy Weights to Multiple Nodes\n------------------------------\nIf you want to run the model on multiple nodes, you can use one of the following methods to copy the weights to all nodes.\n\n1. Put the weights under a shared network file system, so all nodes can access it.\n2. Run the script first on a driver node. The driver node will download the weights to its local disk, but the script will fail later because worker nodes cannot access the weights.\n   You can then manually copy all downloaded weights under ``path`` from the driver node to all worker nodes.\n\nRun Generation in the Command Line\n==================================\n\nThe code of this tutorial is under `examples/llm_serving <https://github.com/alpa-projects/alpa/tree/main/examples/llm_serving>`_.\n\n- Run generation using the 125M model with PyTorch/HuggingFace backend on a single GPU:\n\n  .. code:: shell\n\n    python3 textgen.py --model facebook/opt-125m\n\n\n- Run generation using the 125M model with JAX backend on a single GPU:\n\n  .. code:: shell\n\n    python3 textgen.py --model jax/opt-125m\n\n\n- Run model-parallel generation using the 2.7B model with Alpa on multiple GPUs:\n\n  .. code:: shell\n\n    # Start ray on the node\n    ray start --head\n\n    python3 textgen.py --model alpa/opt-2.7b\n\n\n- Run distributed generation using the 175B model with Alpa on a cluster of GPU nodes.\n  Note you will need >350GB total GPU memory in the entire cluster to successfully run the inference.\n\n  Before running the command below, start Ray on the cluster following `this guide <https://docs.ray.io/en/latest/cluster/cloud.html#manual-cluster>`_. You can check the cluster status by ``ray status``. You should be able to see all GPUs and all nodes in the output.\n\n  .. code:: shell\n\n    python3 textgen.py --model alpa/opt-175b\n\nLaunch a Web Server to Serve the OPT Models\n===========================================\n\nWe need to run two scripts: one for web server and another for the model serving worker.\nThey will use two ports. The port of the website is defined in the command line and the port of the worker is defined in ``service/constants.py``\n\n.. code:: shell\n\n  # Launch the model worker\n  python3 launch_model_worker.py --model alpa/opt-175b\n\n  # Launch the website (in a new terminal)\n  uvicorn launch_website:app --host 0.0.0.0 --port 8001\n\nThen open ``http://[IP-ADDRESS]:8001`` in your browser to try out the model!\n\nThere is also a client library which can be used to query the model worker\nvia a python script. Please check ``test_completions.py`` for the usage.\n\nImproving Generation Speed\n==========================\nHere are some tips for improving the generation speed.\n\n1. Batching. Single sequence generation cannot fully utilize the GPU power.\n   Applying batching can greatly boost the performace. See ``textgen.py`` for the usage.\n2. Tune the ``encoder_chunk_sizes`` argument of ``get_model``.\n   Alpa compiles multiple executables and uses these executables to encode a prompt chunk by chunk. This argument controls the possible chunk sizes. Depending on the length of your prompt, you can try different combinations. For example, if your prompt lengths are around 1000-1500, a good combination is ``[1, 256, 1024]``.\n3. Tune parallelization strategy. If you are familiar with alpa, you can tune the ``method`` argument of ``alpa.parallelize`` and try different parallelization methods.\n\nIf you find the generation speed too slow and want to accelerate it, please join `Alpa slack <https://forms.gle/YEZTCrtZD6EAVNBQ7>`_ and tell us your use cases. We are actively working on improving the performance.\n\nOPT License\n===========\nThe use of the OPT pretrained weights is subject to the `Model License <https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/MODEL_LICENSE.md>`_ by Metaseq.\n\nOther Models (BLOOM)\n====================\nAlpa also supports `BLOOM <https://huggingface.co/bigscience/bloom>`_.\nYou can use commands similar to OPT but with a different model name.\n\n  .. code:: shell\n\n    # Huggingface/pytorch backend\n    python3 textgen.py --model bigscience/bloom-560m\n\n    # Jax backend\n    python3 textgen.py --model jax/bloom-560m\n\n    # Alpa backend\n    python3 textgen.py --model alpa/bloom-560m\n\nOther Models (CodeGen)\n======================\nAlpa also supports `CodeGen <https://github.com/salesforce/CodeGen>`_.\nYou can use commands similar to OPT but with a different model name.\n\n  .. code:: shell\n\n    # Huggingface/pytorch backend\n    python3 codegen.py --model Salesforce/codegen-2B-mono\n\n    # Alpa backend\n    python3 codegen.py --model alpa/codegen-2B-mono\n"
  },
  {
    "path": "examples/llm_serving/__init__.py",
    "content": ""
  },
  {
    "path": "examples/llm_serving/benchmark/benchmark_1d.py",
    "content": "import argparse\nimport math\nimport time\nimport random\n\nimport numpy as np\nimport torch\n\nfrom alpa.util import write_tsv\nfrom llm_serving.generator import pad_batch\nfrom llm_serving.model.wrapper import get_model as get_model_2d\nfrom llm_serving.model.wrapper_1d import get_model as get_model_1d\n\n\ninput_id_list = [\n    [45942, 2866, 16, 5, 892, 9, 44042, 8],\n    [100, 261, 23888, 2426, 16, 10, 21624, 12, 4310, 3034, 9744, 25526, 11],\n    [133, 589, 9, 886, 6, 10817, 16, 10, 285],\n    [5625, 16, 10, 205, 183, 8, 38, 236, 7],\n    [2264, 16, 5, 7440, 9, 16673, 873, 24214, 116],\n    [32826, 16, 5, 812, 343, 9],\n    [2264, 109, 47, 206, 59, 5, 499, 9, 28850, 1975, 37079, 116],\n    [2264, 109, 47, 206, 59, 5, 3099, 9, 301, 116],\n    [19195, 140, 16, 5, 394, 9],\n    [534, 10311, 12, 246, 16, 10, 739, 2777, 1421, 14, 16, 4453, 9],\n]\n\n\ndef synthesize_inputs(low=32, high=512, n_prompt=256):\n    vocab_size = 50272\n    ret = []\n    prompt_length = np.random.randint(low, high, (n_prompt,))\n    for i in range(n_prompt):\n        p = np.random.randint(low=4, high=vocab_size, size=prompt_length[i]).tolist()\n        ret.append(p)\n    min_length = min(len(p) for p in ret)\n    max_length = max(len(p) for p in ret)\n    mean_length = sum(len(p) for p in ret) / len(ret)\n    print(f\"- Synthetic dataset, size {len(ret)}, min {min_length}, max {max_length}, mean {mean_length}\")\n    return ret\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"opt-1.3b\")\n    parser.add_argument(\"--backend\", type=str, default=\"jax\")\n    parser.add_argument(\"--path\", type=str, default=\"~/opt_weights/\")\n    parser.add_argument(\"--n-warmup\", type=int, default=2)\n    parser.add_argument(\"--n-iter\", type=int, default=3)\n    parser.add_argument(\"--n-prompt\", type=int, default=8)\n    parser.add_argument(\"--use-synthetic\", action=\"store_true\")\n    parser.add_argument(\"--low\", type=int, default=16)\n    parser.add_argument(\"--high\", type=int, default=128)\n    parser.add_argument(\"--batch-size-2d\", type=int, default=4)\n    parser.add_argument(\"--batch-size-1d\", type=int, default=256)\n    parser.add_argument(\"--cache-size\", type=int, default=4096 * 8)\n    parser.add_argument(\"--max-new-tokens\", type=int, default=128)\n    parser.add_argument(\"--tail-percentage\", type=float, default=10)\n    parser.add_argument(\"--verbose\", action=\"store_true\")\n    args = parser.parse_args()\n\n    def extend_input(input_list):\n        if args.n_prompt <= len(input_list):\n            ret = input_list[:args.n_prompt]\n        else:\n            factor = math.ceil(float(args.n_prompt) / float(len(input_list)))\n            ret = input_list * factor\n            random.shuffle(ret)\n            ret = ret[:args.n_prompt]\n        return ret\n\n    if not args.use_synthetic:\n        input = extend_input(input_id_list)\n    else:\n        input = synthesize_inputs(low=args.low, high=args.high, n_prompt=args.n_prompt)\n    n_batch_2d = math.ceil(len(input) / float(args.batch_size_2d))\n\n    def runner_2d(model, input):\n        output = []\n        latency = []\n        total_time = 0.0\n        start_idx = 0\n        for i in range(n_batch_2d):\n            end_idx = start_idx + args.batch_size_2d\n            end_idx = min(len(input), end_idx)\n            cur_batch = input[start_idx:end_idx]\n\n            effective_num_seq = len(cur_batch)\n            cur_batch = pad_batch(cur_batch, 1, args.batch_size_2d)\n            cur_batch = torch.from_numpy(np.array(cur_batch))\n\n            tic = time.time()\n            output_ids = model.generate(input_ids=cur_batch,\n                                        max_new_tokens=args.max_new_tokens,\n                                        do_sample=False)\n            toc = time.time()\n            batch_latency = toc - tic\n            total_time += batch_latency\n            latency.extend([batch_latency] * effective_num_seq)\n            output.extend(output_ids[:effective_num_seq])\n            start_idx += args.batch_size_2d\n\n        return latency, total_time, output\n\n    def runner_1d(model, input):\n        tic = time.time()\n        output_ids, latency = model.generate(input,\n                                             max_new_tokens=args.max_new_tokens,\n                                             do_sample=False)\n        toc = time.time()\n        total_time = toc - tic\n\n        return latency, total_time, output_ids\n\n    def benchmark(model, runner, input):\n        for i in range(args.n_warmup):\n            print(f\"  Warm-up iter {i}\")\n            runner(model, input)\n        latencies = np.zeros((args.n_iter, len(input)), dtype=float)\n        total_times = []\n        for i in range(args.n_iter):\n            latency, total_time, output = runner(model, input)\n            print(f\"  Benchmark iter {i}\")\n            if args.verbose:\n                print(f\"  {latency}\")\n            latencies[i, :] = latency\n            total_times.append(total_time)\n        mean_latency = np.mean(latencies, axis=0)\n        return mean_latency, sum(total_times) / args.n_iter, output\n\n    def estimate_throughput(input, output, latency, total_time):\n        req_per_sec = len(input) / total_time\n        decoded_tokens = [out[len(input[i]):] for i, out in enumerate(output)]\n        decode_token_per_sec = sum(len(seq) for seq in decoded_tokens) / total_time\n        return req_per_sec, decode_token_per_sec\n\n    model_name_2d = args.backend + \"/\" + args.model\n    model_2d = get_model_2d(model_name=model_name_2d,\n                            path=\"~/opt_weights\",\n                            batch_size=args.batch_size_2d)\n\n    model_name_1d = \"alpa/\" + args.model.replace(\"-\", \"-1d-\")\n    model_1d = get_model_1d(model_name=model_name_1d,\n                            path=\"~/opt_weights\",\n                            batch_size=args.batch_size_1d,\n                            cache_size=args.cache_size)\n\n    num_tail = int(args.tail_percentage / 100.0  * len(input))\n\n    print(\"- Benchmark 2D...\")\n    latency_2d, total_time_2d, output_2d = benchmark(model_2d, runner_2d, input)\n    rps_2d, tps_2d = estimate_throughput(input, output_2d, latency_2d, total_time_2d)\n    mean_latency_2d = np.mean(latency_2d)\n    tail_latency_2d = np.mean(latency_2d[np.argsort(latency_2d)[-num_tail:]])\n\n    print(\"- Benchmark 1D...\")\n    latency_1d, total_time_1d, output_1d = benchmark(model_1d, runner_1d, input)\n    rps_1d, tps_1d = estimate_throughput(input, output_1d, latency_1d, total_time_1d)\n    mean_latency_1d = np.mean(latency_1d)\n    tail_latency_1d = np.mean(latency_1d[np.argsort(latency_1d)[-num_tail:]])\n\n    heads = [\n        \"Model\", \"#Prompts\", \"BS (2D)\", \"BS (1D)\", \"Max new tokens\",\n        \"RPS (1D vs. 2D)\", \"TPS (1D vs. 2D)\",\n        \"Mean Latency (1D vs. 2D)\", \"Tail latency (1D vs. 2D)\"\n    ]\n    values = [\n        args.model, args.n_prompt, args.batch_size_2d, args.batch_size_1d, args.max_new_tokens,\n        f\"{rps_1d:.2f}/{rps_2d:.2f} ({rps_1d / rps_2d:.2f}x)\", f\"{tps_1d:.2f}/{tps_2d:.2f} ({tps_1d / tps_2d:.2f}x)\",\n        f\"{mean_latency_1d:.2f}/{mean_latency_2d:.2f} ({mean_latency_2d / mean_latency_1d:.1f}x)\",\n        f\"{tail_latency_1d:.2f}/{tail_latency_2d:.2f} ({tail_latency_2d / tail_latency_1d:.1f}x)\"\n    ]\n    write_tsv(heads, values, \"1d-vs-2d.tsv\")\n"
  },
  {
    "path": "examples/llm_serving/benchmark/benchmark_step_func.py",
    "content": "\"\"\"\nA simpler benchmark script that benchmarks the latency of alpa execution\nwithout the huggingface generator interface.\n\"\"\"\n\nimport argparse\nimport os\nimport time\n\nimport alpa\nfrom alpa.util import write_tsv\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom llm_serving.model import opt_model, bloom_model\nfrom llm_serving.model.wrapper import set_skip_shard_args_check\n\n\ndef run_benchmark(args):\n    name = args.model.split(\"/\")[1].lower()\n    path = os.path.join(args.path, f\"{name}-np\")\n\n    alpa.global_config.shard_parallel_sync_for_timer = True\n    alpa.global_config.pipeline_check_alive = False\n    alpa.global_config.pipeline_sync_for_timer = True\n    alpa.global_config.delete_remote_arrays_threshold = 100\n\n    batch_size = args.batch_size\n    seq_len = 10\n    dummy = args.dummy\n    if \"opt\" in name:\n        m = opt_model\n        def inference_step_with_cache(params, batch):\n            output = model.apply(params,\n                                 batch[\"input_ids\"],\n                                 batch[\"position_ids\"],\n                                 attention_mask=batch[\"mask\"],\n                                 attention_cache=batch[\"cache\"])\n            return output.logits, output.attention_cache\n    else:\n        m = bloom_model\n        def inference_step_with_cache(params, batch):\n            output = model.apply(params,\n                                 batch[\"input_ids\"],\n                                 attention_mask=batch[\"mask\"],\n                                 attention_cache=batch[\"cache\"])\n            return output.logits, output.attention_cache\n\n    if args.parallel_method == \"jit\":\n        config = m.get_config(name)\n        model, params_aval = m.init_model_aval(config)\n        params = m.load_params_np(params_aval, path, config, dummy)\n        cache = m.init_cache_np(config, batch_size)\n        params, cache = jax.tree_map(jnp.array, (params, cache))\n\n        infer_step = jax.jit(inference_step_with_cache)\n        sync_func = lambda: jax.local_devices()[0].synchronize_all_activity()\n        executable = None\n        num_gpus = 1\n    else:\n        if args.parallel_method in [\"shard_local\", \"shard_ray\"]:\n            assert dummy == True, 'Only support dummy weights. Plasese add \"--dummy\".'\n\n            config = m.get_config(name)\n            model, params_aval = m.init_model_aval(config)\n            if args.parallel_method == \"shard_local\":\n                alpa.init(cluster=\"local\")\n            else:\n                alpa.init(cluster=\"ray\")\n            num_gpus = alpa.get_global_num_devices()\n\n            method = alpa.ShardParallel(\n                auto_sharding_option=alpa.AutoShardingOption())\n            infer_step = alpa.parallelize(inference_step_with_cache,\n                                          method=method)\n        else:\n            assert args.parallel_method == \"pipeshard\"\n            alpa.init(cluster=\"ray\")\n            num_gpus = alpa.get_global_num_devices()\n            num_pp_stages = max(2, alpa.get_global_cluster().num_hosts)\n            config = m.get_config(name, num_pp_stages=num_pp_stages)\n            model, params_aval = m.init_model_aval(config)\n\n            method = alpa.PipeshardParallel(\n                num_micro_batches=1,\n                pipeline_schedule=\"inference\",\n                layer_option=\"manual\",\n                default_auto_sharding_option=alpa.AutoShardingOption(\n                    # Force operator model parallel\n                    force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0,\n                    # Disabling all-to-all and all-gather generates better intra-op strategies.\n                    allow_all_to_all=False,\n                    allow_all_gather=False,\n                ))\n            infer_step = alpa.parallelize(inference_step_with_cache, method=method)\n            alpa.global_config.always_donate_micro_batch_vars = False\n\n        executable = infer_step.get_executable(\n            params_aval, {\n                \"input_ids\":\n                    jax.core.ShapedArray((batch_size, 1), jnp.int32),\n                \"position_ids\":\n                    jax.core.ShapedArray((batch_size, 1), jnp.int32),\n                \"cache\":\n                    m.init_cache_aval(config, batch_size),\n                \"mask\":\n                    m.init_mask_aval(config, batch_size),\n            })\n        executable.dump_debug_info(\"tmp\")\n\n        params = m.load_params_dis_array(path, executable, params_aval, config,\n                                         dummy)\n        cache = m.init_cache_dis_array(executable, config, batch_size, dummy)\n        set_skip_shard_args_check(cache)\n        infer_step = executable\n        if args.parallel_method == \"local_shard\":\n            # Already synced by the local timer\n            sync_func = lambda: None\n        else:\n            sync_func = lambda: executable.sync()\n\n    input_ids = np.random.randint(0,\n                                  10000,\n                                  size=(batch_size, seq_len),\n                                  dtype=np.int32)\n    position_ids = opt_model.build_position_ids(input_ids, config.pad)\n    mask = np.ones((batch_size, 1, 1, config.max_seq_len), dtype=np.int8)\n\n    step_latencies = []\n    compute_latencies = []\n    shard_args_latencies = []\n    for i in range(input_ids.shape[1]):\n        input_ids_step = input_ids[:, i:i + 1]\n        position_ids_step = np.full_like(input_ids_step, i + config.pad + 1)\n\n        sync_func()\n        start_time = time.time()\n        infer_step(\n            params, {\n                \"input_ids\": input_ids_step,\n                \"position_ids\": position_ids_step,\n                \"mask\": mask,\n                \"cache\": cache,\n            })\n        sync_func()\n        end_time = time.time()\n\n        step_latencies.append(end_time - start_time)\n        if executable:\n            compute_latencies.append(executable.get_execution_time_costs()[-1])\n            shard_args_latencies.append(\n                executable.get_shard_args_time_costs()[-1])\n        else:\n            compute_latencies.append(step_latencies[-1])\n            shard_args_latencies.append(0)\n\n        print(f\"{i}, step_latency: {step_latencies[-1] * 1000:.2f} ms\")\n\n    warmup = 3\n    heads = [\n        \"Model\", \"Parallel Method\", \"Dummy\", \"#gpu\", \"Step Latency (ms)\",\n        \"Compute Latency (ms)\", \"ShardArgs Latency (ms)\"\n    ]\n    values = [\n        args.model, args.parallel_method, args.dummy, num_gpus,\n        f\"{np.mean(step_latencies[warmup:]) * 1e3:.2f}\",\n        f\"{np.mean(compute_latencies[warmup:]) * 1e3:.2f}\",\n        f\"{np.mean(shard_args_latencies[warmup:]) * 1e3:.2f}\"\n    ]\n    write_tsv(heads, values, \"result_step_func.tsv\")\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"alpa/opt-2.7b\")\n    parser.add_argument(\"--batch-size\", type=int, default=1)\n    parser.add_argument(\"--path\", type=str, default=\"/home/ubuntu/opt_weights/\")\n    parser.add_argument(\"--dummy\", action=\"store_true\")\n    parser.add_argument(\n        \"--parallel-method\",\n        type=str,\n        required=True,\n        choices=[\"jit\", \"shard_local\", \"shard_ray\", \"pipeshard\"])\n    args = parser.parse_args()\n\n    run_benchmark(args)\n"
  },
  {
    "path": "examples/llm_serving/benchmark/benchmark_text_gen.py",
    "content": "\"\"\"benchmark generation performance.\n\nUsages:\n1. benchmark huggingface torch-based OPT generation:\npython3 benchmark_text_gen.py --model facebook/opt-125m --debug\n\n2. benchmark jax.jit based OPT generation without alpa, on a single GPU:\npython3 benchmark_text_gen.py --model jax/opt-125m --debug\n\n3. benchmark alpa parallelized OPT generation:\npython3 benchmark_text_gen.py --model alpa/opt-2.7b --debug\n\n4. benchmark alpa parallelized OPT forward computation, batch_size, encoder length, and #micro_batches can be configured.\npython3 benchmark_text_gen.py --model alpa/opt-2.7b --forward\n    --forward-encoder-length 1024 --nb 1 --batch-size 256 --debug\n\"\"\"\nimport argparse\n\nimport alpa\nfrom alpa.global_env import global_config\nfrom alpa.util import write_tsv\nimport jax.numpy as jnp\nimport numpy as np\nimport time\nimport torch\nfrom transformers import AutoTokenizer\n\nfrom llm_serving.model.opt_utils import compute_gpt_tflops_inference_with_padding\nfrom llm_serving.model.wrapper import get_model\n\ntest_prompts = [\n    \"Computer science is the study of computation and\",\n    \"Ion Stoica is a Romanian-American computer scientist specializing in\",\n    \"The University of California, Berkeley is a public\",\n    \"Today is a good day and I want to\", \"What is the valuation of Databricks?\",\n    \"Paris is the capital city of\", \"Which country has the most population?\",\n    \"What do you think about the future of Cryptocurrency?\",\n    \"What do you think about the meaning of life?\",\n    \"Donald Trump is the president of\",\n    \"GPT-3 is a large language model that is capable of\"\n]\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"alpa/opt-125m\")\n    parser.add_argument(\"--torch-device\", type=str)\n    parser.add_argument(\"--path\", type=str, default=\"~/opt_weights/\")\n    parser.add_argument(\"--dummy\", action=\"store_true\")\n    parser.add_argument(\"--forward\", action=\"store_true\")\n    parser.add_argument(\"--forward-encoder-length\", type=int, default=1024)\n    parser.add_argument(\"--nb\", type=int, default=1)\n    parser.add_argument(\"--batch-size\", type=int, default=1)\n    parser.add_argument(\"--n-warmup\", type=int, default=1)\n    parser.add_argument(\"--n-iter\", type=int, default=10)\n    parser.add_argument(\"--max-length\", type=int, default=256)\n    parser.add_argument(\"--pad-to-max-length\", type=int)\n    parser.add_argument(\"--num-beams\", type=int, default=1)\n    parser.add_argument(\"--debug\", action=\"store_true\")\n    parser.add_argument(\"--dtype\", type=str, default=\"fp16\")\n    args = parser.parse_args()\n\n    # Some global params\n    global_config.pipeline_sync_for_timer = True\n    global_config.shard_parallel_sync_for_timer = True\n\n    # Do some param check\n    n_warmup = args.n_warmup\n    n_iters = args.n_iter\n    max_length = args.max_length\n    num_micro_batches = args.nb\n    batch_size = args.batch_size\n    num_beams = args.num_beams\n    autoregressive = not args.forward\n    dtype = jnp.float16 if args.dtype == \"fp16\" else jnp.float32\n\n    if autoregressive:\n        assert num_micro_batches == 1, \"we only support num_micro_batches=1 for autoregressive!\"\n\n    if args.torch_device:\n        torch_device = args.torch_device\n    else:\n        if \"alpa\" in args.model or \"jax\" in args.model:\n            # alpa/jax prefer cpu backend of pytorch to avoid memory conflict\n            torch_device = \"cpu\"\n        else:\n            torch_device = \"cuda\"\n\n    decode_speeds = []\n    tflopss = []\n    compute_tflopss = []\n\n    if not autoregressive: # Forward mode\n        raise RuntimeError(\"This branch is deprecated\")\n        # Increase the frequency of deleting buffers to avoid OOM.\n        global_config.delete_remote_arrays_threshold = 1\n        seq_len = args.forward_encoder_length\n        encoder_chunk_sizes = [seq_len]\n\n        tic = time.time()\n        model, params, transformer_config = get_model(\n            args.model,\n            path=args.path,\n            torch_device=torch_device,\n            dummy=args.dummy,\n            autoregressive=autoregressive,\n            max_target_positions=seq_len,\n            dtype=dtype,\n            batch_size=batch_size,\n            encoder_chunk_sizes=encoder_chunk_sizes,\n            num_micro_batches=num_micro_batches)\n        load_time = time.time() - tic\n\n        # create batch\n        input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n\n        # get model config\n        H = transformer_config.H\n        L = transformer_config.L\n        seq_len = transformer_config.seq_len\n        vocab_size = transformer_config.vocab_size\n\n        num_gpus = alpa.get_global_cluster(\n        ).num_devices if \"alpa\" in args.model else 1\n\n        # warm up\n        for _ in range(n_warmup):\n            forward_results = model(params, {\n                \"input_ids\": input_ids,\n                \"position_ids\": position_ids\n            })\n            model.sync()\n\n        # benchmark\n        for i in range(n_iters):\n            torch.manual_seed(8)\n\n            tic = time.time()\n            forward_results = model(params, {\n                \"input_ids\": input_ids,\n                \"position_ids\": position_ids\n            })\n            model.sync()\n            # a = np.array(forward_results)\n            # print(a)\n            latency = time.time() - tic\n\n            compute_latency = model.get_execution_time_costs()[-1]\n            # print(f\"input length: {input_ids.shape[1]}, output_length: {input_ids.shape[1]}, num_gpus: {num_gpus}\")\n            assert seq_len == input_ids.shape[1]\n\n            memory_allocated = model.mesh_group.get_memory_allocated() / 1e9\n            max_memory_allocated = model.mesh_group.get_max_memory_allocated(\n            ) / 1e9\n\n            tflops = compute_gpt_tflops_inference_with_padding(\n                batch_size, seq_len, seq_len, L, H, vocab_size,\n                num_gpus, latency)\n            compute_tflops = compute_gpt_tflops_inference_with_padding(\n                batch_size, seq_len, seq_len, L, H, vocab_size,\n                num_gpus, compute_latency)\n            speed = np.prod(input_ids.shape) / latency\n\n            if args.debug:\n                print(\n                    f\"speed: {speed:.2f} token/s, E2E tflops: {tflops:.4f}, compute tflops: {compute_tflops:.4f}, \"\n                    f\"memory: {memory_allocated}, max memory: {max_memory_allocated}\"\n                )\n            decode_speeds.append(speed)\n            tflopss.append(tflops)\n            compute_tflopss.append(compute_tflops)\n    else: # Generation mode\n        encoder_chunk_sizes = (1, 64)\n        generate_args = {\n            \"do_sample\": False,\n            \"num_beams\": num_beams,\n            \"return_dict_in_generate\": True\n        }\n\n        # Note(Hao): we need to use \"opt-30b\" and disable \"add_bos_token\".\n        tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-30b\",\n                                                  use_fast=False)\n        tokenizer.add_bos_token = False\n\n        tic = time.time()\n        model = get_model(args.model,\n                          args.path,\n                          torch_device=torch_device,\n                          dummy=args.dummy,\n                          dtype=dtype,\n                          encoder_chunk_sizes=encoder_chunk_sizes,\n                          **generate_args)\n        load_time = time.time() - tic\n\n        H = model.transformer_config.H\n        L = model.transformer_config.L\n        seq_len = model.transformer_config.seq_len\n        vocab_size = model.transformer_config.vocab_size\n        if \"alpa\" in args.model:\n            num_gpus = alpa.get_global_num_devices()\n        else:\n            num_gpus = 1\n\n        # Benchmark all prompts\n        for i in range(min(args.n_iter, len(test_prompts))):\n            prompt = test_prompts[i]\n            torch.manual_seed(8)\n            if args.pad_to_max_length:\n                input_ids = tokenizer(prompt,\n                                      padding=\"max_length\",\n                                      max_length=args.pad_to_max_length,\n                                      return_tensors=\"pt\").input_ids.to(torch_device)\n            else:\n                input_ids = tokenizer(prompt,\n                                      return_tensors=\"pt\").input_ids.to(torch_device)\n\n            # Warm up\n            for _ in range(n_warmup):\n                model.generate(input_ids=input_ids,\n                               max_length=max_length,\n                               **generate_args)\n\n            # Benchmark a prompt\n            tic = time.time()\n            output = model.generate(input_ids=input_ids,\n                                    max_length=max_length,\n                                    **generate_args)\n            latency = time.time() - tic\n            generated_ids = output.sequences\n            generated_string = tokenizer.batch_decode(generated_ids,\n                                                      skip_special_tokens=True)\n\n            gen_len = generated_ids.shape[1]\n\n            if \"alpa\" in args.model:\n                compute_latency = sum(\n                    model.executable.get_execution_time_costs()[-gen_len:])\n            else:\n                compute_latency = latency\n            tflops = compute_gpt_tflops_inference_with_padding(\n                num_beams * batch_size, gen_len, seq_len, L, H, vocab_size,\n                num_gpus, latency)\n            compute_tflops = compute_gpt_tflops_inference_with_padding(\n                num_beams * batch_size, gen_len, seq_len, L, H, vocab_size,\n                num_gpus, compute_latency)\n            speed = np.prod(generated_ids.shape) / latency\n            if args.debug:\n                print(\n                    f\"input length: {input_ids.shape[1]}, output_length: {generated_ids.shape[1]}, \"\n                    f\"num_gpus: {num_gpus}, speed: {speed:.2f} tokens/s, tflops: {tflops:.4f} tflops/s\"\n                )\n                print(generated_string)\n            decode_speeds.append(speed)\n            tflopss.append(tflops)\n            compute_tflopss.append(compute_tflops)\n\n    avg_speed = np.mean(decode_speeds)\n    avg_tflops = np.mean(tflopss)\n    avg_compute_tflops = np.mean(compute_tflopss)\n    latency_32_tokens = 32.0 / (avg_speed / batch_size)\n    num_pp_stages = 2\n\n    heads = [\n        \"Model\", \"Torch device\", \"Dummy\", \"Load (s)\", \"Autoregressive\", \"Batch size\",\n        \"#Microbatches\", \"#Beams\", \"#Stages\", \"Encoder chunk sizes\", \"TFlops\",\n        \"Compute TFlops\", \"Speed (token/s)\", \"latency (32 token)\"\n    ]\n    values = [\n        args.model, torch_device, args.dummy, f\"{load_time:.2f}\",\n        f\"{autoregressive}\", f\"{batch_size}\", f\"{num_micro_batches}\",\n        f\"{num_beams}\", f\"{num_pp_stages}\", f\"{encoder_chunk_sizes}\",\n        f\"{avg_tflops:.4f}\", f\"{avg_compute_tflops:.4f}\", f\"{avg_speed:.2f}\",\n        f\"{latency_32_tokens:.2f}\"\n    ]\n    write_tsv(heads, values, \"results.tsv\")\n"
  },
  {
    "path": "examples/llm_serving/client.py",
    "content": "import argparse\nfrom typing import Dict, Optional, Union, Sequence\n\nimport requests\n\nDEFAULT_URL = \"https://api.alpa.ai\"\n\nheaders = {\"User-Agent\": \"Alpa Client\"}\n\n\nclass Client(object):\n\n    def __init__(self,\n                 url: Optional[str] = None,\n                 api_key: Optional[str] = None,\n                 default_model: str = \"default\") -> None:\n        if url is None:\n            url = DEFAULT_URL\n\n        self.api_key = api_key\n        self.default_model = default_model\n        self.completions_url = url + \"/completions\"\n        self.logprobs_url = url + \"/logprobs\"\n\n    def completions(\n        self,\n        prompt: Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]],\n        min_tokens: int = 0,\n        max_tokens: int = 32,\n        top_p: float = 1.0,\n        temperature: float = 1.0,\n        echo: bool = True,\n        model: Optional[str] = None,\n    ) -> Dict:\n        \"\"\"\n        Generation API.\n        Parameters match those of the OpenAI API.\n        https://beta.openai.com/docs/api-reference/completions/create\n\n        Args:\n          prompt: a list of tokenized inputs.\n          min_tokens: The minimum number of tokens to generate.\n          max_tokens: The maximum number of tokens to generate.\n          temperature: What sampling temperature to use.\n          top_p: The nucleus sampling probability.\n          echo: if true, returned text/tokens/scores includes the prompt.\n        \"\"\"\n        pload = {\n            \"model\": model or self.default_model,\n            \"prompt\": prompt,\n            \"min_tokens\": min_tokens,\n            \"max_tokens\": max_tokens,\n            \"temperature\": temperature,\n            \"top_p\": top_p,\n            \"echo\": echo,\n            \"api_key\": self.api_key\n        }\n        result = requests.post(self.completions_url, json=pload, headers=headers)\n        return self.result_or_error(result)\n\n    def logprobs(\n        self,\n        prompt: Union[str, Sequence[str], Sequence[int], Sequence[Sequence[int]]],\n        top_k: int = 50,\n        cache_id: Optional = None,\n        model: Optional[str] = None) -> Dict:\n        \"\"\"Return the log probability of the next top-k tokens\"\"\"\n        pload = {\n            \"model\": model or self.default_model,\n            \"prompt\": prompt,\n            \"top_k\": top_k,\n            \"api_key\": self.api_key\n        }\n        if cache_id:\n            pload[\"cache_id\"] = cache_id\n        result = requests.post(self.logprobs_url, json=pload, headers=headers)\n        return self.result_or_error(result)\n\n    def result_or_error(self, result):\n        result = result.json()\n        if result.get(\"type\", \"\") == \"error\":\n            raise RuntimeError(\n                result[\"stacktrace\"] +\n                f'RuntimeError(\"{result[\"message\"]}\")')\n        else:\n            return result\n"
  },
  {
    "path": "examples/llm_serving/codegen.py",
    "content": "\"\"\"Use huggingface/transformers interface and Alpa backend for distributed inference.\"\"\"\nimport argparse\n\nimport numpy as np\nfrom transformers import AutoTokenizer\n\nfrom llm_serving.model.wrapper import get_model\n\ndef main(args):\n    # Load the tokenizer.\n    if \"codegen\" in args.model:\n        name = args.model.replace(\"alpa\", \"Salesforce\")\\\n                         .replace(\"jax\", \"Salesforce\")\n        tokenizer = AutoTokenizer.from_pretrained(name, padding_side = \"left\")\n        tokenizer.pad_token = 50256\n    generate_params = {\n        \"do_sample\": args.do_sample,\n        \"num_beams\": args.num_beams,\n        \"num_return_sequences\": args.num_return_sequences\n    }\n\n    # Load the model\n    model = get_model(model_name=args.model,\n                      path=\"~/codegen_weights\",\n                      batch_size=args.n_prompts,\n                      **generate_params)\n\n    # Generate\n    prompts = [\n        \"# This function prints hello world.\\n\",\n        \"def fib(k):\\n    # Returns the k-th Fibonacci number.\\n\",\n        \"def is_prime(n):\\n    # Return whether n is a prime number.\\n\",\n        \"def return_len(s):\\n    # Return the length of s.\\n\",\n    ]\n    prompts = prompts[:args.n_prompts]\n\n    input_ids = tokenizer(prompts, return_tensors=\"pt\", padding=\"longest\").input_ids\n    \n    output_ids = model.generate(input_ids=input_ids,\n                                max_length=64,\n                                **generate_params)\n    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True,\n                                     truncate_before_pattern=[r\"\\n\\n^#\", \"^'''\", \"\\n\\n\\n\"])\n\n    # Print results\n    print(\"Outputs:\\n\" + 100 * '-')\n    for i, output in enumerate(outputs):\n        print(f\"{i}: {output}\")\n        print(100 * '-')\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"alpa/codegen-2B-mono\")\n    # help: see https://github.com/salesforce/CodeGen for a list of available models.\n    parser.add_argument('--do-sample', action='store_true')\n    parser.add_argument('--num-beams', type=int, default=1)\n    parser.add_argument('--num-return-sequences', type=int, default=1)\n    parser.add_argument('--n-prompts', type=int, default=4)\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/llm_serving/generator.py",
    "content": "import time\nfrom typing import List, Optional\n\nimport numpy as np\nimport torch\nfrom transformers import AutoTokenizer\n\nfrom llm_serving.model.wrapper import get_model\nfrom llm_serving.model.opt_utils import compute_gpt_tflops_inference_with_padding\nfrom llm_serving.service.utils import build_logger\n\n\nclass Generator:\n    \"\"\"The generator interface.\n\n    This class wraps tokenizer and the langauge model.\n    \"\"\"\n\n    def __init__(self,\n                 model_name,\n                 path,\n                 torch_device=\"cpu\",\n                 tokenizer_name=None,\n                 add_bos_token=False,\n                 max_seq_len=1024,\n                 max_batch_size=4,\n                 do_sample=False,\n                 num_beams=1,\n                 num_return_sequences=1):\n        self.logger = build_logger()\n\n        # Model arguments\n        self.model_name = model_name\n        self.path = path\n        self.model_wrapper = None\n        self.torch_device = torch_device\n\n        # Tokenizer arguments\n        self.tokenizer_name = tokenizer_name\n        self.tokenizer = None\n        self.add_bos_token = add_bos_token\n\n        # Generation arguments\n        self.max_seq_len = max_seq_len\n        self.max_batch_size = max_batch_size\n        self.do_sample = do_sample\n        self.num_beams = num_beams\n        self.num_return_sequences = num_return_sequences\n\n        # Others\n        self.num_gpus = None\n        self.dataset_to_epoch_iter = dict()\n\n        # Initialize models\n        self.load_model()\n\n    def load_model(self):\n        \"\"\"Compile and load a model.\"\"\"\n        tic = time.time()\n\n        # Init model\n        self.model_wrapper = get_model(self.model_name, self.path,\n                                       torch_device=self.torch_device,\n                                       batch_size=self.max_batch_size,\n                                       encoder_chunk_sizes=[1, 64],\n                                       max_seq_len=self.max_seq_len,\n                                       num_beams=self.num_beams,\n                                       num_return_sequences=self.num_return_sequences,\n                                       do_sample=self.do_sample)\n        load_time = time.time() - tic\n\n        # Init tokenizer\n        if self.tokenizer_name:\n            self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name)\n        else:\n            if \"opt\" in self.model_name:\n                self.tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-30b\")\n                self.tokenizer.add_bos_token = False\n            elif \"bloom\" in self.model_name:\n                tokenizer_name = self.model_name.replace(\"alpa\", \"bigscience\")\\\n                                                .replace(\"jax\", \"bigscience\")\n                self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n            else:\n                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)\n\n        if \"alpa\" in self.model_name:\n            import alpa\n            self.num_gpus = alpa.get_global_cluster().num_devices\n        else:\n            self.num_gpus = 1\n        self.logger.info(f\"Loading model time: {load_time:.2f}\")\n\n    def encode(self, s: str):\n        \"\"\"Tokenize strings\"\"\"\n        # note that web browsers send \\r\\n but our training data uses \\n.\n        s = s.replace(\"\\r\\n\", \"\\n\").replace(\"\\r\", \"\\n\")\n        return self.tokenizer.encode(s)\n\n    def generate(\n        self,\n        inputs: List[List[int]],\n        min_tokens: List[int],\n        max_tokens: List[int],\n        temperature: float,\n        top_p: float,\n        n: int,\n        echo: bool,\n        best_of: int,\n    ):\n        \"\"\"\n        Generation API.\n        Parameters match those of the OpenAI API.\n        https://beta.openai.com/docs/api-reference/completions/create\n\n        Args:\n          inputs: a list of tokenized inputs.\n          min_tokens: The minimum number of tokens to generate.\n          max_tokens: The maximum number of tokens to generate.\n          temperature: What sampling temperature to use.\n          top_p: The nucleus sampling probability.\n          n: How many completions to generate for each prompt.\n          echo: if true, returned text/tokens/scores includes the prompt.\n          best_of: Generates best_of completions server-side and returns the \"best\" (the one with the highest log probability per token)\n        \"\"\"\n        start_time = time.time()\n        total_inference_time = 0\n        batch_id = next_serve_batch_uuid()\n        ori_bs = len(inputs)\n        self.logger.info(f\"Generate begin. batch id: {batch_id}, batch size: {ori_bs}\")\n\n        # Check arguments\n        assert best_of == self.num_beams, \"model must be instantiated and used with the same num_beams\"\n        assert n == self.num_return_sequences, \"model must be instantiated and used with the same num_return_sequences\"\n        if temperature <= 1e-3:\n            do_sample = False\n        else:\n            do_sample = self.do_sample\n        # Resolve the max sequence length allowed from multiple sources\n        max_seq_len = min(self.max_seq_len,\n                          self.model_wrapper.transformer_config.seq_len)\n\n        # Pad the batch to a maximum batch size\n        input_ids = pad_batch(inputs, self.tokenizer.pad_token_id, self.max_batch_size)\n        input_ids = torch.IntTensor(input_ids).to(self.torch_device)\n        input_lens = [len(x) for x in inputs]\n        batch_size = len(input_ids)\n\n        # Set generation args\n        if min_tokens is None:\n            min_tokens = [0] * batchsize\n        if max_tokens is None:\n            max_tokens = [max_seq_len] * batchsize\n        min_length = max(min_tokens) + max(input_lens)\n        max_length = min(max_seq_len, max(max_tokens) + max(input_lens))\n\n        generator_args = {\n            \"min_length\": min_length,\n            \"max_length\": max_length,\n            \"temperature\": temperature,\n            \"do_sample\": do_sample,\n            \"top_p\": top_p,\n            \"num_beams\": best_of,\n            \"num_return_sequences\": n,\n            \"early_stopping\": True,\n            \"repetition_penalty\": 1.0,\n            \"no_repeat_ngram_size\": 8,\n        }\n\n        self.logger.info(\n            f\"Call generate. batch id: {batch_id}, \"\n            f\"padded bs: {batch_size}, original bs: {ori_bs}, \"\n            f\"generator_args: {generator_args}.\")\n\n        inference_start_time = time.time()\n        output_ids = self.model_wrapper.generate(input_ids=input_ids, **generator_args)\n        inference_time = time.time() - inference_start_time\n        output_ids = torch.reshape(output_ids, (batch_size, self.num_return_sequences, -1))\n\n        tflops, speed, token_32_latency = self.estimate_performance(\n            output_ids, inference_time)\n\n        # Decode results to strings\n        ret = []\n        for i in range(ori_bs):\n            tmp_ret = []\n            for tokens in output_ids[i]:\n                prompt_len = input_lens[i]\n                if echo:\n                    tokens = tokens[:prompt_len + max_tokens[i]]\n                else:\n                    tokens = tokens[prompt_len:prompt_len + max_tokens[i]]\n                text = self.tokenizer.decode(tokens, skip_special_tokens=True)\n                result = {\"text\": text}\n                tmp_ret.append(result)\n            ret.append(tmp_ret)\n\n        self.logger.info(\n            f\"Generate end. batch id: {batch_id}. batch size: {ori_bs}, \"\n            f\"e2e latency: {time.time() - start_time:.2f} s, \"\n            f\"inference latency: {inference_time:.2f} s, \"\n            f\"speed: {speed:.2f} token/s, \"\n            f\"32 token latency: {token_32_latency:.2f} s, \"\n            f\"tflops: {tflops:.2f} TFLOPS\")\n        return ret\n\n    def forward(\n        self,\n        inputs,\n        cache_id,\n        pasts=None,\n    ):\n        self.logger.info(f\"Forward begin. cache_id: {cache_id}\")\n        time_start = time.time()\n\n        inputs = pad_batch(inputs, self.tokenizer.pad_token_id, self.max_batch_size)\n        input_ids = torch.IntTensor(inputs).to(self.torch_device)\n\n        attention_mask = self.model_wrapper._prepare_attention_mask_for_generation(input_ids, pad_token_id=self.model_wrapper.config.pad_token_id, eos_token_id=self.model_wrapper.config.eos_token_id)\n        model_inputs = self.model_wrapper.prepare_inputs_for_generation(input_ids, past=pasts[cache_id][1] if pasts is not None else None, attention_mask=attention_mask)\n        output = self.model_wrapper(**model_inputs)\n\n        self.logger.info(f\"Forward end. e2e latency: {time.time() - time_start:.2f}\")\n        return output\n\n    def estimate_performance(self, output_ids, latency):\n        \"\"\"Report the tflops, decoding speed, and latency for decoding 32 tokens.\"\"\"\n        # TODO(Hao): (1) we are still over-computing\n        transformer_config = self.model_wrapper.transformer_config\n\n        batch_size = self.num_beams * len(output_ids)\n        gen_len = max(t[0].shape[0] for t in output_ids)\n        seq_len = transformer_config.seq_len\n        H = transformer_config.H\n        L = transformer_config.L\n        vocab_size = transformer_config.vocab_size\n        tflops = compute_gpt_tflops_inference_with_padding(\n            batch_size, gen_len, seq_len, L, H, vocab_size, self.num_gpus,\n            latency)\n        speed = batch_size * gen_len / latency\n        token_32_latency = 32.0 / (speed / len(output_ids))\n        return tflops, speed, token_32_latency\n\n\ndef pad_batch(inputs, pad_value, max_batch_size):\n    \"\"\"Pad the batch to max_batch_size.\"\"\"\n    new_inputs = inputs\n    src_lens = [len(input) for input in inputs]\n    max_len = max(src_lens)\n    bs = len(inputs)\n\n    # Pad to max_len\n    for new_input in new_inputs:\n        ori_len = len(new_input)\n        if len(new_input) < max_len:\n            new_input.extend([pad_value for _ in range(max_len - ori_len)])\n\n    # Pad to max_batch_size\n    if bs < max_batch_size:\n        new_inputs.extend([[pad_value for _ in range(max_len)] for _ in range(max_batch_size - bs)])\n    return new_inputs\n\n\nserve_batch_counter = 0\n\ndef next_serve_batch_uuid(number=1):\n    \"\"\"Return the next uuid of a remote buffer.\"\"\"\n    global serve_batch_counter\n    if number == 1:\n        ret = serve_batch_counter\n    else:\n        ret = np.arange(serve_batch_counter, serve_batch_counter + number)\n    serve_batch_counter = (serve_batch_counter + number) % (1 << 60)\n    return ret\n"
  },
  {
    "path": "examples/llm_serving/launch_model_worker.py",
    "content": "import asyncio\nimport argparse\nfrom collections import deque, defaultdict, namedtuple\nfrom dataclasses import dataclass, field\nimport json\nimport time\nfrom typing import Any\nimport uuid\n\nimport alpa\nfrom alpa.serve import run_controller, CONTROLLER_NAME\nimport ray\nimport torch\n\nfrom llm_serving.generator import Generator\nfrom llm_serving.service.constants import (\n    NUM_BEAMS, NUM_RETURN_SEQ, ALPA_SERVE_PORT, USE_RECAPTCHA, USE_API_KEYS,\n    ALLOW_NON_KEY_ACCESS, KEYS_FILENAME, AuthGroups, AUTH_GROUP_WEIGHTS,\n    AUTH_GROUP_SCHEDULER_SCALE, API_KEY_SCHEDULER_SCALE,\n    API_KEY_DEFAULT_WEIGHT, LOGPROBS_PRIORITY_TIME_LIMIT_S)\nfrom llm_serving.service.recaptcha import load_recaptcha\nfrom llm_serving.service.scheduler import (\n    WeightedRoundRobin, NestedScheduler, FrontQueueScheduler, AsyncWrapper)\nfrom llm_serving.service.utils import build_logger\n\n\nGenerateItem = namedtuple(\"GenerateItem\", [\"uid\", \"return_queue\", \"data\"])\nLogprobsItem = namedtuple(\"LogprobsItem\", [\"uid\", \"return_queue\", \"data\"])\n\n\nclass LangaugeModelWorker:\n    def __init__(self,\n                 model_name: str,\n                 path: str,\n                 torch_device: str,\n                 tokenizer_name: str,\n                 num_beams: int,\n                 num_return_sequences: int,\n                 use_recaptcha: bool,\n                 use_api_keys: bool,\n                 allow_non_key_access: bool,\n                 max_seq_len: int = 1024,\n                 max_batch_size: int = 4,\n                 logprobs_past_cache_size_limit: int = 4,\n                 batch_wait_size_mult: int = 10,\n                 batch_timeout: float = 1.0,\n                 queue_timeout: float = 0.001):\n\n        self.logger = build_logger()\n        self.num_beams = num_beams\n        self.num_return_sequences = num_return_sequences\n        self.max_seq_len = max_seq_len\n\n        # Batch queues\n        self.max_bs = max_batch_size\n        self.batch_wait_size_mult = batch_wait_size_mult\n        self.batch_timeout = batch_timeout\n        self.queue_timeout = queue_timeout\n        self.logprobs_past_cache = defaultdict(lambda: (0, None, (), 0))\n        self.logprobs_past_cache_size_limit = logprobs_past_cache_size_limit\n        asyncio.get_event_loop().create_task(self.batch_loop())\n\n        # Load model\n        if num_beams > 1: # beam search is on, disable sampling\n            do_sample = False\n        else:\n            do_sample = True\n\n        self.generator = Generator(model_name,\n                                   path,\n                                   torch_device=torch_device,\n                                   tokenizer_name=tokenizer_name,\n                                   num_beams=num_beams,\n                                   num_return_sequences=num_return_sequences,\n                                   max_seq_len=self.max_seq_len,\n                                   max_batch_size=self.max_bs,\n                                   do_sample=do_sample)\n\n        # Authentication\n        self.allowed_api_keys = []\n        self.recaptcha = load_recaptcha(use_recaptcha)\n        self.allow_non_key_access = allow_non_key_access\n        api_key_weights = {}\n        if use_api_keys:\n            keys = json.load(open(KEYS_FILENAME, \"r\"))\n            self.allowed_api_keys = keys[\"allowed_api_keys\"]\n            if \"api_key_weights\" in keys:\n                api_key_weights = keys[\"api_key_weights\"]\n\n        # Scheduling\n        # Each authentication choice is assigned a separate queue, and\n        # these queues are given fixed weights independent of how many\n        # requests are within each group. Requests that use API keys are\n        # further organized based on the API key weights.\n        inner_schedulers = {}\n        for auth_group in AuthGroups:\n            if auth_group == AuthGroups.API_KEY_USER:\n                inner_schedulers[auth_group] = WeightedRoundRobin(\n                    api_key_weights,\n                    API_KEY_SCHEDULER_SCALE,\n                    API_KEY_DEFAULT_WEIGHT)\n            else:\n                inner_schedulers[auth_group] = deque()\n        self.request_queue = NestedScheduler(\n            WeightedRoundRobin(\n                AUTH_GROUP_WEIGHTS, AUTH_GROUP_SCHEDULER_SCALE, None),\n            inner_schedulers)\n        # To support batching completion requests without shuffling the order\n        # of logprob requests, we return the temporarily unqueued logprob\n        # requests to the front of the queue.\n        self.request_queue = AsyncWrapper(FrontQueueScheduler(\n            self.request_queue))\n\n    async def batch_loop(self):\n        while True:\n            item = (await self.request_queue.get())[1][1]\n\n            # Get the next batch\n            generate_batch = []\n            logprobs_item = None\n            non_batch = []\n            if isinstance(item, GenerateItem):\n                batch_wait_size = self.batch_wait_size_mult * self.max_bs\n                if self.request_queue.qsize() < batch_wait_size:\n                    # Wait for batch opportunity\n                    await asyncio.sleep(self.batch_timeout)\n                else:\n                    # Yield control until new requests are queued\n                    await asyncio.sleep(self.queue_timeout)\n                generate_batch.append(item)\n\n                while (not self.request_queue.empty() and\n                       len(generate_batch) < self.max_bs):\n                    queue_entry = self.request_queue.get_nowait()\n                    item = queue_entry[1][1]\n                    if isinstance(item, GenerateItem):\n                        generate_batch.append(item)\n                    else:\n                        non_batch.append(queue_entry)\n                        break\n\n                # Return non-batch items to the front of the request queue\n                while len(non_batch) > 0:\n                    self.request_queue.put_nowait_special(\n                        lambda scheduler, arg: scheduler.appendleft(arg),\n                        non_batch.pop())\n            elif isinstance(item, LogprobsItem):\n                logprobs_item = item\n            else:\n                raise RuntimeError(f\"Invalid item: {item}\")\n\n            # Process this batch\n            if generate_batch:\n                args = {\n                    \"inputs\": [],\n                    \"min_tokens\": [],\n                    \"max_tokens\": [],\n                }\n                for item in generate_batch:\n                    args[\"inputs\"].append(item.data[\"input\"])\n                    args[\"min_tokens\"].append(item.data[\"min_tokens\"])\n                    args[\"max_tokens\"].append(item.data[\"max_tokens\"])\n                    # FIXME: Now we assume all items have the same remaining args\n                    for key in [\n                        \"temperature\", \"top_p\", \"n\", \"best_of\", \"echo\",\n                    ]:\n                        args[key] = item.data[key]\n                results = self.generator.generate(**args)\n                for item, res in zip(generate_batch, results):\n                    item.return_queue.put_nowait((item.uid, res))\n\n            elif logprobs_item:\n                logprobs_past_cache = self.logprobs_past_cache\n                arg = logprobs_item.data\n                inputs = arg[\"input\"]\n                inputs_copy = tuple(tuple(s) for s in inputs)\n                num_inputs = len(inputs)\n                cache_id = arg[\"cache_id\"]\n                first_entry_time = None\n                if cache_id in self.logprobs_past_cache:\n                    prev_inputs = logprobs_past_cache[cache_id][2]\n                    try:\n                        assert len(prev_inputs) == num_inputs\n                        assert all(pl == cl[:-1] for (pl, cl) in\n                                   zip(prev_inputs, inputs_copy))\n                    except AssertionError:\n                        logprobs_item.return_queue.put_nowait(\n                            ValueError(\"Request does not extend cached request \"\n                                       \"by one token; you are probably using \"\n                                       \"the logprobs endpoint incorrectly.\"))\n                        del logprobs_past_cache[cache_id]\n                        continue\n                    first_entry_time = logprobs_past_cache[cache_id][3]\n                # do the actual generations\n                output = self.generator.forward(inputs, cache_id, pasts=logprobs_past_cache)\n                # add to or update the cache with newly computed values\n                curr_time = time.time()\n                if first_entry_time is None:\n                    first_entry_time = curr_time\n                logprobs_past_cache[cache_id] = (\n                    curr_time, output.past_key_values, inputs_copy, first_entry_time)\n                # delete oldest key in cache if cache too big\n                while len(logprobs_past_cache) > self.logprobs_past_cache_size_limit:\n                    oldest_key = min(list(logprobs_past_cache.keys()), key=lambda k: logprobs_past_cache[k][0])\n                    del logprobs_past_cache[oldest_key]\n\n                logits = output.logits[:num_inputs, -1]\n                logprobs = torch.log_softmax(logits, dim=-1)\n                top_k = min(arg[\"top_k\"], logprobs.shape[1])\n                top_logprobs, top_indices = logprobs.topk(top_k, dim=1)\n\n                # return at most top_k tokens, e.g. if network limited\n                return_dict = {\n                    'logprobs': top_logprobs.cpu().tolist(),\n                    'indices': top_indices.cpu().tolist()\n                }\n                # broadcast them back\n                logprobs_item.return_queue.put_nowait((logprobs_item.uid, return_dict))\n\n    async def handle_request(self, request):\n        args = await request.json()\n        authorization = self.get_authorization(args, request)\n\n        if \"completions\" in request.url.path:\n            return await self.completions(args, request, authorization)\n        elif \"logprobs\" in request.url.path:\n            return await self.logprobs(args, request, authorization)\n        else:\n            raise ValueError(\"Invalid url: {request.url}\")\n\n    def normalize_prompts(self, prompts):\n        # prompt can be 4 types:\n        # - case 1: str. Basic case. Return one generation.\n        # - case 2: List[str]. Multiple generations, one per prompt.\n        # - case 3: List[int]. Pretokenized. Return one generation.\n        # - case 4: List[List[int]]. Pretokenized multiple generations.\n        # our approach is to turn everything into the case 4\n        try:\n            if isinstance(prompts, str):  # case 1\n                prompts = [self.generator.encode(prompts)]\n            elif isinstance(prompts, list) and isinstance(prompts[0], str):\n                assert all(isinstance(v, str) for v in prompts)\n                prompts = [self.generator.encode(p) for p in prompts]\n            elif isinstance(prompts, list) and isinstance(prompts[0], int):\n                prompts = [prompts]\n            assert isinstance(prompts, list)\n            for sublist in prompts:\n                assert isinstance(sublist, list)\n                assert all(isinstance(v, int) for v in sublist)\n                assert all(v + (1 << 63) < (1 << 64) for v in sublist)\n        except AssertionError:\n            raise ValueError(\n                \"The prompt must be either a string, a list of strings, a \"\n                \"list of integers, or a list of integer lists.\")\n        if len(prompts[0]) <= 0 or \\\n                any(len(sublist) <= 0 for sublist in prompts):\n            raise ValueError(\"The prompt must be nonempty.\")\n        return prompts\n\n    async def completions(self, args, request, authorization):\n        logger = self.logger\n\n        # Normalize prompts\n        prompts = args[\"prompt\"]\n        prompts = self.normalize_prompts(prompts)\n\n        # Generation arguments\n        args[\"min_tokens\"] = int(args.get(\"min_tokens\", 0))\n        args[\"max_tokens\"] = int(args.get(\"max_tokens\", self.max_seq_len))\n\n        if self.num_beams > 1:\n            # if beam search is enabled, disable all sampling\n            args[\"temperature\"] = 0.0\n            args[\"top_p\"] = 0.0\n        else:\n            args[\"temperature\"] = round(float(args.get(\"temperature\", 1.0)), 1)\n            args[\"top_p\"] = round(float(args.get(\"top_p\", 1.0)), 1)\n\n        assert 0 <= args[\"top_p\"] <= 1\n        assert 0 <= args[\"temperature\"]\n\n        args[\"n\"] = int(args.get(\"n\", self.num_return_sequences))\n        args[\"echo\"] = bool(args.get(\"echo\", False))\n        args[\"best_of\"] = self.num_beams\n\n        if \"stop\" in args:\n            raise NotImplementedError(\"The stop argument is not implemented\")\n\n        logger.info(f\"Received new generate request: \"\n                    f\"prompt length {[len(p) for p in prompts]}, \"\n                    f\"max_len: {args.get('max_tokens', 0)}, \"\n                    f\"temperature: {args['temperature']}, \"\n                    f\"top_p: {args['top_p']}, \"\n                    f\"api_key: {args.get('api_key', None)}, \"\n                    f\"ip: {self.get_remote_ip(request)}, \"\n                    f\"tstamp: {request.scope['tstamp']}\")\n\n        cur_len = max(len(p) for p in prompts)\n        self.check_max_length_limit(cur_len, self.max_seq_len)\n\n        # Push the requests to the batch queue\n        return_queue = asyncio.Queue()\n        for i, prompt in enumerate(prompts):\n            data = {\"input\": prompt, **args}\n            queue_entry = GenerateItem(i, return_queue, data)\n            auth_group, api_key = authorization\n            queue_entry = (auth_group, (api_key, queue_entry))\n            self.request_queue.put_nowait(queue_entry)\n\n        unordered_results = []\n        for i in range(len(prompts)):\n            unordered_results.append(await return_queue.get())\n\n        # Sort results by the original ordering\n        reordered = sorted(unordered_results, key=lambda x: x[0])\n        results = []\n        for _, generations in reordered:\n            results += generations\n\n        # Transform the results into the openai format\n        return {\n            \"id\": str(uuid.uuid4()),\n            \"object\": \"text_completion\",\n            \"created\": int(time.time()),\n            \"choices\": [\n                {\n                    \"text\": result[\"text\"],\n                    # TODO: align with what OpenAI returns\n                } for result in results\n            ],\n        }\n\n    async def logprobs(self, args, request, authorization):\n        logger = self.logger\n\n        # Normalize prompts\n        prompts = args[\"prompt\"]\n        prompts = self.normalize_prompts(prompts)\n\n        # we're going to cache the keys for all the prompts in the request all together, so limit batch size\n        assert len(prompts) <= self.max_bs, \"Please submit a smaller batch\"\n        prompt_length = len(prompts[0])\n        for prompt in prompts:\n            assert len(prompt) == prompt_length, \"All prompts must be the same length to work with current caching implementation\"\n\n        # Generation arguments\n        args[\"min_tokens\"] = int(args.get(\"min_tokens\", 0))\n        args[\"max_tokens\"] = int(args.get(\"max_tokens\", self.max_seq_len))\n\n        args[\"top_k\"] = int(args.get(\"top_k\", 100000))\n\n        args['top_p'] = -1\n        args[\"temperature\"] = -1\n        args[\"n\"] = int(args.get(\"n\", self.num_return_sequences))\n\n        logger.info(f\"Received new logprobs request: \"\n                    f\"prompt length {[len(p) for p in prompts]}, \"\n                    f\"top_k: {args['top_k']}, \"\n                    f\"api_key: {args.get('api_key', None)}, \"\n                    f\"ip: {self.get_remote_ip(request)}, \"\n                    f\"tstamp: {request.scope['tstamp']}\")\n\n        cur_len = max(len(p) for p in prompts)\n        self.check_max_length_limit(cur_len, self.max_seq_len)\n\n        # Push the request to the batch queue\n        cache_id = str(args[\"cache_id\"]) if \"cache_id\" in args else str(uuid.uuid4())\n        try:\n            uuid.UUID(cache_id)\n        except ValueError:\n            raise ValueError(\"Malformed \\\"cache_id\\\", you must use the \"\n                             \"the value returned in a prior server response\")\n        ret_queue = asyncio.Queue()\n        data = {\"input\": prompts, \"cache_id\": cache_id, **args}\n        queue_entry = LogprobsItem(0, ret_queue, data)\n        auth_group, api_key = authorization\n        queue_entry = (auth_group, (api_key, queue_entry))\n        earliest_allowed = time.time() - LOGPROBS_PRIORITY_TIME_LIMIT_S\n        if cache_id in self.logprobs_past_cache and \\\n                self.logprobs_past_cache[cache_id][3] >= earliest_allowed:\n            self.request_queue.put_nowait_special(\n                lambda scheduler, arg: scheduler.appendleft(arg), queue_entry)\n        else:\n            self.request_queue.put_nowait(queue_entry)\n        results = await ret_queue.get()\n        if isinstance(results, Exception):\n            raise results\n        return {\n            \"cache_id\": cache_id,\n            \"logprobs\": results[1]['logprobs'],\n            \"indices\": results[1]['indices']\n        }\n\n    def check_max_length_limit(self, cur_len, max_len):\n        if cur_len > max_len:\n            self.logger.info(f\"Rejected a request with max prompt length = {cur_len}.\")\n            raise ValueError(f\"Your prompt length  = {cur_len} is too long. \"\n                             f\"Please make sure len(prompt) + response length <= {max_len}. \"\n                             f\"Since this is a public service, we have limited the max length supported. \"\n                             f\"If you want to try longer sequence length, \"\n                             f\"please consider hosting your own service using Alpa.\")\n\n    def get_authorization(self, args, request):\n        api_key = args.get(\"api_key\", None)\n        if api_key in self.allowed_api_keys:\n            return (AuthGroups.API_KEY_USER, api_key)\n        elif api_key is not None:\n            self.logger.error(f\"Rejected a request with an incorrect key.\")\n            raise ValueError(\"API key is incorrect, please verify that you \"\n                             \"have passed the right value (as opposed to, \"\n                             \"say, an OpenAI API key).\")\n\n        recaptcha_response = str(args.get(\"g-recaptcha-response\", \"\"))\n        if recaptcha_response == \"\":\n            if self.allow_non_key_access:\n                return (AuthGroups.NON_KEY_USER, None)\n            else:\n                self.logger.error(f\"Rejected a request with no API key.\")\n                raise ValueError(\"No captcha data found. If you are using \"\n                                 \"client APIs, please contact alpa developers \"\n                                 \"to get an API key.\")\n\n        if not self.recaptcha.verify(recaptcha_response, request.client.host):\n            self.logger.error(f\"Rejected a request with invalid captcha.\")\n            raise ValueError(\"Invalid captcha. If you are using the website, please click the \"\n                             \"\\\"I'm not a robot\\\" button.\")\n        return (AuthGroups.RECAPTCHA_USER, None)\n\n    def get_remote_ip(self, request):\n        for x in request.scope['headers']:\n            if x[0] == b\"x-forwarded-for\":\n                v = x[1].decode()\n                v = v.split(\",\")[0] # Obtain the client IP\n                if \":\" in v:\n                    # Drop the port number\n                    return v[:v.index(\":\")]\n                return v\n        return request.client.host\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"alpa/opt-125m\")\n    parser.add_argument(\"--path\", type=str, default=\"~/opt_weights/\")\n    parser.add_argument(\"--host\", type=str, default=\"0.0.0.0\")\n    parser.add_argument(\"--torch-device\", type=str, default=\"cpu\")\n    parser.add_argument(\"--tokenizer\", type=str)\n    parser.add_argument(\"--no-recaptcha\", action=\"store_true\")\n    parser.add_argument(\"--no-api-keys\", action=\"store_true\")\n    parser.add_argument(\"--block-non-key-access\", action=\"store_true\")\n    parser.add_argument(\"--register-name\", type=str, default=\"default\")\n    parser.add_argument(\"--ssl-keyfile\", type=str)\n    parser.add_argument(\"--ssl-certfile\", type=str)\n    args = parser.parse_args()\n\n    ray.init(address=\"auto\", namespace=\"alpa_serve\")\n\n    try:\n        controller = ray.get_actor(CONTROLLER_NAME)\n    except ValueError:\n        controller = run_controller(args.host, ALPA_SERVE_PORT, \"/\",\n                                    ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile)\n\n    group_id = 0\n    controller.launch_mesh_group_manager.remote(group_id)\n    t = controller.register_model.remote(\n        args.register_name, LangaugeModelWorker,\n        (args.model, args.path, args.torch_device, args.tokenizer, NUM_BEAMS, NUM_RETURN_SEQ,\n         not args.no_recaptcha and USE_RECAPTCHA,\n         not args.no_api_keys and USE_API_KEYS,\n         not args.block_non_key_access and ALLOW_NON_KEY_ACCESS),\n        override=True)\n    ray.get(t)\n    t = controller.create_replica.remote(args.register_name, group_id)\n    ray.get(t)\n\n    while True:\n        pass\n"
  },
  {
    "path": "examples/llm_serving/launch_website.py",
    "content": "import json\nimport logging\nfrom typing import Union\n\nfrom fastapi import FastAPI, Request\nfrom fastapi.staticfiles import StaticFiles\nfrom fastapi.templating import Jinja2Templates\n\nfrom llm_serving.service.constants import (\n    NUM_BEAMS, NUM_RETURN_SEQ, ALPA_SERVE_URL, USE_RECAPTCHA)\nfrom llm_serving.service.recaptcha import load_recaptcha\n\napp = FastAPI()\n\napp.mount(\"/static\", StaticFiles(directory=\"service/static\"), name=\"static\")\ntemplates = Jinja2Templates(directory=\"service/static\")\n\nif NUM_BEAMS > 1: # beam search is on, disable sampling\n    sampling_css = \"display:none\"\nelse:\n    sampling_css = \"\"\n\nrecaptcha = load_recaptcha(USE_RECAPTCHA)\n\n\ndef log_scope(request):\n    scope = request.scope\n    del scope[\"app\"]\n    del scope[\"fastapi_astack\"]\n    del scope[\"router\"]\n    del scope[\"endpoint\"]\n    del scope[\"route\"]\n    scope[\"tstamp\"] = time.time()\n    logging.info(scope)\n    return scope\n\n\n##### Redirect Begin #####\nimport asyncio\nimport pickle\nimport time\n\nfrom alpa.serve.http_util import HTTPRequestWrapper, make_error_response, RelayException\nimport ray\nfrom starlette.responses import JSONResponse\nray.init(address=\"auto\", namespace=\"alpa_serve\")\n\nmanager = None\n\nasync def connect_manager():\n    global manager\n    while True:\n        if manager is None:\n            try:\n                manager = ray.get_actor(\"mesh_group_manager_0\")\n            except ValueError:\n                manager = None\n        await asyncio.sleep(1)\n\nasyncio.get_event_loop().create_task(connect_manager())\n\nasync def redirect(request):\n    global manager\n\n    body = await request.body()\n    scope = log_scope(request)\n    request = pickle.dumps(HTTPRequestWrapper(scope, body))\n    try:\n        ret = await manager.handle_request.remote(\"default\", request)\n    except ray.exceptions.RayActorError:\n        manager = None\n    if isinstance(ret, RelayException):\n        ret = make_error_response(ret)\n        ret = JSONResponse(ret, status_code=400)\n    return ret\n\n\n@app.post(\"/completions\")\nasync def completions(request: Request):\n    return await redirect(request)\n\n\n@app.post(\"/logprobs\")\nasync def logprobs(request: Request):\n    return await redirect(request)\n\n\n@app.post(\"/call\")\nasync def logprobs(request: Request):\n    return await redirect(request)\n\n##### Redirect End #####\n\n@app.get(\"/\")\nasync def homepage(request: Request):\n    for x in request.scope['headers']:\n        if x[0] == b\"user-agent\" and b\"UptimeRobot\" not in x[1]:\n            log_scope(request)\n            break\n    return templates.TemplateResponse(\"index.html\", {\n        \"request\": request,\n        \"num_return_sequences\": NUM_RETURN_SEQ,\n        \"sampling_css\": sampling_css,\n        \"recaptcha\": recaptcha.get_code(),\n        \"alpa_serve_url\": ALPA_SERVE_URL,\n    })\n"
  },
  {
    "path": "examples/llm_serving/log_config.yaml",
    "content": "version: 1\nformatters:\n  simple:\n    format: \"%(asctime)s | %(levelname)s | %(name)s | %(message)s\"\n    datefmt: \"%Y-%m-%d %H:%M:%S\"\nhandlers:\n  console:\n    class : logging.StreamHandler\n    formatter: simple\n    level   : INFO\n    stream  : ext://sys.stdout\n  file:\n    class : logging.handlers.TimedRotatingFileHandler\n    filename: weblogs/llm_serving.website.log\n    when: \"D\"\n    utc: True\n    formatter: simple\n    level   : INFO\nroot:\n  level: INFO\n  handlers: [console, file]\n"
  },
  {
    "path": "examples/llm_serving/model/__init__.py",
    "content": ""
  },
  {
    "path": "examples/llm_serving/model/bloom_model.py",
    "content": "\"\"\"BLOOM model implementation.\n\nSome code is adapted from\nhttps://github.com/huggingface/bloom-jax-inference/blob/main/bloom_inference/modeling_bloom/modeling_bloom.py\n\"\"\"\nimport dataclasses\nfrom dataclasses import dataclass\nimport itertools\nfrom functools import partial\nimport math\nimport os\nfrom typing import Optional, Tuple, Sequence\n\nimport alpa\nfrom alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray,\n                              MeshHostWorker, create_remote_array_refs)\nfrom alpa.model.model_util import ModelOutput\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\nimport flax\nimport flax.linen as nn\nfrom flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask\nfrom flax.linen.activation import tanh\nimport jax\nfrom jax import lax\nfrom jax.interpreters import pxla\nimport jax.numpy as jnp\nfrom jax.tree_util import tree_flatten, tree_leaves\nimport jaxlib.xla_extension as jax_xla\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom llm_serving.model.opt_model import (init_cache_aval, init_mask_aval,\n    init_cache_np, init_cache_dis_array, init_multi_executable_cache_dis_array)\n\n\n@dataclass(frozen=True)\nclass BloomConfig:\n    model_type: str = \"bloom\"\n    vocab_size: int = 250880\n    max_seq_len: int = 2048\n    hidden_size: int = 64\n    n_head: int = 8\n    num_hidden_layers: int = 2\n    layer_norm_epsilon: float = 1e-5\n    initializer_range: float = 0.02\n    use_cache: bool = False\n    eos_token_id: int = 2\n    pad_token_id: int = 3\n    unk_token_id: int = 0\n    apply_residual_connection_post_layernorm: bool = False\n    hidden_dropout: float = 0.0\n    attention_dropout: float = 0.0\n    pretraining_tp: int = 1  # TP rank used when training with megatron\n    slow_but_exact: bool = False\n    tie_word_embeddings: bool = True\n    dtype: any = jnp.float16\n    pad: int = 1\n    # For parallel\n    mark_boundary: bool = True\n    num_pp_stages: int = None\n\n\n@flax.struct.dataclass\nclass BloomModelOutput(ModelOutput):\n    last_hidden_state: jax_xla.DeviceArray\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None\n\n\n@flax.struct.dataclass\nclass BloomLMOutput(ModelOutput):\n    logits: jax_xla.DeviceArray\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None\n\n\ndef build_alibi_tensor_flax(attention_mask, n_head, dtype):\n    def get_slopes(n):\n        def get_slopes_power_of_2(n):\n            start = 2 ** (-(2 ** -(math.log2(n) - 3)))\n            ratio = start\n            return [start * ratio**i for i in range(n)]\n\n        if math.log2(n).is_integer():\n            return get_slopes_power_of_2(n)\n        else:\n            closest_power_of_2 = 2 ** math.floor(math.log2(n))\n            return (\n                get_slopes_power_of_2(closest_power_of_2)\n                + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]\n            )\n\n    # Note: alibi will be added to the attention bias that is applied to the query, key product of attention\n    # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)\n    # => here we set (batch_size=1, num_heads=n_head, query_length=1, key_length=max_length)\n    # => the query_length dimension will then be broadcast correctly\n    # This is more or less identical to T5's relative position bias:\n    # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_flax_t5.py#L426\n    # batch_size = 1, n_head = n_head, query_length\n    # shape of attention_mask: [B, 1, 1, S_max]\n    batch_size = attention_mask.shape[0]\n    key_length = attention_mask.shape[-1]\n\n    # Handle a special kind of internal padding added by alpa.\n    # Where internal padding of 2 is used for encoder chunck size that can't divide input length.\n    attention_mask = (attention_mask == 1)\n\n    attention_mask = attention_mask.reshape((batch_size, key_length))\n    num_heads = n_head\n    query_length = 1\n\n    slopes = jnp.array(get_slopes(n_head))[None, :, None, None].astype(dtype)\n    arange_tensor = attention_mask.cumsum(-1, dtype=dtype)[:, None, None, :] - 1\n\n    slopes_broadcast = jnp.broadcast_to(slopes, (batch_size, num_heads, query_length, key_length))\n    arange_broadcast = jnp.broadcast_to(arange_tensor, (batch_size, num_heads, query_length, key_length))\n\n    alibi = slopes_broadcast * arange_broadcast\n\n    return alibi\n\n\nclass FlaxBloomAttention(nn.Module):\n    config: BloomConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        self.hidden_size = self.config.hidden_size\n        self.num_heads = self.config.n_head\n        self.head_dim = self.hidden_size // self.num_heads\n\n        if self.head_dim * self.num_heads != self.hidden_size:\n            raise ValueError(\n                f\"`hidden_size` must be divisible by `num_heads` (got `hidden_size`: {self.hidden_size} and \"\n                f\"`num_heads`: {self.num_heads}).\"\n            )\n\n        dense = partial(\n            nn.Dense,\n            dtype=self.dtype,\n            kernel_init=jax.nn.initializers.normal(\n                self.config.initializer_range)\n        )\n\n        self.query_key_value = dense(self.hidden_size * 3)\n        self.dense = dense(self.hidden_size)\n        # Mismatch happens here, the self.dense is different from that of HF's\n        self.resid_dropout = nn.Dropout(\n            rate=self.config.hidden_dropout)\n\n    def __call__(\n        self,\n        hidden_states,\n        residual,\n        alibi,\n        attention_mask=None,\n        attention_cache=None,\n        deterministic: bool = True,\n        output_attentions: bool = False\n    ):\n        # This chunk verified to be working\n        batch_size = hidden_states.shape[0]\n        seq_length = hidden_states.shape[1]\n        fused_qkv = self.query_key_value(hidden_states)\n        fused_qkv = fused_qkv.reshape(fused_qkv.shape[:-1] + (self.num_heads, self.head_dim * 3))\n        query, key, value = jnp.split(fused_qkv, 3, axis=-1)\n        key_len = attention_mask.shape[-1]\n        causal_attention_mask = make_causal_mask(jnp.ones((batch_size, key_len)), dtype=\"bool\")\n\n        # for fast decoding causal attention mask should be shifted\n        if attention_cache:\n            causal_attention_mask_shift = attention_cache[2][0]\n        else:\n            causal_attention_mask_shift = 0\n\n        # fast decoding for generate requires special attention_mask\n        if attention_cache:\n            max_decoder_length = attention_cache[0].shape[1]\n            causal_attention_mask = jax.lax.dynamic_slice(\n                causal_attention_mask,\n                (0, 0, causal_attention_mask_shift, 0),\n                (1, 1, seq_length, max_decoder_length)\n            )\n            # Handle a special kind of internal padding added by alpa.\n            # Note that this kind of internal padding is different from\n            # the padding added by the tokenizer. This internal padding\n            # should not update cache and step_ct\n            # shape: [B, 1, 1, S_max]\n            is_internal_padding = (attention_mask == 2)\n            num_internal_pad = jnp.sum(is_internal_padding, axis=3).reshape(-1)\n            attention_mask = (attention_mask == 1)\n\n        attention_mask = combine_masks(attention_mask, causal_attention_mask)\n\n        # During fast autoregressive decoding, we feed one position at a time,\n        # and cache the keys and values step by step.\n        if attention_cache:\n            cache_key, cache_value, cache_index = attention_cache\n            *batch_dims, max_length, num_heads, depth_per_head = cache_key.shape\n            # update key, value caches with our new 1d spatial slices\n            cur_index = cache_index[0]\n            indices = (0, cur_index, 0, 0)\n            key = lax.dynamic_update_slice(cache_key, key, indices)\n            value = lax.dynamic_update_slice(cache_value, value, indices)\n            cache_key = key\n            cache_value = value\n            num_updated_cache_vectors = query.shape[1]\n            # A line added from bloom_model\n            attention_cache = key, value, cache_index + num_updated_cache_vectors - num_internal_pad\n            # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.\n            pad_mask = jnp.broadcast_to(\n                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,\n                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),\n            )\n            attention_mask = combine_masks(pad_mask, attention_mask)\n\n        dropout_rng = None\n        if not deterministic and self.config.attention_dropout > 0.0:\n            dropout_rng = self.make_rng(\"dropout\")\n\n        # transform boolean mask into float mask\n        mask_value = jnp.finfo(self.dtype).min\n        attention_bias = lax.select(\n            attention_mask > 0,\n            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),\n            jnp.full(attention_mask.shape, mask_value).astype(self.dtype),\n        )\n\n        attention_bias = attention_bias + alibi\n\n        attn_weights = dot_product_attention_weights(\n            query,\n            key,\n            bias=attention_bias,\n            dropout_rng=dropout_rng,\n            dropout_rate=self.config.attention_dropout,\n            deterministic=deterministic,\n            dtype=self.dtype,\n            precision=None\n        )\n\n        attn_output = jnp.einsum(\"...hqk,...khd->...qhd\", attn_weights, value)\n        attn_output = attn_output.reshape(hidden_states.shape[:2] + (self.hidden_size,))\n        attn_output = self.dense(attn_output)\n        attn_output = self.resid_dropout(attn_output, deterministic=deterministic)\n        attn_output = attn_output + residual\n\n        outputs = (attn_output, attention_cache,\n                   attn_weights) if output_attentions else (attn_output,\n                                                            attention_cache)\n        return outputs\n\n\nclass BloomGELU(nn.Module):\n    def setup(self):\n        pass\n\n    def __call__(self, x):\n        return x * 0.5 * (1.0 + tanh(0.79788456 * x * (1 + 0.044715 * x * x)))\n\n\nclass FlaxBloomMLP(nn.Module):\n    config: BloomConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        hidden_size = self.config.hidden_size\n\n        self.pretraining_tp = self.config.pretraining_tp\n        self.slow_but_exact = self.config.slow_but_exact\n\n        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)\n\n        self.dense_h_to_4h = nn.Dense(4 * hidden_size, dtype=self.dtype, kernel_init=kernel_init)\n        self.dense_4h_to_h = nn.Dense(hidden_size, dtype=self.dtype, kernel_init=kernel_init)\n        self.hidden_dropout = nn.Dropout(self.config.hidden_dropout)\n        self.act = BloomGELU()\n\n    def __call__(self, hidden_states, residual, deterministic: bool = True):\n        hidden_states = self.dense_h_to_4h(hidden_states)\n        hidden_states = self.act(hidden_states)\n\n        intermediate_output = self.dense_4h_to_h(hidden_states)\n\n        hidden_states = self.hidden_dropout(intermediate_output, deterministic=deterministic)\n        hidden_states += residual\n\n        return hidden_states\n\n\nclass FlaxBloomBlock(nn.Module):\n    config: BloomConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        self.input_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)\n        self.self_attention = FlaxBloomAttention(self.config, dtype=self.dtype)\n        self.post_attention_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)\n\n        self.mlp = FlaxBloomMLP(self.config, dtype=self.dtype)\n\n        self.apply_residual_connection_post_layernorm = self.config.apply_residual_connection_post_layernorm\n        self.hidden_dropout = self.config.hidden_dropout\n\n    def __call__(\n        self,\n        hidden_states,\n        alibi,\n        attention_mask=None,\n        attention_cache=None,\n        deterministic: bool = True,\n        output_attentions: bool = False\n    ):\n        layernorm_output = self.input_layernorm(hidden_states)\n        # layer norm before saving residual if config calls for it\n        if self.apply_residual_connection_post_layernorm:\n            residual = layernorm_output\n        else:\n            residual = hidden_states\n\n        # self-attention\n        attn_outputs = self.self_attention(\n            layernorm_output,\n            residual=residual,\n            alibi=alibi,\n            attention_mask=attention_mask,\n            attention_cache=attention_cache,\n            deterministic=deterministic,\n            output_attentions=output_attentions\n        )\n        attention_output = attn_outputs[0]\n        attention_cache = attn_outputs[1]\n\n        post_layernorm = self.post_attention_layernorm(attention_output)\n\n        # set residual based on config\n        if self.apply_residual_connection_post_layernorm:\n            residual = post_layernorm\n        else:\n            residual = attention_output\n\n        output = self.mlp(post_layernorm, residual, deterministic=deterministic)\n\n        outputs = (output, attention_cache)\n        if output_attentions:\n            outputs += (attn_outputs[2],)\n        return outputs\n\n\nclass FlaxBloomBlockCollection(nn.Module):\n    config: BloomConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        self.layers = [\n            FlaxBloomBlock(self.config, name=str(i), dtype=self.dtype)\n            for i in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(\n        self,\n        hidden_states,\n        alibi,\n        attention_mask=None,\n        attention_cache=None,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True\n    ):\n        all_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        new_attention_cache = () if attention_cache is not None else None\n\n        if self.config.num_pp_stages is not None:\n            assert self.config.num_hidden_layers % self.config.num_pp_stages == 0\n            layers_per_stage = self.config.num_hidden_layers // self.config.num_pp_stages\n\n        for layer_number, layer in enumerate(self.layers):\n            if self.config.num_pp_stages is not None:\n                if layer_number % layers_per_stage == 0 and layer_number != 0:\n                    if self.config.mark_boundary:\n                        mark_pipeline_boundary()\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n            layer_attention_cache = None\n            if attention_cache is not None:\n                layer_attention_cache = attention_cache[layer_number]\n            layer_outputs = layer(\n                hidden_states,\n                alibi=alibi,\n                attention_mask=attention_mask,\n                attention_cache=layer_attention_cache,\n                deterministic=deterministic,\n                output_attentions=output_attentions\n            )\n            hidden_states = layer_outputs[0]\n\n            if attention_cache is not None:\n                new_attention_cache += (layer_outputs[1],)\n\n            if output_attentions:\n                all_attentions += (layer_outputs[2],)\n\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        outputs = (hidden_states,)\n        if not return_dict:\n            return tuple(v for v in outputs if v is not None)\n\n        return BloomModelOutput(last_hidden_state=hidden_states,\n                              hidden_states=all_hidden_states,\n                              attentions=all_attentions,\n                              attention_cache=new_attention_cache)\n\n\nclass FlaxBloomModule(nn.Module):\n    config: BloomConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        self.embed_dim = self.config.hidden_size\n\n        embedding_init = jax.nn.initializers.normal(stddev=self.config.initializer_range)\n\n        # word embeddings (no positional embedding layer)\n        self.word_embeddings = nn.Embed(\n            self.config.vocab_size,\n            self.embed_dim,\n            embedding_init=embedding_init,\n            dtype=self.dtype\n        )\n\n        # post-embedding layernorm\n        self.word_embeddings_layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)\n\n        # transformer layers\n        self.h = FlaxBloomBlockCollection(self.config, dtype=self.dtype)\n\n        # final layernorm\n        self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_ids=None,\n        attention_mask=None,\n        attention_cache=None,\n        deterministic=True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True\n    ):\n        inputs_embeds = self.word_embeddings(input_ids)\n        # do post-embedding layernorm\n        hidden_states = self.word_embeddings_layernorm(inputs_embeds)\n\n        # build alibi depending on `attention_mask`\n        alibi = build_alibi_tensor_flax(attention_mask, self.config.n_head, hidden_states.dtype)\n\n        outputs = self.h(\n            hidden_states,\n            alibi=alibi,\n            attention_mask=attention_mask,\n            attention_cache=attention_cache,\n            deterministic=deterministic,\n            output_hidden_states=output_hidden_states,\n            output_attentions=output_attentions,\n            return_dict=return_dict\n        )\n\n        hidden_states = outputs[0]\n\n        hidden_states = self.ln_f(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states = outputs.hidden_states + (hidden_states,)\n            outputs = BloomModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache)\n        else:\n            outputs = BloomModelOutput(last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache)\n\n        if not return_dict:\n            return (hidden_states,) + outputs[1:]\n\n        return outputs\n\n\nclass FlaxBloomForCausalLMModule(nn.Module):\n    config: BloomConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        self.transformer = FlaxBloomModule(self.config, dtype=self.dtype)\n        self.lm_head = nn.Dense(\n            self.config.vocab_size,\n            use_bias=False,\n            dtype=jnp.float32,\n            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n        )\n\n    def __call__(\n        self,\n        input_ids,\n        attention_mask=None,\n        attention_cache=None,\n        deterministic: bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True\n    ):\n        outputs = self.transformer(\n            input_ids,\n            attention_mask=attention_mask,\n            attention_cache=attention_cache,\n            deterministic=deterministic,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict\n        )\n\n        hidden_states = outputs[0]\n\n        if self.config.tie_word_embeddings:\n            shared_kernel = self.transformer.variables[\"params\"][\"word_embeddings\"][\"embedding\"].T\n            lm_logits = self.lm_head.apply({\"params\": {\"kernel\": shared_kernel}}, hidden_states)\n        else:\n            lm_logits = self.lm_head(hidden_states)\n\n        if not return_dict:\n            return (lm_logits,) + outputs[1:]\n        return BloomLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache)\n\n\ndef get_config(name, **kwargs):\n    if name in [\"bloom-560m\", \"bloomz-560m\"]:\n        config = BloomConfig(\n            hidden_size=1024, n_head=16, num_hidden_layers=24,\n            pretraining_tp=1, use_cache=True\n        )\n    elif name in [\"bloom-1b1\", \"bloomz-1b1\"]:\n        config = BloomConfig(\n            hidden_size=1536, n_head=16, num_hidden_layers=24,\n            pretraining_tp=1, use_cache=True\n        )\n    elif name in [\"bloom-1b7\", \"bloomz-1b7\"]:\n        config = BloomConfig(\n            hidden_size=2048, n_head=16, num_hidden_layers=24,\n            pretraining_tp=2, use_cache=True\n        )\n    elif name in [\"bloom-3b\", \"bloomz-3b\"]:\n        config = BloomConfig(\n            hidden_size=2560, n_head=32, num_hidden_layers=30,\n            pretraining_tp=4, use_cache=True\n        )\n    elif name in [\"bloom-7b1\", \"bloomz-7b1\"]:\n        config = BloomConfig(\n            hidden_size=4096, n_head=32, num_hidden_layers=30,\n            pretraining_tp=4, use_cache=True\n        )\n    elif name in [\"bloom\", \"bloomz\"]:\n        config = BloomConfig(\n            hidden_size=14336, n_head=112, num_hidden_layers=70,\n            pretraining_tp=4, use_cache=True\n        )\n    elif name == \"bloom-debug\":\n        config = BloomConfig(\n            hidden_size=1024, n_head=16, num_hidden_layers=8,\n            pretraining_tp=4, use_cache=True\n        )\n    else:\n        raise ValueError()\n\n    return dataclasses.replace(config, **kwargs)\n\n\ndef init_model_aval(config):\n    \"\"\"Initialize model with parameters with abstract values (shape-only arrays).\"\"\"\n    model = FlaxBloomForCausalLMModule(config, dtype=config.dtype)\n    rngkey = jax.core.ShapedArray((2,), jnp.uint32)\n    input_ids = jax.core.ShapedArray((1,2), jnp.int32)\n    attention_mask = jax.core.ShapedArray((1, 1, 1, 2), jnp.int32)\n    params = jax.eval_shape(model.init, rngkey, input_ids, attention_mask=attention_mask)\n    params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, config.dtype),\n                          params)\n    return model, params\n\n\ndef load_params_np(params, path, config, dummy=False):\n    \"\"\"Load parameters with numpy arrays.\"\"\"\n    if dummy:\n        np_dtype = config.dtype\n        return jax.tree_map(lambda x: np.full(x.shape, 1e-9, np_dtype), params)\n\n    def load_array(key):\n        return np.load(os.path.join(path, key))\n\n    def load_param(param_key, loaded_array, is_position_embedding=False):\n        param_dict = params\n        param_keys = param_key.split('.')\n        for i, key in enumerate(param_keys):\n            if i == len(param_keys) - 1:\n                if dummy:\n                    param_dict[key] = jax.core.ShapedArray(\n                        param_dict[key].shape, param_dict[key].dtype)\n                else:\n                    if not is_position_embedding:\n                        assert param_dict[key].shape == loaded_array.shape, (\n                                f\"{param_dict[key].shape} vs. {loaded_array.shape}\")\n                    else:\n                        shape = param_dict[key].shape\n                        if shape != loaded_array.shape:\n                            assert shape[1] == loaded_array.shape[1]\n                            loaded_array = loaded_array[:shape[0], :]\n                    param_dict[key] = loaded_array\n            else:\n                param_dict = param_dict[key]\n\n    params = params.unfreeze()\n    load_param(\"params.transformer.ln_f.scale\",\n               load_array(\"ln_f.weight\"))\n    load_param(\"params.transformer.ln_f.bias\",\n               load_array(\"ln_f.bias\"))\n    load_param(\"params.transformer.word_embeddings.embedding\",\n               load_array(\"word_embeddings.weight\"))\n    load_param(\"params.transformer.word_embeddings_layernorm.scale\",\n                load_array(\"word_embeddings_layernorm.weight\"))\n    load_param(\"params.transformer.word_embeddings_layernorm.bias\",\n                load_array(\"word_embeddings_layernorm.bias\"))\n    for i in tqdm(range(config.num_hidden_layers)):\n        param_prefix = f\"params.transformer.h.{i}.\"\n        load_prefix = f\"h.{i}.\"\n        # Attention weights\n        load_param(param_prefix + \"self_attention.query_key_value.kernel\",\n                   load_array(load_prefix + \"self_attention.query_key_value.weight\").transpose())\n        load_param(param_prefix + \"self_attention.query_key_value.bias\",\n                   load_array(load_prefix + \"self_attention.query_key_value.bias\").transpose())\n        load_param(param_prefix + \"input_layernorm.scale\",\n                   load_array(load_prefix + \"input_layernorm.weight\"))\n        load_param(param_prefix + \"input_layernorm.bias\",\n                   load_array(load_prefix + \"input_layernorm.bias\"))\n        load_param(param_prefix + \"self_attention.dense.kernel\",\n                   load_array(load_prefix + \"self_attention.dense.weight\").transpose())\n        load_param(param_prefix + \"self_attention.dense.bias\",\n                   load_array(load_prefix + \"self_attention.dense.bias\"))\n        load_param(param_prefix + \"post_attention_layernorm.scale\",\n                   load_array(load_prefix + \"post_attention_layernorm.weight\"))\n        load_param(param_prefix + \"post_attention_layernorm.bias\",\n                   load_array(load_prefix + \"post_attention_layernorm.bias\"))\n        # MLP weights\n        load_param(param_prefix + \"mlp.dense_h_to_4h.kernel\",\n                   np.transpose(load_array(load_prefix + \"mlp.dense_h_to_4h.weight\")))\n        load_param(param_prefix + \"mlp.dense_h_to_4h.bias\",\n                   np.transpose(load_array(load_prefix + \"mlp.dense_h_to_4h.bias\")))\n        load_param(param_prefix + \"mlp.dense_4h_to_h.kernel\",\n                   np.transpose(load_array(load_prefix + \"mlp.dense_4h_to_h.weight\")))\n        load_param(param_prefix + \"mlp.dense_4h_to_h.bias\",\n                   np.transpose(load_array(load_prefix + \"mlp.dense_4h_to_h.bias\")))\n\n    return flax.core.freeze(params)\n\n\ndef get_jax_executable(config: BloomConfig,\n                       encoder_chunk_sizes: Sequence[int],\n                       output_attentions: bool = False,\n                       output_hidden_states:bool = False):\n    \"\"\"Get a single-gpu executable.\"\"\"\n    model, params = init_model_aval(config)\n\n    @jax.jit\n    def inference_step(params, batch):\n        output = model.apply(params,\n                             batch[\"input_ids\"],\n                             attention_cache=batch[\"cache\"],\n                             attention_mask=batch[\"mask\"],\n                             output_attentions=output_attentions,\n                             output_hidden_states=output_hidden_states)\n        return output\n\n    executables = {}\n    for length in encoder_chunk_sizes:\n        executables[length] = inference_step\n    return executables, params\n\n\ndef get_pipeshard_executable(config: BloomConfig,\n                             batch_size: int,\n                             encoder_chunk_sizes: Sequence[int],\n                             num_micro_batches: int = 1,\n                             output_attentions: bool = False,\n                             output_hidden_states: bool = False):\n    \"\"\"Get a parallel executable.\"\"\"\n    # Init model\n    model, params = init_model_aval(config)\n\n    # Parallelize\n    method = alpa.PipeshardParallel(\n        num_micro_batches=num_micro_batches,\n        pipeline_schedule=\"inference\",\n        layer_option=\"manual\",\n        default_auto_sharding_option=alpa.AutoShardingOption(\n            # Force operator model parallel\n            force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0,\n            # Disabling all-to-all and all-gather generates better intra-op strategies.\n            allow_all_to_all=False,\n            allow_all_gather=False,\n        ))\n\n    def inference_step_with_cache(params, batch):\n        output = model.apply(\n            params,\n            batch[\"input_ids\"],\n            attention_cache=batch[\"cache\"],\n            attention_mask=batch[\"mask\"],\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states)\n        return output\n\n    alpa.global_config.always_donate_micro_batch_vars = False\n\n    cache = init_cache_aval(config, batch_size)\n    mask = init_mask_aval(config, batch_size)\n\n    executables = {}\n\n    # Compile an executable with sequence length 1\n    executable = alpa.parallelize(\n        inference_step_with_cache, batch_argnums=(1,),\n        method=method).get_executable(\n            params, {\n                \"input_ids\":\n                    jax.core.ShapedArray((batch_size, 1), jnp.int32),\n                \"cache\":\n                    cache,\n                \"mask\":\n                    mask,\n            })\n    executable.dump_debug_info(\"tmp_executable_1\")\n    executables[1] = executable\n\n    # Create another parallel method with assigned input sharding specs\n    method_with_input_sharding = alpa.PipeshardParallel(\n        num_micro_batches=num_micro_batches,\n        pipeline_schedule=\"inference\",\n        layer_option=\"manual\",\n        default_auto_sharding_option=alpa.AutoShardingOption(\n            enable_auto_sharding=False,\n        ),\n        stage_input_shardings=executable.stage_input_shard_specs)\n\n    # Compile other executables\n    for seq_len in encoder_chunk_sizes:\n        executable = alpa.parallelize(\n            inference_step_with_cache,\n            batch_argnums=(1,),\n            method=method_with_input_sharding).get_executable(\n                params, {\n                    \"input_ids\":\n                        jax.core.ShapedArray(\n                            (batch_size, seq_len), jnp.int32),\n                    \"cache\":\n                        cache,\n                    \"mask\":\n                        mask,\n                })\n        executable.dump_debug_info(\"tmp_executable_%d\" % seq_len)\n        executables[seq_len] = executable\n    return executables, params\n\n\ndef load_bloom_params_worker_func(self, path, prefix_to_idx, config, shapes,\n                                  uuids, indices, mesh_ids):\n    \"\"\"The worker function to load Bloom parameters.\"\"\"\n\n    def load_array(key):\n        return np.load(os.path.join(path, key))\n\n    def load_param(param_key, loaded_array, is_position_embedding=False):\n        i = prefix_to_idx[param_key]\n\n        for j in range(len(mesh_ids[i])):\n            if self.mesh_id != mesh_ids[i][j]:\n                continue\n\n            if not is_position_embedding:\n                assert shapes[i][j] == loaded_array.shape\n            else:\n                if shapes[i][j] != loaded_array.shape:\n                    assert shapes[i][j][1] == loaded_array.shape[1]\n                    loaded_array = loaded_array[:shapes[i][j][0], :]\n            uuid = uuids[i][j]\n            datas = []\n            for k in range(len(self.local_devices)):\n                idx = self.host_id * len(self.local_devices) + k\n                datas.append(loaded_array[indices[i][j][idx]])\n            self.put_buffers(uuid, datas)\n    layers_per_stage = config.num_hidden_layers // config.num_pp_stages\n\n    load_param(\"params.transformer.ln_f.scale\",\n               load_array(\"ln_f.weight\"))\n    load_param(\"params.transformer.ln_f.bias\",\n               load_array(\"ln_f.bias\"))\n    load_param(\"params.transformer.word_embeddings.embedding\",\n               load_array(\"word_embeddings.weight\"))\n    load_param(\"params.transformer.word_embeddings_layernorm.scale\",\n                load_array(\"word_embeddings_layernorm.weight\"))\n    load_param(\"params.transformer.word_embeddings_layernorm.bias\",\n                load_array(\"word_embeddings_layernorm.bias\"))\n\n    for i in range(config.num_hidden_layers):\n        stage_id = i // layers_per_stage\n        if stage_id != self.mesh_id:\n            continue\n\n        param_prefix = f\"params.transformer.h.{i}.\"\n        load_prefix = f\"h.{i}.\"\n        # Attention weights\n        load_param(param_prefix + \"self_attention.query_key_value.kernel\",\n                   load_array(load_prefix + \"self_attention.query_key_value.weight\").transpose())\n        load_param(param_prefix + \"self_attention.query_key_value.bias\",\n                   load_array(load_prefix + \"self_attention.query_key_value.bias\").transpose())\n        load_param(param_prefix + \"input_layernorm.scale\",\n                   load_array(load_prefix + \"input_layernorm.weight\"))\n        load_param(param_prefix + \"input_layernorm.bias\",\n                   load_array(load_prefix + \"input_layernorm.bias\"))\n        load_param(param_prefix + \"self_attention.dense.kernel\",\n                   load_array(load_prefix + \"self_attention.dense.weight\").transpose())\n        load_param(param_prefix + \"self_attention.dense.bias\",\n                   load_array(load_prefix + \"self_attention.dense.bias\"))\n        load_param(param_prefix + \"post_attention_layernorm.scale\",\n                   load_array(load_prefix + \"post_attention_layernorm.weight\"))\n        load_param(param_prefix + \"post_attention_layernorm.bias\",\n                   load_array(load_prefix + \"post_attention_layernorm.bias\"))\n        # MLP weights\n        load_param(param_prefix + \"mlp.dense_h_to_4h.kernel\",\n                   np.transpose(load_array(load_prefix + \"mlp.dense_h_to_4h.weight\")))\n        load_param(param_prefix + \"mlp.dense_h_to_4h.bias\",\n                   np.transpose(load_array(load_prefix + \"mlp.dense_h_to_4h.bias\")))\n        load_param(param_prefix + \"mlp.dense_4h_to_h.kernel\",\n                   np.transpose(load_array(load_prefix + \"mlp.dense_4h_to_h.weight\")))\n        load_param(param_prefix + \"mlp.dense_4h_to_h.bias\",\n                   np.transpose(load_array(load_prefix + \"mlp.dense_4h_to_h.bias\")))\n\n\nsetattr(MeshHostWorker, \"load_bloom_params_worker_func\",\n        load_bloom_params_worker_func)\n\n\ndef load_params_dis_array(path, executable, params_aval, config, dummy=False):\n    \"\"\"Load parameters with distributed arrays.\"\"\"\n    if dummy:\n        alpa.global_config.use_dummy_value_for_benchmarking = True\n        params_info, _ = executable.get_input_placement_specs()\n        flat_args, in_tree = tree_flatten(params_aval)\n        flat_info = tree_leaves(params_info)\n        if hasattr(executable, \"mesh_group\"):\n            ret = executable.mesh_group.shard_args_to_arrays(\n                flat_info, flat_args)\n        else:\n            ret = executable.physical_mesh.shard_args_to_arrays_ps(\n                flat_info, flat_args)\n        alpa.global_config.use_dummy_value_for_benchmarking = False\n        return ret\n\n    params_info, _ = executable.get_input_placement_specs()\n\n    prefix_to_flat_idx = {}\n    ct = itertools.count()\n\n    def dfs(dict_tree, result_dict, cur_prefix):\n        if isinstance(dict_tree, (dict, flax.core.FrozenDict)):\n            for key in dict_tree.keys():\n                dfs(dict_tree[key], result_dict,\n                    cur_prefix + (\".\" if cur_prefix else \"\") + key)\n        else:\n            result_dict[cur_prefix] = next(ct)\n\n    dfs(params_aval, prefix_to_flat_idx, \"\")\n\n    flat_infos, in_tree = tree_flatten(params_info)\n\n    flat_shapes = []\n    flat_uuids = []\n    flat_indices = []\n    flat_mesh_ids = []\n    flat_arrays = []\n\n    mesh_group = executable.mesh_group\n\n    for info in flat_infos:\n        aval = info.aval\n        if len(info.mesh_ids) == 1:\n            mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0]\n            indices = pxla.spec_to_indices(aval.shape, spec)\n            ary_refs, ary_uuid = create_remote_array_refs(mesh)\n            flat_shapes.append([aval.shape])\n            flat_uuids.append([ary_uuid[0]])\n            flat_indices.append([indices])\n            flat_mesh_ids.append([mesh.mesh_id])\n            flat_arrays.append(\n                DistributedArray(mesh, aval, spec, ary_refs[0], indices))\n        else:\n            tmp_shapes = []\n            tmp_uuids = []\n            tmp_indices = []\n            tmp_mesh_ids = []\n            tmp_arrays = []\n            tmp_meshes = []\n            for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs):\n                mesh = mesh_group[mesh_id]\n                indices = pxla.spec_to_indices(aval.shape, spec)\n                ary_refs, ary_uuid = create_remote_array_refs(mesh)\n                array = DistributedArray(mesh, aval, spec, ary_refs[0], indices)\n                tmp_shapes.append(aval.shape)\n                tmp_uuids.append(ary_uuid[0])\n                tmp_indices.append(indices)\n                tmp_mesh_ids.append(mesh.mesh_id)\n                tmp_meshes.append(mesh)\n                tmp_arrays.append(array)\n            flat_shapes.append(tuple(tmp_shapes))\n            flat_uuids.append(tuple(tmp_uuids))\n            flat_indices.append(tuple(tmp_indices))\n            flat_mesh_ids.append(tuple(tmp_mesh_ids))\n            flat_arrays.append(\n                ReplicatedDistributedArray(tmp_meshes, tmp_arrays))\n\n    for m in executable.mesh_group.meshes:\n        for w in m.workers:\n            w.load_bloom_params_worker_func.remote(path, prefix_to_flat_idx,\n                                                 config, flat_shapes,\n                                                 flat_uuids, flat_indices,\n                                                 flat_mesh_ids)\n\n    return flat_arrays\n\n\ndef load_multi_executable_params_dis_array(path,\n                                           executables,\n                                           params_aval,\n                                           config,\n                                           dummy=False):\n    \"\"\"Load parameters to workers that will be used by all executables. Accordingly,\n    we need to make sure the parameter sharding specs are identical for all executables.\n    \"\"\"\n    shared_input_shard_specs = None\n    for executable in executables.values():\n        stage_input_shard_specs = executable.stage_input_shard_specs\n        if shared_input_shard_specs is not None:\n            assert shared_input_shard_specs == stage_input_shard_specs, \\\n                \"All executables must have the same input sharding specs.\"\n        else:\n            shared_input_shard_specs = stage_input_shard_specs\n    return load_params_dis_array(path,\n                                 list(executables.values())[0], params_aval,\n                                 config, dummy)\n"
  },
  {
    "path": "examples/llm_serving/model/codegen_model.py",
    "content": "\"\"\"CodeGen model implementation.\"\"\"\nimport dataclasses\nfrom dataclasses import dataclass\nfrom functools import partial\nimport itertools\nimport math\nimport os\nfrom typing import Callable, Optional, Tuple, Dict, Sequence\n\nimport alpa\nfrom alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray,\n                              MeshHostWorker, create_remote_array_refs)\nfrom alpa.model.model_util import ModelOutput\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\nimport flax.linen as nn\nfrom flax.linen import combine_masks, dot_product_attention_weights, make_causal_mask\nimport jax\nimport flax\nfrom jax import lax\nimport jax.numpy as jnp\nfrom jax.tree_util import tree_flatten, tree_unflatten, tree_leaves\nfrom jax.interpreters import pxla\nimport jaxlib.xla_extension as jax_xla\nimport numpy as np\nimport ray\nimport torch\nfrom tqdm import tqdm\nfrom warnings import warn\n\nfrom llm_serving.model.opt_model import init_cache_aval, init_mask_aval\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.swish,\n    \"swish\": nn.swish,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\n@flax.struct.dataclass\nclass CodeGenModelOutput(ModelOutput):\n    last_hidden_state: jax_xla.DeviceArray\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None\n\n\n@flax.struct.dataclass\nclass CodeGenLMOutput(ModelOutput):\n    logits: jax_xla.DeviceArray\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None\n\n\n@dataclass(frozen=True)\nclass CodeGenConfig:\n    pad: int = 1\n    vocab_size: int = 50400\n    max_seq_len: int = 2048\n    n_ctx: int = 2048\n    hidden_size: int = 4096\n    num_hidden_layers: int = 28\n    n_head: int = 16\n    rotary_dim: int = 64\n    n_inner: int = None\n    activation_fn: str = 'gelu_new'\n    resid_pdrop: float = 0.0\n    embd_pdrop: float = 0.0\n    attn_pdrop: float = 0.0\n    layer_norm_eps: float = 1e-5\n    initializer_range: float = 0.02\n    scale_attn_weights: bool = True\n    bos_token_id: int = 50256\n    eos_token_id: int = 50256\n    # Added\n    decoder_input_dim: int = 4096\n    decoder_ffn_embed_dim: int = 16384\n    dtype: any = jnp.float16\n    num_pp_stages: int = None\n    tie_word_embeddings: bool = False\n    use_cache: bool = True\n    # parallelize\n    mark_boundary: bool = True\n\n\n# Copied from transformers.models.gptj.modeling_flax_gptj.create_sinusoidal_positions\ndef create_sinusoidal_positions(num_pos, dim):\n    inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))\n    sinusoid_inp = np.einsum(\"i , j -> i j\", np.arange(num_pos), inv_freq).astype(\"float32\")\n    sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp)\n\n    sentinel = dim // 2 + dim % 2\n    out = np.zeros((num_pos, dim))\n    out[:, 0:sentinel] = sin\n    out[:, sentinel:] = cos\n\n    return jnp.array(out, dtype=jnp.float16)\n\n# Copied from transformers.models.gptj.modeling_flax_gptj.rotate_every_two\ndef rotate_every_two(tensor):\n    rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)\n    rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,))\n    return rotate_half_tensor\n\n# Copied from transformers.models.gptj.modeling_flax_gptj.apply_rotary_pos_emb\ndef apply_rotary_pos_emb(tensor, sincos):\n    sin_pos, cos_pos = sincos\n    sin_pos = sin_pos[:, :, None, :].repeat(2, 3)\n    cos_pos = cos_pos[:, :, None, :].repeat(2, 3)\n    return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)\n\nclass CodeGenAttention(nn.Module):\n    config: CodeGenConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        if self.config.hidden_size % self.config.n_head != 0:\n            raise ValueError(\n                f\"`hidden_size`: {self.config.hidden_size} has to be a \"\n                f\"multiple of `n_head`: {self.config.n_head}\"\n            )\n\n        self.embed_dim = self.config.hidden_size\n        self.head_dim = self.config.hidden_size // self.config.n_head\n        self.rotary_dim = self.config.rotary_dim\n\n        self.qkv_combined = nn.Dense(\n            self.config.hidden_size * 3,\n            dtype=self.dtype,\n            use_bias=False\n        )\n        \n        self.out_proj = nn.Dense(self.config.hidden_size, dtype=self.dtype, use_bias=False)\n        self.resid_dropout = nn.Dropout(rate=self.config.resid_pdrop)\n\n        pos_embd_dim = self.rotary_dim or self.embed_dim\n        self.embed_positions = create_sinusoidal_positions(self.config.max_seq_len, pos_embd_dim)\n\n    def _split_heads(self, hidden_states):\n        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.n_head, self.head_dim))\n\n    def _merge_heads(self, hidden_states):\n        return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))\n\n    def __call__(self,\n                 hidden_states,\n                 position_ids,\n                 output_attentions: bool = False,\n                 attention_cache=None,\n                 attention_mask=None,\n                 deterministic:bool = True):\n\n        batch_size = hidden_states.shape[0]\n        seq_length = hidden_states.shape[1]\n        fused_qkv = self.qkv_combined(hidden_states)\n        mp_num = 4 # number of cores on their TPU\n        qkv_split = fused_qkv.reshape(fused_qkv.shape[:-1] + (mp_num, -1))\n        query, value, key = jnp.split(qkv_split, 3, axis=-1)\n        query = self._split_heads(query)\n        key = self._split_heads(key)\n        value = self._split_heads(value)\n        key_length = attention_mask.shape[-1]\n        causal_attention_mask = make_causal_mask(jnp.ones((batch_size, key_length)), dtype=\"bool\")\n\n        expanded = jax.nn.one_hot(position_ids, self.embed_positions.shape[0], dtype=self.dtype)\n        sincos = expanded @ jnp.asarray(self.embed_positions, self.dtype)\n        sincos = jnp.split(sincos, 2, axis=-1)\n        if self.rotary_dim is not None:\n            k_rot = key[:, :, :, : self.rotary_dim]\n            k_pass = key[:, :, :, self.rotary_dim :]\n\n            q_rot = query[:, :, :, : self.rotary_dim]\n            q_pass = query[:, :, :, self.rotary_dim :]\n\n            k_rot = apply_rotary_pos_emb(k_rot, sincos)\n            q_rot = apply_rotary_pos_emb(q_rot, sincos)\n\n            key = jnp.concatenate([k_rot, k_pass], axis=-1)\n            query = jnp.concatenate([q_rot, q_pass], axis=-1)\n        else:\n            key = apply_rotary_pos_emb(key, sincos)\n            query = apply_rotary_pos_emb(query, sincos)\n            \n        # for fast decoding causal attention mask should be shifted\n        if attention_cache:\n            causal_attention_mask_shift = attention_cache[2][0]\n        else:\n            causal_attention_mask_shift = 0\n\n        if attention_cache:\n            max_decoder_length = attention_cache[0].shape[1]\n            causal_attention_mask = jax.lax.dynamic_slice(\n                causal_attention_mask,\n                (0, 0, causal_attention_mask_shift, 0),\n                (1, 1, seq_length, max_decoder_length)\n            )\n\n            # Handle a special kind of internal padding added by alpa.\n            # Note that this kind of internal padding is different from\n            # the padding added by the tokenizer. This internal padding\n            # should not update cache and step_ct\n            # shape: [B, 1, 1, S_max]\n            is_internal_padding = (attention_mask == 2)\n            num_internal_pad = jnp.sum(is_internal_padding, axis=3).reshape(-1)\n            attention_mask = (attention_mask == 1)\n\n        attention_mask = combine_masks(attention_mask, causal_attention_mask)\n\n        if attention_cache:\n            cache_key, cache_value, cache_index = attention_cache\n            *batch_dims, max_length, num_heads, depth_per_head = cache_key.shape\n            # update key, value caches with our new 1d spatial slices\n            cur_index = cache_index[0]\n            indices = (0,) * len(batch_dims) + (cur_index, 0, 0)\n            key = lax.dynamic_update_slice(cache_key, key, indices)\n            value = lax.dynamic_update_slice(cache_value, value, indices)\n            cache_key = key\n            cache_value = value\n            num_updated_cache_vectors = query.shape[1]\n            # A line added from bloom_model\n            attention_cache = key, value, cache_index + num_updated_cache_vectors - num_internal_pad\n            # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.\n            pad_mask = jnp.broadcast_to(\n                jnp.arange(max_length) < cur_index + num_updated_cache_vectors,\n                tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),\n            )\n            attention_mask = combine_masks(pad_mask, attention_mask)\n\n        dropout_rng = None\n        if not deterministic and self.config.attention_dropout > 0.0:\n            dropout_rng = self.make_rng(\"dropout\")\n\n        # transform boolean mask into float mask\n        mask_value = jnp.finfo(self.dtype).min\n        attention_bias = lax.select(\n            attention_mask > 0,\n            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),\n            jnp.full(attention_mask.shape, mask_value).astype(self.dtype),\n        )\n\n        attn_weights = dot_product_attention_weights(\n            query,\n            key,\n            bias=attention_bias,\n            dropout_rng=dropout_rng,\n            dropout_rate=self.config.attn_pdrop,\n            deterministic=deterministic,\n            dtype=self.dtype,\n            precision=None,\n        )\n\n        attn_output = jnp.einsum(\"...hqk,...khd->...qhd\", attn_weights, value)\n        attn_output = self._merge_heads(attn_output)\n        attn_output = self.out_proj(attn_output)\n        attn_output = self.resid_dropout(attn_output, deterministic=deterministic)\n\n        outputs = (attn_output, attention_cache,\n                   attn_weights) if output_attentions else (attn_output,\n                                                            attention_cache)\n        return outputs\n\n\nclass CodeGenBlock(nn.Module):\n    config: CodeGenConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        hidden_size = self.config.hidden_size\n\n        self.self = CodeGenAttention(self.config, dtype=self.dtype)\n        self.mlp = CodeGenMLP(self.config)\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                       dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 position_ids = None,\n                 deterministic: bool = True,\n                 output_attentions: bool = False,\n                 attention_cache=None,\n                 attention_mask=None):\n        residual = hidden_states\n        hidden_states = self.layer_norm(hidden_states)\n        attn_outputs = self.self(hidden_states,\n                                 position_ids=position_ids,\n                                 output_attentions=output_attentions,\n                                 attention_cache=attention_cache,\n                                 attention_mask=attention_mask)\n        attn_output = attn_outputs[0]\n        attention_cache = attn_outputs[1]\n        \n        feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)\n        hidden_states = attn_output + feed_forward_hidden_states + residual\n        outputs = (hidden_states, attention_cache)\n\n        if output_attentions:\n            outputs += (attn_outputs[2],)\n\n        return outputs\n\n\nclass CodeGenMLP(nn.Module):\n    config: CodeGenConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)\n\n        self.fc_in = nn.Dense(\n            4 * self.config.hidden_size,\n            dtype=self.dtype,\n            kernel_init=kernel_init\n        )\n        self.fc_out = nn.Dense(\n            self.config.hidden_size,\n            dtype=self.dtype,\n            kernel_init=kernel_init\n        )\n        self.act = ACT2FN[self.config.activation_fn]\n        self.dropout = nn.Dropout(self.config.resid_pdrop)\n\n    def __call__(self,\n                 hidden_states,\n                 deterministic: bool = True):\n        hidden_states = self.fc_in(hidden_states)\n        hidden_states = self.act(hidden_states)\n        hidden_states = self.fc_out(hidden_states)\n        hidden_states = self.dropout(hidden_states, deterministic=deterministic)\n        return hidden_states\n\nclass CodeGenTransformerLayerCollection(nn.Module):\n    config: CodeGenConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        self.layers = [ \n            CodeGenBlock(self.config, name=str(i), dtype=self.dtype)\n            for i in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(\n        self,\n        hidden_states,\n        position_ids,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n        attention_mask=None\n    ):\n        all_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        new_attention_cache = () if attention_cache is not None else None\n\n        if self.config.num_pp_stages is not None:\n            if self.config.num_hidden_layers % self.config.num_pp_stages != 0:\n                warn(\"The number of hidden layers is not divisible by the number of stages\")\n            layers_per_stage = self.config.num_hidden_layers // self.config.num_pp_stages\n\n        for i, layer in enumerate(self.layers):\n            if self.config.num_pp_stages is not None:\n                if i % layers_per_stage == 0 and i != 0:\n                    stage_id = i // layers_per_stage\n                    if self.config.mark_boundary and i // layers_per_stage < self.config.num_pp_stages:\n                        mark_pipeline_boundary()\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n            layer_attention_cache = None\n            if attention_cache is not None:\n                layer_attention_cache = attention_cache[i]\n            layer_outputs = layer(hidden_states,\n                                  position_ids=position_ids,\n                                  output_attentions=output_attentions,\n                                  attention_cache=layer_attention_cache,\n                                  attention_mask=attention_mask)\n            hidden_states = layer_outputs[0]\n            if attention_cache is not None:\n                new_attention_cache += (layer_outputs[1],)\n            \n            if output_attentions:\n                all_attentions += (layer_outputs[2],)\n\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        outputs = (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in outputs if v is not None)\n\n        return CodeGenModelOutput(last_hidden_state=hidden_states,\n                              hidden_states=all_hidden_states,\n                              attentions=all_attentions,\n                              attention_cache=new_attention_cache)\n\n\nclass CodeGenTransformerModule(nn.Module):\n    config: CodeGenConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        self.wte = nn.Embed(\n            self.config.vocab_size,\n            self.config.hidden_size,\n            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n            dtype=self.dtype\n        )\n\n        self.drop = nn.Dropout(rate=self.config.embd_pdrop)\n\n        self.encoder = CodeGenTransformerLayerCollection(self.config,\n                                                     dtype=self.dtype)\n\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                           dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        position_ids,\n        deterministic:bool = True,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n        attention_mask=None\n    ):\n        input_embeds = self.wte(input_ids.astype(\"i4\"))\n        \n        hidden_states = self.drop(input_embeds, deterministic=deterministic)\n\n        outputs = self.encoder(\n            hidden_states,\n            position_ids=position_ids,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            attention_cache=attention_cache,\n            attention_mask=attention_mask\n        )\n        hidden_states = outputs[0]\n        hidden_states = self.layer_norm(hidden_states)\n\n        if output_hidden_states:\n            all_hidden_states = outputs.hidden_states + (hidden_states,)\n            outputs = CodeGenModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache)\n        else:\n            outputs = CodeGenModelOutput(last_hidden_state=hidden_states, hidden_states=outputs.hidden_states, attentions=outputs.attentions, attention_cache=outputs.attention_cache)\n\n        if not return_dict:\n            return (hidden_states,) + outputs[1:]\n\n        return outputs\n\n\nclass CodeGenForLMModule(nn.Module):\n    config: CodeGenConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        self.transformers = CodeGenTransformerModule(config=self.config,\n                                                 dtype=self.dtype)\n\n        self.lm_head = nn.Dense(\n            self.config.vocab_size,\n            dtype=jnp.float32,\n            kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n        )\n\n    def __call__(\n        self,\n        input_ids,\n        position_ids,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n        attention_mask=None\n    ):\n        # Model\n        outputs = self.transformers(\n            input_ids=input_ids,\n            position_ids=position_ids,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            attention_cache=attention_cache,\n            attention_mask=attention_mask\n        )\n\n        hidden_states = outputs[0]\n\n        if self.config.tie_word_embeddings:\n            shared_kernel = self.transformers.variables[\"params\"][\"wte\"][\"embedding\"].T\n            logits = self.lm_head.apply({\"params\": {\"kernel\": shared_kernel}}, hidden_states)\n        else:\n            logits = self.lm_head(hidden_states)\n        \n        # Compute the prediction scores\n        if not return_dict:\n            return (logits,) + outputs[1:]\n\n        return CodeGenLMOutput(\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            attention_cache=outputs.attention_cache,\n        )\n\ndef get_config(name, **kwargs):\n    if name in [\"codegen-350m-mono\", \"codegen-350m-multi\", \"codegen-350m-nl\"]:\n        config = CodeGenConfig(\n            max_seq_len=2048, num_hidden_layers=20, n_head=16,\n            hidden_size=1024, decoder_input_dim=1024, decoder_ffn_embed_dim=1024 * 4,\n            rotary_dim=32, bos_token_id=1, vocab_size=51200\n        )\n    elif name in [\"codegen-2b-mono\", \"codegen-2b-multi\", \"codegen-2b-nl\"]:\n        config = CodeGenConfig(\n            max_seq_len=2048, num_hidden_layers=32, n_head=32,\n            hidden_size=2560, decoder_input_dim=2560, decoder_ffn_embed_dim=2560 * 4,\n            rotary_dim=64, bos_token_id=1, vocab_size=51200\n        )\n    elif name in [\"codegen-6b-mono\", \"codegen-6b-multi\", \"codegen-6b-nl\"]:\n        config = CodeGenConfig(\n            max_seq_len=2048, num_hidden_layers=33, n_head=16,\n            hidden_size=4096, decoder_input_dim=4096, decoder_ffn_embed_dim=4096 * 4,\n            rotary_dim=64, bos_token_id=1, vocab_size=51200\n        )\n    elif name in [\"codegen-16b-mono\", \"codegen-16b-multi\", \"codegen-16b-nl\"]:\n        config = CodeGenConfig(\n            max_seq_len=2048, num_hidden_layers=34, n_head=24,\n            hidden_size=6144, decoder_input_dim=6144, decoder_ffn_embed_dim=6144 * 4,\n            rotary_dim=64, bos_token_id=1, vocab_size=51200\n        )\n    else:\n        raise ValueError(f\"Invalid model name: {name}\")\n\n    return dataclasses.replace(config, **kwargs)\n\ndef init_model_aval(config):\n    \"\"\"Initialize model with parameters with abstract values (shape-only arrays).\"\"\"\n    model = CodeGenForLMModule(config, dtype=config.dtype)\n    rngkey = jax.core.ShapedArray((2,), jnp.uint32)\n    input_ids = jax.core.ShapedArray((1, 2), jnp.int32)\n    position_ids = jax.core.ShapedArray((1, 2), jnp.int32)\n    attention_mask = jax.core.ShapedArray((1, 1, 1, 2), jnp.int32)\n    params = jax.eval_shape(model.init, rngkey, input_ids, position_ids, attention_mask=attention_mask)\n    params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, config.dtype),\n                          params)\n    return model, params\n\ndef init_cache_np(config, batch_size):\n    \"\"\"Init cache with numpy arrays.\"\"\"\n    np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16\n    head_dim = config.hidden_size // config.n_head\n\n    all_cache = []\n    for i in range(config.num_hidden_layers):\n        layer_cache = (\n            np.zeros((batch_size, config.max_seq_len,\n                      config.n_head, head_dim),\n                     dtype=np_dtype),\n            np.zeros((batch_size, config.max_seq_len,\n                      config.n_head, head_dim),\n                     dtype=np_dtype),\n            np.zeros((batch_size,), np.int32),\n        )\n        all_cache.append(layer_cache)\n    return tuple(all_cache)\n\ndef inference_step_no_cache(params, batch, apply_func):\n    logits = apply_func(params, batch[\"input_ids\"], batch[\"position_ids\"])[0]\n    return logits\n\n\ndef load_params_np(params, path, config, dummy=False):\n    \"\"\"Load parameters with numpy arrays.\"\"\"\n    if dummy:\n        np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16\n        return jax.tree_map(lambda x: np.full(x.shape, 1e-9, np_dtype), params)\n\n    def load_array(key):\n        return np.load(os.path.join(path, key))\n\n    def load_param(param_key, loaded_array, is_position_embedding=False):\n        param_dict = params\n        param_keys = param_key.split('.')\n        for i, key in enumerate(param_keys):\n            if i == len(param_keys) - 1:\n                if dummy:\n                    param_dict[key] = jax.core.ShapedArray(\n                        param_dict[key].shape, param_dict[key].dtype)\n                else:\n                    if not is_position_embedding:\n                        assert param_dict[key].shape == loaded_array.shape, (\n                                f\"{param_dict[key].shape} vs. {loaded_array.shape}\")\n                    else:\n                        shape = param_dict[key].shape\n                        if shape != loaded_array.shape:\n                            assert shape[1] == loaded_array.shape[1]\n                            loaded_array = loaded_array[:shape[0], :]\n                    param_dict[key] = loaded_array\n            else:\n                param_dict = param_dict[key]\n\n    params = params.unfreeze()\n    load_param(\"params.transformers.layer_norm.scale\", load_array(\"ln_f.weight\"))\n    load_param(\"params.transformers.layer_norm.bias\", load_array(\"ln_f.bias\"))\n    load_param(\"params.transformers.wte.embedding\", load_array(\"wte.weight\"))\n    load_param(\"params.lm_head.bias\", load_array(\"lm_head.bias\"))\n    load_param(\"params.lm_head.kernel\", load_array(\"lm_head.weight\").transpose())\n\n    for i in tqdm(range(config.num_hidden_layers)):\n        param_prefix = f\"params.transformers.encoder.{i}.\"\n        load_prefix = f\"h.{i}.\"\n        # Attention weights\n        load_param(\n            param_prefix + \"self.out_proj.kernel\",\n            load_array(load_prefix + \"attn.out_proj.weight\").transpose())\n        load_param(\n            param_prefix + \"self.qkv_combined.kernel\",\n            load_array(load_prefix + \"attn.qkv_proj.weight\").transpose())\n\n        load_param(param_prefix + \"layer_norm.scale\",\n                   load_array(load_prefix + \"ln_1.weight\"))\n        load_param(param_prefix + \"layer_norm.bias\",\n                   load_array(load_prefix + \"ln_1.bias\"))\n\n        # MLP weights\n        load_param(param_prefix + \"mlp.fc_in.kernel\",\n                   load_array(load_prefix + \"mlp.fc_in.weight\").transpose())\n        load_param(param_prefix + \"mlp.fc_in.bias\",\n                   np.transpose(load_array(load_prefix + \"mlp.fc_in.bias\")))\n        load_param(param_prefix + \"mlp.fc_out.bias\",\n                   load_array(load_prefix + \"mlp.fc_out.bias\"))\n        load_param(param_prefix + \"mlp.fc_out.kernel\",\n                   load_array(load_prefix + \"mlp.fc_out.weight\").transpose())\n\n    return flax.core.freeze(params)\n\ndef get_jax_executable(config: CodeGenConfig,\n                       encoder_chunk_sizes: Sequence[int],\n                       output_attentions: bool = False,\n                       output_hidden_states:bool = False):\n    \"\"\"Get a single-gpu executable.\"\"\"\n    model, params = init_model_aval(config)\n\n    @jax.jit\n    def inference_step(params, batch):\n        output = model.apply(params,\n                             input_ids=batch[\"input_ids\"],\n                             position_ids=batch[\"position_ids\"],\n                             attention_cache=batch[\"cache\"],\n                             attention_mask=batch[\"mask\"],\n                             output_attentions=output_attentions,\n                             output_hidden_states=output_hidden_states)\n        return output\n\n    executables = {}\n    for length in encoder_chunk_sizes:\n        executables[length] = inference_step\n    return executables, params\n\n\ndef get_pipeshard_executable(config: CodeGenConfig,\n                             batch_size: int,\n                             encoder_chunk_sizes: Sequence[int],\n                             num_micro_batches: int = 1,\n                             output_attentions: bool = False,\n                             output_hidden_states: bool = False,\n                             autoregressive: bool = True):\n    \"\"\"Get a parallel executable.\"\"\"\n    # Init model\n    model, params = init_model_aval(config)\n\n    # Parallelize\n    method = alpa.PipeshardParallel(\n        num_micro_batches=num_micro_batches,\n        pipeline_schedule=\"inference\",\n        layer_option=\"manual\",\n        default_auto_sharding_option=alpa.AutoShardingOption(\n            # Force operator model parallel\n            force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0,\n            # Disabling all-to-all and all-gather generates better intra-op strategies.\n            allow_all_to_all=False,\n            allow_all_gather=False,\n        ))\n    \n    def inference_step_with_cache(params, batch):\n        output = model.apply(\n            params,\n            batch[\"input_ids\"],\n            batch[\"position_ids\"],\n            attention_cache=batch[\"cache\"],\n            attention_mask=batch[\"mask\"],\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states)\n        return output\n\n    alpa.global_config.always_donate_micro_batch_vars = False\n\n    cache = init_cache_aval(config, batch_size)\n    mask = init_mask_aval(config, batch_size)\n\n    executables = {}\n\n    # Compile an executable with sequence length 1\n    executable = alpa.parallelize(\n        inference_step_with_cache, batch_argnums=(1,),\n        method=method).get_executable(\n            params, {\n                \"input_ids\":\n                    jax.core.ShapedArray((batch_size, 1), jnp.int32),\n                \"position_ids\":\n                    jax.core.ShapedArray((batch_size, 1), jnp.int32),\n                \"cache\":\n                    cache,\n                \"mask\":\n                    mask,\n            })\n    executable.dump_debug_info(\"tmp_executable_1\")\n    executables[1] = executable\n\n    # Create another parallel method with assigned input sharding specs\n    method_with_input_sharding = alpa.PipeshardParallel(\n        num_micro_batches=num_micro_batches,\n        pipeline_schedule=\"inference\",\n        layer_option=\"manual\",\n        default_auto_sharding_option=alpa.AutoShardingOption(\n            enable_auto_sharding=False,\n        ),\n        stage_input_shardings=executable.stage_input_shard_specs)\n\n    # Compile other executables\n    for seq_len in encoder_chunk_sizes:\n        executable = alpa.parallelize(\n            inference_step_with_cache,\n            batch_argnums=(1,),\n            method=method_with_input_sharding).get_executable(\n                params, {\n                    \"input_ids\":\n                        jax.core.ShapedArray(\n                            (batch_size, seq_len), jnp.int32),\n                    \"position_ids\":\n                        jax.core.ShapedArray(\n                            (batch_size, seq_len), jnp.int32),\n                    \"cache\":\n                        cache,\n                    \"mask\":\n                        mask,\n                })\n        executable.dump_debug_info(\"tmp_executable_%d\" % seq_len)\n        executables[seq_len] = executable\n    return executables, params\n\n\ndef load_codegen_params_worker_func(self, path, prefix_to_idx, config, shapes,\n                                uuids, indices, mesh_ids):\n    \"\"\"The worker function to load CodeGen parameters.\"\"\"\n\n    def load_array(key):\n        return np.load(os.path.join(path, key))\n\n    def load_param(param_key, loaded_array, is_position_embedding=False):\n        i = prefix_to_idx[param_key]\n\n        for j in range(len(mesh_ids[i])):\n            if self.mesh_id != mesh_ids[i][j]:\n                # print(f\"skipping {param_key} on mesh {self.mesh_id} which is on  {mesh_ids[i][j]} and {uuids[i][j]}\")\n                continue\n            \n            if not is_position_embedding:\n                assert shapes[i][j] == loaded_array.shape, (\n                    f\"{shapes[i][j]} vs. {loaded_array.shape}\")\n            else:\n                if shapes[i][j] != loaded_array.shape:\n                    assert shapes[i][j][1] == loaded_array.shape[1]\n                    loaded_array = loaded_array[:shapes[i][j][0], :]\n            uuid = uuids[i][j]\n            datas = []\n            for k in range(len(self.local_devices)):\n                idx = self.host_id * len(self.local_devices) + k\n                datas.append(loaded_array[indices[i][j][idx]])\n            self.put_buffers(uuid, datas)\n\n    layers_per_stage = config.num_hidden_layers // config.num_pp_stages\n\n    load_param(\"params.transformers.layer_norm.scale\", load_array(\"ln_f.weight\"))\n    load_param(\"params.transformers.layer_norm.bias\", load_array(\"ln_f.bias\"))\n    load_param(\"params.transformers.wte.embedding\", load_array(\"wte.weight\"))\n    load_param(\"params.lm_head.bias\", load_array(\"lm_head.bias\"))\n    load_param(\"params.lm_head.kernel\", load_array(\"lm_head.weight\").transpose())\n    \n    for i in range(config.num_hidden_layers):\n        stage_id = i // layers_per_stage\n        if i // layers_per_stage  == config.num_pp_stages: # special case for codegen-6b\n            stage_id = config.num_pp_stages - 1\n        if stage_id != self.mesh_id:\n            continue\n\n        param_prefix = f\"params.transformers.encoder.{i}.\"\n        load_prefix = f\"h.{i}.\"\n        # Attention weights\n        load_param(\n            param_prefix + \"self.out_proj.kernel\",\n            load_array(load_prefix + \"attn.out_proj.weight\").transpose())\n        load_param(\n            param_prefix + \"self.qkv_combined.kernel\",\n            load_array(load_prefix + \"attn.qkv_proj.weight\").transpose())\n\n        load_param(param_prefix + \"layer_norm.scale\",\n                   load_array(load_prefix + \"ln_1.weight\"))\n        load_param(param_prefix + \"layer_norm.bias\",\n                   load_array(load_prefix + \"ln_1.bias\"))\n\n        # MLP weights\n        load_param(param_prefix + \"mlp.fc_in.kernel\",\n                   load_array(load_prefix + \"mlp.fc_in.weight\").transpose())\n        load_param(param_prefix + \"mlp.fc_in.bias\",\n                   np.transpose(load_array(load_prefix + \"mlp.fc_in.bias\")))\n        load_param(param_prefix + \"mlp.fc_out.bias\",\n                   load_array(load_prefix + \"mlp.fc_out.bias\"))\n        load_param(param_prefix + \"mlp.fc_out.kernel\",\n                   load_array(load_prefix + \"mlp.fc_out.weight\").transpose())\n\n\nsetattr(MeshHostWorker, \"load_codegen_params_worker_func\",\n        load_codegen_params_worker_func)\n\n\ndef load_params_dis_array(path, executable, params_aval, config, dummy=False):\n    \"\"\"Load parameters with distributed arrays.\"\"\"\n    if dummy:\n        alpa.global_config.use_dummy_value_for_benchmarking = True\n        params_info, _ = executable.get_input_placement_specs()\n        flat_args, in_tree = tree_flatten(params_aval)\n        flat_info = tree_leaves(params_info)\n        if hasattr(executable, \"mesh_group\"):\n            ret = executable.mesh_group.shard_args_to_arrays(\n                flat_info, flat_args)\n        else:\n            ret = executable.physical_mesh.shard_args_to_arrays_ps(\n                flat_info, flat_args)\n        alpa.global_config.use_dummy_value_for_benchmarking = False\n        return ret\n\n    params_info, _ = executable.get_input_placement_specs()\n\n    prefix_to_flat_idx = {}\n    ct = itertools.count()\n\n    def dfs(dict_tree, result_dict, cur_prefix):\n        if isinstance(dict_tree, (dict, flax.core.FrozenDict)):\n            for key in dict_tree.keys():\n                dfs(dict_tree[key], result_dict,\n                    cur_prefix + (\".\" if cur_prefix else \"\") + key)\n        else:\n            result_dict[cur_prefix] = next(ct)\n\n    dfs(params_aval, prefix_to_flat_idx, \"\")\n\n    flat_infos, in_tree = tree_flatten(params_info)\n\n    flat_shapes = []\n    flat_uuids = []\n    flat_indices = []\n    flat_mesh_ids = []\n    flat_arrays = []\n\n    mesh_group = executable.mesh_group\n\n    for info in flat_infos:\n        aval = info.aval\n        if len(info.mesh_ids) == 1:\n            mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0]\n            indices = pxla.spec_to_indices(aval.shape, spec)\n            ary_refs, ary_uuid = create_remote_array_refs(mesh)\n            flat_shapes.append([aval.shape])\n            flat_uuids.append([ary_uuid[0]])\n            flat_indices.append([indices])\n            flat_mesh_ids.append([mesh.mesh_id])\n            flat_arrays.append(\n                DistributedArray(mesh, aval, spec, ary_refs[0], indices))\n        else:\n            tmp_shapes = []\n            tmp_uuids = []\n            tmp_indices = []\n            tmp_mesh_ids = []\n            tmp_arrays = []\n            tmp_meshes = []\n            for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs):\n                mesh = mesh_group[mesh_id]\n                indices = pxla.spec_to_indices(aval.shape, spec)\n                ary_refs, ary_uuid = create_remote_array_refs(mesh)\n                array = DistributedArray(mesh, aval, spec, ary_refs[0], indices)\n                tmp_shapes.append(aval.shape)\n                tmp_uuids.append(ary_uuid[0])\n                tmp_indices.append(indices)\n                tmp_mesh_ids.append(mesh.mesh_id)\n                tmp_meshes.append(mesh)\n                tmp_arrays.append(array)\n            flat_shapes.append(tuple(tmp_shapes))\n            flat_uuids.append(tuple(tmp_uuids))\n            flat_indices.append(tuple(tmp_indices))\n            flat_mesh_ids.append(tuple(tmp_mesh_ids))\n            flat_arrays.append(\n                ReplicatedDistributedArray(tmp_meshes, tmp_arrays))\n\n    for m in executable.mesh_group.meshes:\n        for w in m.workers:\n            w.load_codegen_params_worker_func.remote(path, prefix_to_flat_idx,\n                                                 config, flat_shapes,\n                                                 flat_uuids, flat_indices,\n                                                 flat_mesh_ids)\n\n    return flat_arrays\n\n\ndef init_cache_dis_array(executable, config, batch_size, dummy=False):\n    \"\"\"Initialize cache with distributed arrays.\"\"\"\n    cache = init_cache_np(config, batch_size)\n    alpa.global_config.use_dummy_value_for_benchmarking = dummy\n    _, batch_info = executable.get_input_placement_specs()\n    flat_args, in_tree = tree_flatten(cache)\n    flat_info = tree_leaves(batch_info[\"cache\"])\n    if hasattr(executable, \"mesh_group\"):\n        ret = executable.mesh_group.shard_args_to_arrays(flat_info, flat_args)\n    else:\n        ret = executable.physical_mesh.shard_args_to_arrays_ps(\n            flat_info, flat_args)\n    alpa.global_config.use_dummy_value_for_benchmarking = False\n    return ret\n\n\ndef load_multi_executable_params_dis_array(path,\n                                           executables,\n                                           params_aval,\n                                           config,\n                                           dummy=False):\n    \"\"\"Load parameters to workers that will be used by all executables. Accordingly,\n    we need to make sure the parameter sharding specs are identical for all executables.\n    \"\"\"\n    shared_input_shard_specs = None\n    for executable in executables.values():\n        stage_input_shard_specs = executable.stage_input_shard_specs\n        if shared_input_shard_specs is not None:\n            assert shared_input_shard_specs == stage_input_shard_specs, \\\n                \"All executables must have the same input sharding specs.\"\n        else:\n            shared_input_shard_specs = stage_input_shard_specs\n    return load_params_dis_array(path,\n                                 list(executables.values())[0], params_aval,\n                                 config, dummy)\n\n\ndef init_multi_executable_cache_dis_array(executables,\n                                          config,\n                                          batch_size,\n                                          dummy=False):\n    \"\"\"Initialize cache to workers that will be used by all executables. Accordingly,\n    we need to make sure all executables are using the same cache.\n    \"\"\"\n    cache_info = None\n    for executable in executables.values():\n        _, batch_info = executable.get_input_placement_specs()\n        if cache_info is not None:\n            assert cache_info == batch_info[\"cache\"], \\\n                \"All executables must share the same cache\"\n        else:\n            cache_info = batch_info[\"cache\"]\n    return init_cache_dis_array(\n        list(executables.values())[0], config, batch_size, dummy)\n"
  },
  {
    "path": "examples/llm_serving/model/opt_model.py",
    "content": "\"\"\"OPT model implementation.\"\"\"\nimport dataclasses\nfrom dataclasses import dataclass\nfrom functools import partial\nimport itertools\nimport math\nimport os\nfrom typing import Callable, Optional, Tuple, Dict, Sequence\n\nimport alpa\nfrom alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray,\n                              MeshHostWorker, create_remote_array_refs)\nfrom alpa.model.model_util import ModelOutput\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\nimport flax.linen as nn\nimport jax\nimport flax\nfrom jax import lax\nimport jax.numpy as jnp\nfrom jax.tree_util import tree_flatten, tree_unflatten, tree_leaves\nfrom jax.interpreters import pxla\nimport jaxlib.xla_extension as jax_xla\nimport numpy as np\nimport ray\nfrom tqdm import tqdm\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.swish,\n    \"swish\": nn.swish,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\n@flax.struct.dataclass\nclass OPTModelOutput(ModelOutput):\n    last_hidden_state: jax_xla.DeviceArray\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None\n\n\n@flax.struct.dataclass\nclass OPTLMOutput(ModelOutput):\n    logits: jax_xla.DeviceArray\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attentions: Optional[Tuple[jax_xla.DeviceArray]] = None\n    attention_cache: Optional[Tuple[Tuple[jax_xla.DeviceArray]]] = None\n\n\n@dataclass(frozen=True)\nclass OPTConfig:\n    # Inherited from OPT\n    num_hidden_layers: int = 12\n    max_seq_len: int = 2048\n    hidden_size: int = 768\n    n_head: int = 12\n    input_dim: int = 768\n    ffn_embed_dim: int = 3072\n    pad: int = 1\n    activation_fn: str = 'relu'\n    dtype: any = jnp.float16\n    use_stable_embedding: bool = False\n    no_scale_embedding: bool = True\n    decoder_learned_pos: bool = True\n    decoder_normalize_before: bool = True\n    share_decoder_input_output_embed: bool = True\n    # Added\n    version: int = 1\n    vocab_size: int = 50272\n    layer_norm_eps: float = 0.00001\n    num_pp_stages: int = None\n    # parallelize\n    mark_boundary: bool = True\n\n\nclass OPTEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word, position and token_type embeddings.\"\"\"\n\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        assert not self.config.use_stable_embedding\n        self.embed_scale = 1.0 if self.config.no_scale_embedding else math.sqrt(\n            self.config.hidden_size)\n        self.word_embeddings = nn.Embed(\n            self.config.vocab_size,\n            self.config.input_dim,\n            dtype=self.dtype,\n        )\n        assert self.config.max_seq_len is not None\n        assert self.config.decoder_learned_pos\n        self.position_embeddings = nn.Embed(\n            self.config.max_seq_len + self.config.pad + 1,\n            self.config.hidden_size,\n            dtype=self.dtype,\n        )\n        self.project_in_dim = nn.Dense(\n            self.config.hidden_size,\n            dtype=self.dtype,\n        ) if self.config.input_dim != self.config.hidden_size else None\n\n    def __call__(self, input_ids, position_ids):\n        # Embed\n        inputs_embeds = self.embed_scale * self.word_embeddings(\n            input_ids.astype(\"i4\"))\n        if self.project_in_dim is not None:\n            inputs_embeds = self.project_in_dim(inputs_embeds)\n        position_embeds = self.position_embeddings(position_ids.astype(\"i4\"))\n\n        # Sum all embeddings\n        hidden_states = inputs_embeds + position_embeds\n        return hidden_states\n\n\nclass OPTSelfAttention(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        if self.config.hidden_size % self.config.n_head != 0:\n            raise ValueError(\n                f\"`hidden_size`: {self.config.hidden_size} has to be a \"\n                f\"multiple of `n_head`: {self.config.decoder_attention_heads}\"\n            )\n\n        self.qkv_combined = nn.Dense(\n            self.config.hidden_size * 3,\n            dtype=self.dtype,\n        )\n\n    def __call__(self,\n                 hidden_states,\n                 output_attentions: bool = False,\n                 attention_cache=None,\n                 attention_mask=None):\n        head_dim = self.config.hidden_size // self.config.n_head\n\n        qkv_combined_states = self.qkv_combined(hidden_states)\n        qkv_combined_states = qkv_combined_states.reshape(\n            qkv_combined_states.shape[:2] + (-1, 3))\n        query_states, key_states, value_states = jnp.split(qkv_combined_states,\n                                                           3,\n                                                           axis=3)\n        # shape: [B, S, #head, head_dim]\n        query_states = query_states.reshape(hidden_states.shape[:2] + (\n            self.config.n_head, head_dim))\n        # shape: [B, S, #head, head_dim]\n        value_states = value_states.reshape(hidden_states.shape[:2] + (\n            self.config.n_head, head_dim))\n        # shape: [B, S, #head, head_dim]\n        key_states = key_states.reshape(hidden_states.shape[:2] +\n                                        (self.config.n_head,\n                                         head_dim))\n\n        batch_size = hidden_states.shape[0]\n        if attention_cache is None:\n            query_len, key_len = query_states.shape[1], key_states.shape[1]\n            assert query_len == key_len\n            # shape: [B, 1, S_max, S_max]\n            causal_mask = nn.make_causal_mask(\n                jnp.ones((batch_size, key_len)), dtype=\"bool\")\n            # shape: [B, 1, 1, S_max]\n            input_mask = attention_mask\n            # shape: [B, 1, S_max, S_max]\n            mask = nn.combine_masks(causal_mask, input_mask, dtype=\"bool\")\n        else:\n            cache_key, cache_value, cache_index = attention_cache\n            cache_index_ = cache_index[0]\n            update_indices = (0, cache_index_, 0, 0)\n            # shape: [B, S_max, #head, head_dim]\n            key_states = lax.dynamic_update_slice(cache_key, key_states, update_indices)\n            # shape: [B, S_max, #head, head_dim]\n            value_states = lax.dynamic_update_slice(cache_value, value_states, update_indices)\n            query_len, key_len = query_states.shape[1], key_states.shape[1]\n\n            if attention_mask is not None:\n                # Handle a special kind of internal padding added by alpa.\n                # Note that this kind of internal padding is different from\n                # the padding added by the tokenizer. This internal padding\n                # should not update cache and step_ct\n                # shape: [B, 1, 1, S_max]\n                is_internal_padding = (attention_mask == 2)\n                num_internal_pad = jnp.sum(is_internal_padding, axis=3).reshape(-1)\n                attention_mask = (attention_mask == 1)\n            else:\n                num_internal_pad = 0\n            attention_cache = key_states, value_states, cache_index + query_len - num_internal_pad\n\n            # shape: [B, 1, S_max, S_max]\n            causal_mask = nn.make_causal_mask(\n                jnp.ones((batch_size, key_len)), dtype=\"bool\")\n            # shape: [B, 1, S, S_max]\n            causal_mask = lax.dynamic_slice(causal_mask,\n                (0, 0, cache_index_, 0), (batch_size, 1, query_len, key_len))\n            # shape: [B, 1, 1, S_max]\n            input_mask = attention_mask\n            # shape: [B, 1, S, S_max]\n            mask = nn.combine_masks(causal_mask, input_mask, dtype=\"bool\")\n\n        attn_weights = nn.attention.dot_product_attention_weights(\n            query_states,\n            key_states,\n            mask=mask,\n            dtype=self.dtype,\n            precision=None,\n        )\n\n        attn_output = jnp.einsum(\"...hqk,...khd->...qhd\", attn_weights,\n                                 value_states)\n        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))\n\n        outputs = (attn_output, attention_cache,\n                   attn_weights) if output_attentions else (attn_output,\n                                                            attention_cache)\n        return outputs\n\n\nclass OPTAttention(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        assert self.config.decoder_normalize_before\n        self.self = OPTSelfAttention(self.config, dtype=self.dtype)\n        self.dense = nn.Dense(\n            self.config.hidden_size,\n            dtype=self.dtype,\n        )\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                       dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 output_attentions: bool = False,\n                 attention_cache=None,\n                 attention_mask=None):\n        residual = hidden_states\n        hidden_states = self.layer_norm(hidden_states)\n        attn_outputs = self.self(hidden_states,\n                                 output_attentions=output_attentions,\n                                 attention_cache=attention_cache,\n                                 attention_mask=attention_mask)\n        attn_output = attn_outputs[0]\n        attention_cache = attn_outputs[1]\n        hidden_states = self.dense(attn_output)\n        hidden_states = hidden_states + residual\n        outputs = (hidden_states, attention_cache)\n\n        if output_attentions:\n            outputs += (attn_outputs[2],)\n\n        return outputs\n\n\nclass OPTFFN(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        self.fc1 = nn.Dense(\n            self.config.ffn_embed_dim,\n            dtype=self.dtype,\n        )\n        self.activation = ACT2FN[self.config.activation_fn]\n        self.fc2 = nn.Dense(\n            self.config.hidden_size,\n            dtype=self.dtype,\n        )\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                       dtype=self.dtype)\n\n    def __call__(self, hidden_states):\n        residual = hidden_states\n        hidden_states = self.layer_norm(hidden_states)\n        hidden_states = self.activation(self.fc1(hidden_states))\n        hidden_states = self.fc2(hidden_states)\n        hidden_states = hidden_states + residual\n        return hidden_states\n\n\nclass OPTTransformerLayer(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        assert self.config.decoder_normalize_before\n        assert not getattr(self.config, \"cross_self_attention\", False)\n        assert not getattr(self.config, \"scale_heads\", False)\n        assert not getattr(self.config, \"scale_attn\", False)\n        assert not getattr(self.config, \"scale_fc\", False)\n        self.attention = OPTAttention(self.config, dtype=self.dtype)\n        self.ffn = OPTFFN(self.config, dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 output_attentions: bool = False,\n                 attention_cache=None,\n                 attention_mask=None):\n\n        attention_outputs = self.attention(hidden_states,\n                                           output_attentions=output_attentions,\n                                           attention_cache=attention_cache,\n                                           attention_mask=attention_mask)\n        attention_output = attention_outputs[0]\n        attention_cache = attention_outputs[1]\n\n        hidden_states = self.ffn(attention_output)\n\n        outputs = (hidden_states, attention_cache)\n\n        if output_attentions:\n            outputs += (attention_outputs[2],)\n        return outputs\n\n\nclass OPTTransformerLayerCollection(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        self.layers = [\n            OPTTransformerLayer(self.config, name=str(i), dtype=self.dtype)\n            for i in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(\n        self,\n        hidden_states,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n        attention_mask=None\n    ):\n        all_attentions = () if output_attentions else None\n        all_hidden_states = () if output_hidden_states else None\n        new_attention_cache = () if attention_cache is not None else None\n\n        if self.config.num_pp_stages is not None:\n            assert self.config.num_hidden_layers % self.config.num_pp_stages == 0\n            layers_per_stage = self.config.num_hidden_layers // self.config.num_pp_stages\n\n        for i, layer in enumerate(self.layers):\n            if self.config.num_pp_stages is not None:\n                if i % layers_per_stage == 0 and i != 0:\n                    if self.config.mark_boundary:\n                        mark_pipeline_boundary()\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n            layer_attention_cache = None\n            if attention_cache is not None:\n                layer_attention_cache = attention_cache[i]\n            layer_outputs = layer(hidden_states,\n                                  output_attentions=output_attentions,\n                                  attention_cache=layer_attention_cache,\n                                  attention_mask=attention_mask)\n            hidden_states = layer_outputs[0]\n            if attention_cache is not None:\n                new_attention_cache += (layer_outputs[1],)\n            if output_attentions:\n                all_attentions += (layer_outputs[2],)\n\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        outputs = (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in outputs if v is not None)\n\n        return OPTModelOutput(last_hidden_state=hidden_states,\n                              hidden_states=all_hidden_states,\n                              attentions=all_attentions,\n                              attention_cache=new_attention_cache)\n\n\nclass OPTTransformerModule(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        assert self.config.decoder_normalize_before\n        self.embeddings = OPTEmbeddings(self.config, dtype=self.dtype)\n        self.encoder = OPTTransformerLayerCollection(self.config,\n                                                     dtype=self.dtype)\n        if self.config.version > 2:\n            self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                           dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        position_ids,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n        attention_mask=None\n    ):\n        hidden_states = self.embeddings(input_ids, position_ids)\n        outputs = self.encoder(\n            hidden_states,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            attention_cache=attention_cache,\n            attention_mask=attention_mask\n        )\n        hidden_states = outputs[0]\n        if self.config.version > 2:\n            hidden_states = self.layer_norm(hidden_states)\n\n        if not return_dict:\n            # if pooled is None, don't return it\n            return (hidden_states,) + outputs[1:]\n\n        return OPTModelOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            attention_cache=outputs.attention_cache,\n        )\n\n\nclass OPTForLMModule(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16\n    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros\n\n    def setup(self):\n        self.transformers = OPTTransformerModule(config=self.config,\n                                                 dtype=self.dtype)\n\n        self.project_out_dim = nn.Dense(\n            self.config.input_dim,\n            dtype=self.dtype,\n        ) if self.config.input_dim != self.config.hidden_size else None\n\n        if self.config.share_decoder_input_output_embed:\n            self.decoder = None\n        else:\n            self.decoder = nn.Dense(self.config.vocab_size,\n                                    dtype=self.dtype,\n                                    use_bias=False)\n\n    def __call__(\n        self,\n        input_ids,\n        position_ids,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n        attention_mask=None\n    ):\n        # Model\n        outputs = self.transformers(\n            input_ids,\n            position_ids,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            attention_cache=attention_cache,\n            attention_mask=attention_mask\n        )\n\n        hidden_states = outputs[0]\n\n        if self.project_out_dim is not None:\n            hidden_states = self.project_out_dim(hidden_states)\n\n        if self.config.share_decoder_input_output_embed:\n            if self.dtype == jnp.float16:\n                shared_embedding = self.transformers.embeddings.word_embeddings.embedding_fp16\n            else:\n                shared_embedding = self.transformers.variables[\"params\"][\n                    \"embeddings\"][\"word_embeddings\"][\"embedding\"]\n            assert self.decoder is None\n            logits = hidden_states @ shared_embedding.T\n        else:\n            assert self.decoder is not None\n            logits = self.decoder(hidden_states)\n\n        # Compute the prediction scores\n        if not return_dict:\n            return (logits,) + outputs[1:]\n\n        return OPTLMOutput(\n            logits=logits,\n            hidden_states=outputs.hidden_states,\n            attentions=outputs.attentions,\n            attention_cache=outputs.attention_cache,\n        )\n\n\ndef get_config(name, **kwargs):\n    if name == \"opt-125m\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=12, n_head=12,\n            hidden_size=768, input_dim=768, ffn_embed_dim=768 * 4,\n            version=3,\n        )\n    elif name == \"opt-350m\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=24, n_head=16,\n            hidden_size=1024, input_dim=1024, ffn_embed_dim=1024 * 4,\n            version=2,\n        )\n        raise NotImplementedError(\"Not implemented because this model \"\n                                  \"has a different architecture\")\n    elif name == \"opt-1.3b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=24, n_head=32,\n            hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4,\n            version=3,\n        )\n    elif name == \"opt-2.7b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=32, n_head=32,\n            hidden_size=2560, input_dim=2560, ffn_embed_dim=2560 * 4,\n            version=3,\n        )\n    elif name == \"opt-6.7b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=32, n_head=32,\n            hidden_size=4096, input_dim=4096, ffn_embed_dim=4096 * 4,\n            version=3,\n        )\n    elif name == \"opt-30b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=48, n_head=56,\n            hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4,\n            version=3,\n        )\n    elif name == \"opt-66b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=64, n_head=72,\n            hidden_size=9216, input_dim=9216, ffn_embed_dim=9216 * 4,\n            version=3,\n        )\n    elif name == \"opt-175b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=96, n_head=96,\n            hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4,\n            version=3,\n        )\n    elif name == \"opt-iml-1.3b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=24, n_head=32,\n            hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4,\n            version=3,\n        )\n    elif name == \"opt-iml-30b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=48, n_head=56,\n            hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4,\n            version=3,\n        )\n    elif name == \"opt-iml-175b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=96, n_head=96,\n            hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4,\n            version=3,\n        )\n    elif name == \"opt-iml-max-1.3b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=24, n_head=32,\n            hidden_size=2048, input_dim=2048, ffn_embed_dim=2048 * 4,\n            version=3,\n        )\n    elif name == \"opt-iml-max-30b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=48, n_head=56,\n            hidden_size=7168, input_dim=7168, ffn_embed_dim=7168 * 4,\n            version=3,\n        )\n    elif name == \"opt-iml-max-175b\":\n        config = OPTConfig(\n            max_seq_len=2048, num_hidden_layers=96, n_head=96,\n            hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4,\n            version=3,\n        )\n    else:\n        raise ValueError(f\"Invalid model name: {name}\")\n\n    return dataclasses.replace(config, **kwargs)\n\n\ndef init_model_aval(config):\n    \"\"\"Initialize model with parameters with abstract values (shape-only arrays).\"\"\"\n    model = OPTForLMModule(config, dtype=config.dtype)\n    rngkey = jax.core.ShapedArray((2,), jnp.uint32)\n    input_ids = jax.core.ShapedArray((1, 128), jnp.int32)\n    position_ids = jax.core.ShapedArray((1, 128), jnp.int32)\n    params = jax.eval_shape(model.init, rngkey, input_ids, position_ids)\n    params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, config.dtype),\n                          params)\n    return model, params\n\n\ndef init_cache_aval(config, batch_size):\n    \"\"\"Initialize cache with abstract values (shape-only arrays).\"\"\"\n    dtype = config.dtype\n    head_dim = config.hidden_size // config.n_head\n\n    all_cache = []\n    for _ in range(config.num_hidden_layers):\n        layer_cache = (\n            jax.core.ShapedArray((batch_size, config.max_seq_len,\n                                  config.n_head, head_dim),\n                                 dtype),\n            jax.core.ShapedArray((batch_size, config.max_seq_len,\n                                  config.n_head, head_dim),\n                                 dtype),\n            jax.core.ShapedArray((batch_size,), jnp.int32),\n        )\n        all_cache.append(layer_cache)\n    return tuple(all_cache)\n\n\ndef init_mask_aval(config, batch_size):\n    \"\"\"Initialize attention mask with abstract values (shape-only arrays).\"\"\"\n    mask = jax.core.ShapedArray((batch_size, 1, 1, config.max_seq_len), dtype=np.int8)\n    return mask\n\n\ndef init_cache_np(config, batch_size):\n    \"\"\"Init cache with numpy arrays.\"\"\"\n    np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16\n    head_dim = config.hidden_size // config.n_head\n\n    all_cache = []\n    for i in range(config.num_hidden_layers):\n        layer_cache = (\n            np.zeros((batch_size, config.max_seq_len,\n                      config.n_head, head_dim),\n                     dtype=np_dtype),\n            np.zeros((batch_size, config.max_seq_len,\n                      config.n_head, head_dim),\n                     dtype=np_dtype),\n            np.zeros((batch_size,), np.int32),\n        )\n        all_cache.append(layer_cache)\n    return tuple(all_cache)\n\n\ndef build_position_ids(input_ids, padding_idx):\n    mask = (input_ids != padding_idx).astype(np.int32)\n    position_ids = np.cumsum(mask, axis=1).astype(np.int32) * mask + padding_idx\n    return position_ids\n\n\ndef inference_step_no_cache(params, batch, apply_func):\n    logits = apply_func(params, batch[\"input_ids\"], batch[\"position_ids\"])[0]\n    return logits\n\n\ndef load_params_np(params, path, config, dummy=False):\n    \"\"\"Load parameters with numpy arrays.\"\"\"\n    if dummy:\n        np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16\n        return jax.tree_map(lambda x: np.full(x.shape, 1e-9, np_dtype), params)\n\n    def load_array(key):\n        return np.load(os.path.join(path, key))\n\n    def load_param(param_key, loaded_array, is_position_embedding=False):\n        param_dict = params\n        param_keys = param_key.split('.')\n        for i, key in enumerate(param_keys):\n            if i == len(param_keys) - 1:\n                if dummy:\n                    param_dict[key] = jax.core.ShapedArray(\n                        param_dict[key].shape, param_dict[key].dtype)\n                else:\n                    if not is_position_embedding:\n                        assert param_dict[key].shape == loaded_array.shape, (\n                                f\"{param_dict[key].shape} vs. {loaded_array.shape}\")\n                    else:\n                        shape = param_dict[key].shape\n                        if shape != loaded_array.shape:\n                            assert shape[1] == loaded_array.shape[1]\n                            loaded_array = loaded_array[:shape[0], :]\n                    param_dict[key] = loaded_array\n            else:\n                param_dict = param_dict[key]\n\n    params = params.unfreeze()\n    load_param(\"params.transformers.embeddings.word_embeddings.embedding\",\n               load_array(\"decoder.embed_tokens.weight\"))\n    load_param(\"params.transformers.embeddings.position_embeddings.embedding\",\n               load_array(\"decoder.embed_positions.weight\"),\n               is_position_embedding=True)\n    if config.version > 2:\n        load_param(\"params.transformers.layer_norm.scale\",\n                   load_array(\"decoder.layer_norm.weight\"))\n        load_param(\"params.transformers.layer_norm.bias\",\n                   load_array(\"decoder.layer_norm.bias\"))\n    for i in tqdm(range(config.num_hidden_layers)):\n        param_prefix = f\"params.transformers.encoder.{i}.\"\n        load_prefix = f\"decoder.layers.{i}.\"\n        # Attention weights\n        wq = load_array(load_prefix + \"self_attn.q_proj.weight\")\n        wk = load_array(load_prefix + \"self_attn.k_proj.weight\")\n        wv = load_array(load_prefix + \"self_attn.v_proj.weight\")\n        dim = wq.shape[-1]\n        w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape(\n            (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1))\n        load_param(param_prefix + \"attention.self.qkv_combined.kernel\", w_qkv)\n        bq = load_array(load_prefix + \"self_attn.q_proj.bias\")\n        bk = load_array(load_prefix + \"self_attn.k_proj.bias\")\n        bv = load_array(load_prefix + \"self_attn.v_proj.bias\")\n        b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape(\n            (3, dim)).transpose([1, 0]).reshape((-1,))\n        load_param(param_prefix + \"attention.self.qkv_combined.bias\", b_qkv)\n        load_param(\n            param_prefix + \"attention.dense.kernel\",\n            np.transpose(load_array(load_prefix + \"self_attn.out_proj.weight\")))\n        load_param(param_prefix + \"attention.dense.bias\",\n                   load_array(load_prefix + \"self_attn.out_proj.bias\"))\n        load_param(param_prefix + \"attention.layer_norm.scale\",\n                   load_array(load_prefix + \"self_attn_layer_norm.weight\"))\n        load_param(param_prefix + \"attention.layer_norm.bias\",\n                   load_array(load_prefix + \"self_attn_layer_norm.bias\"))\n        # FFN weights\n        load_param(param_prefix + \"ffn.fc1.bias\",\n                   load_array(load_prefix + \"fc1.bias\"))\n        load_param(param_prefix + \"ffn.fc1.kernel\",\n                   np.transpose(load_array(load_prefix + \"fc1.weight\")))\n        load_param(param_prefix + \"ffn.fc2.bias\",\n                   load_array(load_prefix + \"fc2.bias\"))\n        load_param(param_prefix + \"ffn.fc2.kernel\",\n                   np.transpose(load_array(load_prefix + \"fc2.weight\")))\n        load_param(param_prefix + \"ffn.layer_norm.scale\",\n                   load_array(load_prefix + \"final_layer_norm.weight\"))\n        load_param(param_prefix + \"ffn.layer_norm.bias\",\n                   load_array(load_prefix + \"final_layer_norm.bias\"))\n\n    return flax.core.freeze(params)\n\n\ndef get_jax_executable(config: OPTConfig,\n                       encoder_chunk_sizes: Sequence[int],\n                       output_attentions: bool = False,\n                       output_hidden_states: bool = False):\n    \"\"\"Get a single-gpu executable.\"\"\"\n    model, params = init_model_aval(config)\n\n    @jax.jit\n    def inference_step(params, batch):\n        output = model.apply(params,\n                             batch[\"input_ids\"],\n                             batch[\"position_ids\"],\n                             attention_cache=batch[\"cache\"],\n                             attention_mask=batch[\"mask\"],\n                             output_attentions=output_attentions,\n                             output_hidden_states=output_hidden_states)\n        return output\n\n    executables = {}\n    for length in encoder_chunk_sizes:\n        executables[length] = inference_step\n    return executables, params\n\n\ndef get_pipeshard_executable(config: OPTConfig,\n                             batch_size: int,\n                             encoder_chunk_sizes: Sequence[int],\n                             num_micro_batches: int = 1,\n                             output_attentions: bool = False,\n                             output_hidden_states: bool = False):\n    \"\"\"Get a parallel executable.\"\"\"\n    # Init model\n    model, params = init_model_aval(config)\n\n    # Parallelize\n    method = alpa.PipeshardParallel(\n        num_micro_batches=num_micro_batches,\n        pipeline_schedule=\"inference\",\n        layer_option=\"manual\",\n        default_auto_sharding_option=alpa.AutoShardingOption(\n            # Force operator model parallel\n            force_batch_dim_to_mesh_dim=None if batch_size == 1 else 0,\n            # Disabling all-to-all and all-gather generates better intra-op strategies.\n            allow_all_to_all=False,\n            allow_all_gather=False,\n        ))\n    #method = alpa.ShardParallel()\n\n    def inference_step_with_cache(params, batch):\n        output = model.apply(\n            params,\n            batch[\"input_ids\"],\n            batch[\"position_ids\"],\n            attention_cache=batch[\"cache\"],\n            attention_mask=batch[\"mask\"],\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states)\n        return output\n\n    alpa.global_config.always_donate_micro_batch_vars = False\n\n    cache = init_cache_aval(config, batch_size)\n    mask = init_mask_aval(config, batch_size)\n\n    executables = {}\n\n    # Compile an executable with sequence length 1\n    executable = alpa.parallelize(\n        inference_step_with_cache, batch_argnums=(1,),\n        method=method).get_executable(\n            params, {\n                \"input_ids\":\n                    jax.core.ShapedArray((batch_size, 1), jnp.int32),\n                \"position_ids\":\n                    jax.core.ShapedArray((batch_size, 1), jnp.int32),\n                \"cache\":\n                    cache,\n                \"mask\":\n                    mask,\n            })\n    executable.dump_debug_info(\"tmp_executable_1\")\n    executables[1] = executable\n\n    # Create another parallel method with assigned input sharding specs\n    method_with_input_sharding = alpa.PipeshardParallel(\n        num_micro_batches=num_micro_batches,\n        pipeline_schedule=\"inference\",\n        layer_option=\"manual\",\n        default_auto_sharding_option=alpa.AutoShardingOption(\n            enable_auto_sharding=False,\n        ),\n        stage_input_shardings=executable.stage_input_shard_specs)\n\n    # Compile other executables\n    for seq_len in encoder_chunk_sizes:\n        executable = alpa.parallelize(\n            inference_step_with_cache,\n            batch_argnums=(1,),\n            method=method_with_input_sharding).get_executable(\n                params, {\n                    \"input_ids\":\n                        jax.core.ShapedArray(\n                            (batch_size, seq_len), jnp.int32),\n                    \"position_ids\":\n                        jax.core.ShapedArray(\n                            (batch_size, seq_len), jnp.int32),\n                    \"cache\":\n                        cache,\n                    \"mask\":\n                        mask,\n                })\n        executable.dump_debug_info(\"tmp_executable_%d\" % seq_len)\n        executables[seq_len] = executable\n    return executables, params\n\n    executable.dump_debug_info(\"tmp\")\n    return {seq_len: executable}, params\n\n\ndef load_opt_params_worker_func(self, path, prefix_to_idx, config, shapes,\n                                uuids, indices, mesh_ids):\n    \"\"\"The worker function to load OPT parameters.\"\"\"\n\n    def load_array(key):\n        return np.load(os.path.join(path, key))\n\n    def load_param(param_key, loaded_array, is_position_embedding=False):\n        i = prefix_to_idx[param_key]\n\n        for j in range(len(mesh_ids[i])):\n            if self.mesh_id != mesh_ids[i][j]:\n                continue\n\n            if not is_position_embedding:\n                assert shapes[i][j] == loaded_array.shape, (\n                    f\"{shapes[i][j]} vs. {loaded_array.shape}\")\n            else:\n                if shapes[i][j] != loaded_array.shape:\n                    assert shapes[i][j][1] == loaded_array.shape[1]\n                    loaded_array = loaded_array[:shapes[i][j][0], :]\n            uuid = uuids[i][j]\n            datas = []\n            for k in range(len(self.local_devices)):\n                idx = self.host_id * len(self.local_devices) + k\n                datas.append(loaded_array[indices[i][j][idx]])\n            self.put_buffers(uuid, datas)\n\n    load_param(\"params.transformers.embeddings.word_embeddings.embedding\",\n               load_array(\"decoder.embed_tokens.weight\"))\n    load_param(\"params.transformers.embeddings.position_embeddings.embedding\",\n               load_array(\"decoder.embed_positions.weight\"),\n               is_position_embedding=True)\n\n    if config.version > 2:\n        load_param(\"params.transformers.layer_norm.scale\",\n                   load_array(\"decoder.layer_norm.weight\"))\n        load_param(\"params.transformers.layer_norm.bias\",\n                   load_array(\"decoder.layer_norm.bias\"))\n\n    layers_per_stage = config.num_hidden_layers // config.num_pp_stages\n\n    for i in range(config.num_hidden_layers):\n        stage_id = i // layers_per_stage\n        if stage_id != self.mesh_id:\n            continue\n\n        param_prefix = f\"params.transformers.encoder.{i}.\"\n        load_prefix = f\"decoder.layers.{i}.\"\n        # Attention weights\n        wq = load_array(load_prefix + \"self_attn.q_proj.weight\")\n        wk = load_array(load_prefix + \"self_attn.k_proj.weight\")\n        wv = load_array(load_prefix + \"self_attn.v_proj.weight\")\n        dim = wq.shape[-1]\n        w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape(\n            (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1))\n        load_param(param_prefix + \"attention.self.qkv_combined.kernel\", w_qkv)\n        bq = load_array(load_prefix + \"self_attn.q_proj.bias\")\n        bk = load_array(load_prefix + \"self_attn.k_proj.bias\")\n        bv = load_array(load_prefix + \"self_attn.v_proj.bias\")\n        b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape(\n            (3, dim)).transpose([1, 0]).reshape((-1,))\n        load_param(param_prefix + \"attention.self.qkv_combined.bias\", b_qkv)\n        load_param(\n            param_prefix + \"attention.dense.kernel\",\n            np.transpose(load_array(load_prefix + \"self_attn.out_proj.weight\")))\n        load_param(param_prefix + \"attention.dense.bias\",\n                   load_array(load_prefix + \"self_attn.out_proj.bias\"))\n        load_param(param_prefix + \"attention.layer_norm.scale\",\n                   load_array(load_prefix + \"self_attn_layer_norm.weight\"))\n        load_param(param_prefix + \"attention.layer_norm.bias\",\n                   load_array(load_prefix + \"self_attn_layer_norm.bias\"))\n        # FFN weights\n        load_param(param_prefix + \"ffn.fc1.bias\",\n                   load_array(load_prefix + \"fc1.bias\"))\n        load_param(param_prefix + \"ffn.fc1.kernel\",\n                   np.transpose(load_array(load_prefix + \"fc1.weight\")))\n        load_param(param_prefix + \"ffn.fc2.bias\",\n                   load_array(load_prefix + \"fc2.bias\"))\n        load_param(param_prefix + \"ffn.fc2.kernel\",\n                   np.transpose(load_array(load_prefix + \"fc2.weight\")))\n        load_param(param_prefix + \"ffn.layer_norm.scale\",\n                   load_array(load_prefix + \"final_layer_norm.weight\"))\n        load_param(param_prefix + \"ffn.layer_norm.bias\",\n                   load_array(load_prefix + \"final_layer_norm.bias\"))\n\n\nsetattr(MeshHostWorker, \"load_opt_params_worker_func\",\n        load_opt_params_worker_func)\n\n\ndef load_params_dis_array(path, executable, params_aval, config, dummy=False):\n    \"\"\"Load parameters with distributed arrays.\"\"\"\n    if dummy:\n        alpa.global_config.use_dummy_value_for_benchmarking = True\n        params_info, _ = executable.get_input_placement_specs()\n        flat_args, in_tree = tree_flatten(params_aval)\n        flat_info = tree_leaves(params_info)\n        if hasattr(executable, \"mesh_group\"):\n            ret = executable.mesh_group.shard_args_to_arrays(\n                flat_info, flat_args)\n        else:\n            ret = executable.physical_mesh.shard_args_to_arrays_ps(\n                flat_info, flat_args)\n        alpa.global_config.use_dummy_value_for_benchmarking = False\n        return ret\n\n    params_info, _ = executable.get_input_placement_specs()\n\n    prefix_to_flat_idx = {}\n    ct = itertools.count()\n\n    def dfs(dict_tree, result_dict, cur_prefix):\n        if isinstance(dict_tree, (dict, flax.core.FrozenDict)):\n            for key in dict_tree.keys():\n                dfs(dict_tree[key], result_dict,\n                    cur_prefix + (\".\" if cur_prefix else \"\") + key)\n        else:\n            result_dict[cur_prefix] = next(ct)\n\n    dfs(params_aval, prefix_to_flat_idx, \"\")\n\n    flat_infos, in_tree = tree_flatten(params_info)\n\n    flat_shapes = []\n    flat_uuids = []\n    flat_indices = []\n    flat_mesh_ids = []\n    flat_arrays = []\n\n    mesh_group = executable.mesh_group\n\n    for info in flat_infos:\n        aval = info.aval\n        if len(info.mesh_ids) == 1:\n            mesh, spec = mesh_group[info.mesh_ids[0]], info.sharding_specs[0]\n            indices = pxla.spec_to_indices(aval.shape, spec)\n            ary_refs, ary_uuid = create_remote_array_refs(mesh)\n            flat_shapes.append([aval.shape])\n            flat_uuids.append([ary_uuid[0]])\n            flat_indices.append([indices])\n            flat_mesh_ids.append([mesh.mesh_id])\n            flat_arrays.append(\n                DistributedArray(mesh, aval, spec, ary_refs[0], indices))\n        else:\n            tmp_shapes = []\n            tmp_uuids = []\n            tmp_indices = []\n            tmp_mesh_ids = []\n            tmp_arrays = []\n            tmp_meshes = []\n            for mesh_id, spec in zip(info.mesh_ids, info.sharding_specs):\n                mesh = mesh_group[mesh_id]\n                indices = pxla.spec_to_indices(aval.shape, spec)\n                ary_refs, ary_uuid = create_remote_array_refs(mesh)\n                array = DistributedArray(mesh, aval, spec, ary_refs[0], indices)\n                tmp_shapes.append(aval.shape)\n                tmp_uuids.append(ary_uuid[0])\n                tmp_indices.append(indices)\n                tmp_mesh_ids.append(mesh.mesh_id)\n                tmp_meshes.append(mesh)\n                tmp_arrays.append(array)\n            flat_shapes.append(tuple(tmp_shapes))\n            flat_uuids.append(tuple(tmp_uuids))\n            flat_indices.append(tuple(tmp_indices))\n            flat_mesh_ids.append(tuple(tmp_mesh_ids))\n            flat_arrays.append(\n                ReplicatedDistributedArray(tmp_meshes, tmp_arrays))\n\n    for m in executable.mesh_group.meshes:\n        for w in m.workers:\n            w.load_opt_params_worker_func.remote(path, prefix_to_flat_idx,\n                                                 config, flat_shapes,\n                                                 flat_uuids, flat_indices,\n                                                 flat_mesh_ids)\n\n    return flat_arrays\n\n\ndef init_cache_dis_array(executable, config, batch_size, dummy=False):\n    \"\"\"Initialize cache with distributed arrays.\"\"\"\n    cache = init_cache_np(config, batch_size)\n    alpa.global_config.use_dummy_value_for_benchmarking = dummy\n    _, batch_info = executable.get_input_placement_specs()\n    flat_args, in_tree = tree_flatten(cache)\n    flat_info = tree_leaves(batch_info[\"cache\"])\n    if hasattr(executable, \"mesh_group\"):\n        ret = executable.mesh_group.shard_args_to_arrays(flat_info, flat_args)\n    else:\n        ret = executable.physical_mesh.shard_args_to_arrays_ps(\n            flat_info, flat_args)\n    alpa.global_config.use_dummy_value_for_benchmarking = False\n    return ret\n\n\ndef load_multi_executable_params_dis_array(path,\n                                           executables,\n                                           params_aval,\n                                           config,\n                                           dummy=False):\n    \"\"\"Load parameters to workers that will be used by all executables. Accordingly,\n    we need to make sure the parameter sharding specs are identical for all executables.\n    \"\"\"\n    shared_input_shard_specs = None\n    for executable in executables.values():\n        stage_input_shard_specs = executable.stage_input_shard_specs\n        if shared_input_shard_specs is not None:\n            assert shared_input_shard_specs == stage_input_shard_specs, \\\n                \"All executables must have the same input sharding specs.\"\n        else:\n            shared_input_shard_specs = stage_input_shard_specs\n    return load_params_dis_array(path,\n                                 list(executables.values())[0], params_aval,\n                                 config, dummy)\n\n\ndef init_multi_executable_cache_dis_array(executables,\n                                          config,\n                                          batch_size,\n                                          dummy=False):\n    \"\"\"Initialize cache to workers that will be used by all executables. Accordingly,\n    we need to make sure all executables are using the same cache.\n    \"\"\"\n    cache_info = None\n    for executable in executables.values():\n        _, batch_info = executable.get_input_placement_specs()\n        if cache_info is not None:\n            assert cache_info == batch_info[\"cache\"], \\\n                \"All executables must share the same cache\"\n        else:\n            cache_info = batch_info[\"cache\"]\n    return init_cache_dis_array(\n        list(executables.values())[0], config, batch_size, dummy)\n"
  },
  {
    "path": "examples/llm_serving/model/opt_model_1d.py",
    "content": "import heapq\nimport math\nimport queue\nimport time\nimport logging\n\nimport torch\nfrom dataclasses import dataclass\nfrom typing import Callable, Optional, Tuple, List, Union\n\nimport flax\nimport flax.linen as nn\nimport jax\nimport jax.numpy as jnp\nimport jaxlib.xla_extension as jax_xla\nimport numpy as np\nimport os\nfrom enum import Enum\nfrom functools import partial\n\nfrom alpa.model.model_util import ModelOutput\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\nfrom alpa.util import OrderedSet\nfrom alpa.timer import timers\nfrom examples.llm_serving.model.opt_utils import sync\n\n\ntry:\n    from ft_mha import fused_mmha, init_cache_manager, \\\n        prepare_inputs, free_cache, can_allocate\n    from ft_mha import Prompt as PromptInternal, DecodingToken as DecodingTokenInternal\nexcept ImportError:\n    raise RuntimeError(\"Please install ft_mha to use 1D OPT model.\")\n\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\nACT2FN = {\n    \"gelu\": partial(nn.gelu, approximate=False),\n    \"relu\": nn.relu,\n    \"silu\": nn.swish,\n    \"swish\": nn.swish,\n    \"gelu_new\": partial(nn.gelu, approximate=True),\n}\n\n\n@flax.struct.dataclass\nclass OPTModelOutput(ModelOutput):\n    last_hidden_state: jax_xla.DeviceArray\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n\n\n@flax.struct.dataclass\nclass OPTLMOutput(ModelOutput):\n    logits: jax_xla.DeviceArray\n    hidden_states: Optional[Tuple[jax_xla.DeviceArray]] = None\n\n\n@dataclass(frozen=True)\nclass OPTConfig:\n    # Inherited from OPT\n    num_hidden_layers: int = 12\n    max_seq_len: int = 2048\n    hidden_size: int = 768\n    n_head: int = 12\n    input_dim: int = 768\n    ffn_embed_dim: int = 3072\n    pad: int = 1\n    activation_fn: str = 'relu'\n    dtype: any = jnp.float16\n    use_stable_embedding: bool = False\n    no_scale_embedding: bool = True\n    decoder_learned_pos: bool = True\n    decoder_normalize_before: bool = True\n    share_decoder_input_output_embed: bool = True\n    # Added\n    version: int = 1\n    vocab_size: int = 50272\n    layer_norm_eps: float = 0.00001\n    num_pp_stages: int = None\n    # parallelize\n    mark_boundary: bool = True\n\n\nclass OPTEmbeddings(nn.Module):\n    \"\"\"Construct the embeddings from word, position and token_type embeddings.\"\"\"\n\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        assert not self.config.use_stable_embedding\n        self.embed_scale = 1.0 if self.config.no_scale_embedding else math.sqrt(\n            self.config.hidden_size)\n        self.word_embeddings = nn.Embed(\n            self.config.vocab_size,\n            self.config.input_dim,\n            dtype=self.dtype,\n        )\n        assert self.config.max_seq_len is not None\n        assert self.config.decoder_learned_pos\n        self.position_embeddings = nn.Embed(\n            self.config.max_seq_len + self.config.pad + 1,\n            self.config.hidden_size,\n            dtype=self.dtype,\n        )\n        self.project_in_dim = nn.Dense(\n            self.config.hidden_size,\n            dtype=self.dtype,\n        ) if self.config.input_dim != self.config.hidden_size else None\n\n    def __call__(self, input_ids, position_ids):\n        # Embed\n        inputs_embeds = self.embed_scale * self.word_embeddings(\n            input_ids.astype(\"i4\"))\n        if self.project_in_dim is not None:\n            inputs_embeds = self.project_in_dim(inputs_embeds)\n        position_embeds = self.position_embeddings(position_ids.astype(\"i4\"))\n\n        # Sum all embeddings\n        hidden_states = inputs_embeds + position_embeds\n        return hidden_states\n\n\nclass OPTSelfAttention(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        if self.config.hidden_size % self.config.n_head != 0:\n            raise ValueError(\n                f\"`hidden_size`: {self.config.hidden_size} has to be a \"\n                f\"multiple of `n_head`: {self.config.n_head}\"\n            )\n\n        self.qkv_combined = nn.Dense(\n            self.config.hidden_size * 3,\n            dtype=self.dtype,\n            use_bias=False,\n        )\n\n        # The fused_mmha kernel fuses the bias add, so we do not load the bias in Dense and\n        # instead feed it into the kernel.\n        head_dim = self.config.hidden_size // self.config.n_head\n        self.qkv_combined_bias = self.param(\n            'qkv_combined_bias', flax.linen.initializers.zeros,\n            (3, self.config.n_head, head_dim), self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 output_attentions: bool = False,\n                 attention_cache=None):\n        head_dim = self.config.hidden_size // self.config.n_head\n        assert attention_cache is not None, \"Attention cache must be provided for now\"\n\n        # Shape: [1D seq, heads, head_dim, 3]\n        qkv_combined_states = self.qkv_combined(hidden_states)\n        qkv_combined_states = qkv_combined_states.reshape(\n            qkv_combined_states.shape[:1] +\n            (self.config.n_head, head_dim, 3))\n\n        # Shape: [1D seq, 3, heads, head_dim]\n        qkv_combined_states = qkv_combined_states.transpose((0, 3, 1, 2))\n\n        # Shape of cache_key and cache_value: [batch * max_length, heads, head_dim]\n        # Shape of cache_index: [batch * max_length]\n        cache_key, cache_value = attention_cache\n\n        attn_output = fused_mmha(qkv_combined_states, self.qkv_combined_bias,\n                                 cache_key, cache_value)\n\n        attn_output = attn_output.reshape(attn_output.shape[:1] + (-1,))\n\n        if output_attentions:\n            print(\"Do not support output_attentions\")\n        return attn_output\n\n\nclass OPTAttention(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16\n\n    def setup(self):\n        assert self.config.decoder_normalize_before\n        self.self = OPTSelfAttention(self.config, dtype=self.dtype)\n        self.dense = nn.Dense(\n            self.config.hidden_size,\n            dtype=self.dtype,\n        )\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                       dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 output_attentions: bool = False,\n                 attention_cache=None):\n        residual = hidden_states\n        hidden_states = self.layer_norm(hidden_states)\n        attn_outputs = self.self(hidden_states,\n                                 output_attentions=output_attentions,\n                                 attention_cache=attention_cache)\n        hidden_states = self.dense(attn_outputs)\n        hidden_states = hidden_states + residual\n\n        return hidden_states\n\n\nclass OPTFFN(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        self.fc1 = nn.Dense(\n            self.config.ffn_embed_dim,\n            dtype=self.dtype,\n        )\n        self.activation = ACT2FN[self.config.activation_fn]\n        self.fc2 = nn.Dense(\n            self.config.hidden_size,\n            dtype=self.dtype,\n        )\n        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                       dtype=self.dtype)\n\n    def __call__(self, hidden_states):\n        residual = hidden_states\n        hidden_states = self.layer_norm(hidden_states)\n        hidden_states = self.activation(self.fc1(hidden_states))\n        hidden_states = self.fc2(hidden_states)\n        hidden_states = hidden_states + residual\n        return hidden_states\n\n\nclass OPTTransformerLayer(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        assert self.config.decoder_normalize_before\n        assert not getattr(self.config, \"cross_self_attention\", False)\n        assert not getattr(self.config, \"scale_heads\", False)\n        assert not getattr(self.config, \"scale_attn\", False)\n        assert not getattr(self.config, \"scale_fc\", False)\n        self.attention = OPTAttention(self.config, dtype=self.dtype)\n        self.ffn = OPTFFN(self.config, dtype=self.dtype)\n\n    def __call__(self,\n                 hidden_states,\n                 output_attentions: bool = False,\n                 attention_cache=None):\n\n        attention_outputs = self.attention(hidden_states,\n                                           output_attentions=output_attentions,\n                                           attention_cache=attention_cache)\n\n        hidden_states = self.ffn(attention_outputs)\n        return hidden_states\n\n\nclass OPTTransformerLayerCollection(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        self.layers = [\n            OPTTransformerLayer(self.config, name=str(i), dtype=self.dtype)\n            for i in range(self.config.num_hidden_layers)\n        ]\n\n    def __call__(\n        self,\n        hidden_states,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n    ):\n        all_hidden_states = () if output_hidden_states else None\n\n        if self.config.num_pp_stages is not None:\n            assert self.config.num_hidden_layers % self.config.num_pp_stages == 0\n            layers_per_stage = self.config.num_hidden_layers // self.config.num_pp_stages\n\n        for i, layer in enumerate(self.layers):\n            if self.config.num_pp_stages is not None:\n                if i % layers_per_stage == 0 and i != 0:\n                    stage_id = i // layers_per_stage\n                    if self.config.mark_boundary:\n                        mark_pipeline_boundary()\n\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n            layer_attention_cache = None\n            if attention_cache is not None:\n                layer_attention_cache = attention_cache[i]\n            hidden_states = layer(hidden_states,\n                                  output_attentions=output_attentions,\n                                  attention_cache=layer_attention_cache)\n\n        if output_hidden_states:\n            all_hidden_states += (hidden_states,)\n\n        outputs = (hidden_states,)\n\n        if not return_dict:\n            return tuple(v for v in outputs if v is not None)\n\n        return OPTModelOutput(last_hidden_state=hidden_states,\n                              hidden_states=all_hidden_states)\n\n\nclass OPTTransformerModule(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16  # the dtype of the computation\n\n    def setup(self):\n        assert self.config.decoder_normalize_before\n        self.embeddings = OPTEmbeddings(self.config, dtype=self.dtype)\n        self.encoder = OPTTransformerLayerCollection(self.config,\n                                                     dtype=self.dtype)\n        if self.config.version > 2:\n            self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,\n                                           dtype=self.dtype)\n\n    def __call__(\n        self,\n        input_ids,\n        position_ids,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n    ):\n        hidden_states = self.embeddings(input_ids, position_ids)\n        outputs = self.encoder(\n            hidden_states,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            attention_cache=attention_cache,\n        )\n        hidden_states = outputs[0]\n        if self.config.version > 2:\n            hidden_states = self.layer_norm(hidden_states)\n\n        if not return_dict:\n            # if pooled is None, don't return it\n            return (hidden_states,) + outputs[1:]\n\n        return OPTModelOutput(\n            last_hidden_state=hidden_states,\n            hidden_states=outputs.hidden_states)\n\n\nclass OPTForLMModule(nn.Module):\n    config: OPTConfig\n    dtype: jnp.dtype = jnp.float16\n    bias_init: Callable[..., jnp.ndarray] = jax.nn.initializers.zeros\n\n    def setup(self):\n        self.transformers = OPTTransformerModule(config=self.config,\n                                                 dtype=self.dtype)\n\n        self.project_out_dim = nn.Dense(\n            self.config.input_dim,\n            dtype=self.dtype,\n        ) if self.config.input_dim != self.config.hidden_size else None\n\n        if self.config.share_decoder_input_output_embed:\n            self.decoder = None\n        else:\n            self.decoder = nn.Dense(self.config.vocab_size,\n                                    dtype=self.dtype,\n                                    use_bias=False)\n\n    def __call__(\n        self,\n        input_ids,\n        position_ids,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n        return_dict: bool = True,\n        attention_cache=None,\n    ):\n        # Model\n        outputs = self.transformers(\n            input_ids,\n            position_ids,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states,\n            return_dict=return_dict,\n            attention_cache=attention_cache,\n        )\n\n        hidden_states = outputs[0]\n\n        if self.project_out_dim is not None:\n            hidden_states = self.project_out_dim(hidden_states)\n\n        if self.config.share_decoder_input_output_embed:\n            if self.dtype == jnp.float16:\n                shared_embedding = self.transformers.embeddings.word_embeddings.embedding_fp16\n            else:\n                shared_embedding = self.transformers.variables[\"params\"][\n                    \"embeddings\"][\"word_embeddings\"][\"embedding\"]\n            assert self.decoder is None\n            logits = hidden_states @ shared_embedding.T\n        else:\n            assert self.decoder is not None\n            logits = self.decoder(hidden_states)\n\n        # Compute the prediction scores\n        if not return_dict:\n            return (logits,) + outputs[1:]\n\n        return OPTLMOutput(\n            logits=logits,\n            hidden_states=outputs.hidden_states)\n\n\ndef init_model_aval(config, total_input_len, total_cache_len):\n    \"\"\"In 1D: we specify total_input_len and total_cache_len in advance.\"\"\"\n    model = OPTForLMModule(config, dtype=config.dtype)\n    rngkey = jax.core.ShapedArray((2,), jnp.uint32)\n    input_ids = jax.core.ShapedArray((total_input_len,), jnp.int32)\n    position_ids = jax.core.ShapedArray((total_input_len,), jnp.int32)\n    cache = init_cache_aval(config, total_cache_len)\n\n    params = jax.eval_shape(model.init,\n                            rngkey,\n                            input_ids,\n                            position_ids,\n                            attention_cache=cache)\n    params = jax.tree_map(lambda x: jax.ShapeDtypeStruct(x.shape, config.dtype),\n                          params)\n    return model, params\n\n\ndef init_cache_aval(config, total_cache_len):\n    dtype = config.dtype\n    head_dim = config.hidden_size // config.n_head\n\n    all_cache = []\n    for i in range(config.num_hidden_layers):\n        layer_cache = (\n            jax.core.ShapedArray((total_cache_len * config.n_head * head_dim,),\n                                 dtype),\n            jax.core.ShapedArray((total_cache_len * config.n_head * head_dim,),\n                                 dtype),\n        )\n        all_cache.append(layer_cache)\n    return tuple(all_cache)\n\n\ndef init_cache_np(config, total_cache_len):\n    \"\"\"Init cache per sequence with numpy arrays.\"\"\"\n    np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16\n    head_dim = config.hidden_size // config.n_head\n\n    all_cache = []\n    for i in range(config.num_hidden_layers):\n        layer_cache = (\n            np.zeros((total_cache_len * config.n_head * head_dim),\n                     dtype=np_dtype),\n            np.zeros((total_cache_len * config.n_head * head_dim),\n                     dtype=np_dtype),\n        )\n        all_cache.append(layer_cache)\n    return tuple(all_cache)\n\n\ndef build_position_ids(input_ids, padding_idx):\n    mask = (input_ids != padding_idx).astype(np.int32)\n    position_ids = np.cumsum(mask).astype(np.int32) * mask + padding_idx\n    return position_ids\n\n\nclass PromptStatus(Enum):\n    PROMPT = 1\n    DECODING = 2\n    FINISHED = 3\n\n\nclass Prompt:\n    def __init__(self, input_ids, sentence_id, max_length=2048):\n        self.input_ids = input_ids\n        self.sentence_id = sentence_id\n        self.status = PromptStatus.PROMPT\n        # states to be filled during generation\n        self.generated_ids = []\n        self.last_generated_id = None\n\n        # In v3, we have to use an internal Prompt object.\n        self.p = PromptInternal(seq_id=sentence_id,\n                                max_len=max_length,\n                                token_ids=self.input_ids)\n        # latency information\n        self.start_time = None\n        self.finish_time = None\n\n    def finish(self, finish_token_id):\n        self.finish_time = time.time()\n        self.status = PromptStatus.FINISHED\n        self.generated_ids.append(finish_token_id)\n        self.last_generated_id = finish_token_id\n\n    def add_token(self, token_id):\n        if self.status == PromptStatus.PROMPT:\n            self.status = PromptStatus.DECODING\n        else:\n            assert self.last_generated_id is not None and self.status == PromptStatus.DECODING\n            self.generated_ids.append(self.last_generated_id)\n        self.last_generated_id = token_id\n        # rewrite the internal object to DecodingToken\n        self.p = DecodingTokenInternal(seq_id=self.sentence_id, token_id=token_id)\n\n    def start(self):\n        self.start_time = time.time()\n\n    @property\n    def prompt_length(self):\n        return len(self.input_ids)\n\n    @property\n    def generation_length(self):\n        return len(self.generated_ids)\n\n    @property\n    def num_prev_tokens(self):\n        if self.status == PromptStatus.PROMPT:\n            return 0\n        else:\n            return self.prompt_length + self.generation_length\n\n    @property\n    def latency(self):\n        if self.status != PromptStatus.FINISHED:\n            raise RuntimeError(\"Unfinished prompt.\")\n        return self.finish_time - self.start_time\n\n    def print(self):\n        print(self.input_ids + \":\" + self.generated_ids)\n\n\nclass IterationLevelInputPool:\n    \"\"\"This pool is for iteration-level scheduling.\"\"\"\n    def __init__(self,\n                 input_pool_config,\n                 model_config,\n                 max_length=None,\n                 max_new_tokens=None):\n        self.batch_size = input_pool_config.batch_size\n        self.cache_size = input_pool_config.cache_size\n        self.model_config = model_config\n        self.max_length = max_length\n        self.max_new_tokens = max_new_tokens\n\n        # Cache space is associated and owned with Pool.\n        self.cache = jax.tree_map(jnp.array, init_cache_np(model_config, self.cache_size))\n        init_cache_manager(cache_size=self.cache_size)\n\n        # input pool states\n        self.todo = queue.Queue()\n        self.wip = OrderedSet()\n        self.done = OrderedSet()\n\n        # current batch state\n        self._current_batch = None\n        self._sentence_id_counter = 1\n\n        # model config\n        self.pad = self.model_config.pad if \"pad\" in dir(self.model_config) else 1\n        self.eos = self.model_config.eos_token_id if \"eos_token_id\" in dir(self.model_config) else 2\n\n    def is_finished(self):\n        return self.todo.empty() and len(self.wip) == 0\n\n    def enter_prompts(self, input_sequences: List[List[int]]):\n        \"\"\"Enter a new batch of prompts into self.\"\"\"\n        sentence_ids = self.next_sentence_id(len(input_sequences))\n\n        def max_new_tokens(seq_len):\n            n = 2048\n            if self.max_length:\n                n = min(n, self.max_length - seq_len)\n            if self.max_new_tokens:\n                n = min(n, self.max_new_tokens)\n            return n\n\n        for i, seq in enumerate(input_sequences):\n            p = Prompt(seq, sentence_ids[i], max_length=max_new_tokens(len(seq)) + len(seq))\n            self.todo.put(p)\n\n    def next(self):\n        \"\"\"Get the inputs for the next iteration from the pool.\"\"\"\n        # figure out WIP prompts and put their next token in a list\n        decoding_input = list(self.wip)\n        # re-batch new prompts, concat them into a list\n\n        prompt_input = []\n        proposals = []\n        batch_availability = self.batch_size - len(decoding_input)\n        while not self.todo.empty():\n            proposals.append(self.todo.queue[0])\n            proposals_length = [p.prompt_length for p in proposals]\n            num_new_tokens = sum(proposals_length)\n            # now we check if we can put this prompt into batch\n            if batch_availability < num_new_tokens:\n                break\n            if not can_allocate([p.p.max_len for p in proposals]):\n                break\n            prompt_input.append(self.todo.get())\n        logger.debug(f\"In this iteration {len(prompt_input)} new prompts enter.\")\n\n        # make input: prompts must go first\n        input = sum([p.input_ids for p in prompt_input], []) + [p.last_generated_id for p in decoding_input]\n        input = np.array(input + [self.pad] * (self.batch_size - len(input)), dtype=np.int32)\n\n        # make input index\n        input_index = []\n        for p in prompt_input:\n            input_index.extend([p.sentence_id] * p.prompt_length)\n        for p in decoding_input:\n            input_index.append(p.sentence_id)\n        input_index = np.array(input_index + [0] * (self.batch_size - len(input_index)), dtype=np.int32)\n\n        # make position ids\n\n        position_ids = []\n        for p in prompt_input:\n            start_idx = 1 + self.pad + p.num_prev_tokens\n            position_ids.extend([i for i in range(start_idx, start_idx + p.prompt_length)])\n        for p in decoding_input:\n            start_idx = 1 + self.pad + p.num_prev_tokens\n            position_ids.extend([start_idx])\n        position_ids = np.array(position_ids +  [0] * (self.batch_size - len(position_ids)), dtype=np.int32)\n\n        self._current_batch = prompt_input + decoding_input\n        logit_positions = []\n        i = -1\n        for p in prompt_input:\n            i += p.prompt_length\n            logit_positions.append(i)\n        for _ in decoding_input:\n            i += 1\n            logit_positions.append(i)\n\n        # start prompts for recording time\n        for p in prompt_input:\n            p.start()\n\n        # Call prepare_inputs before every inference_step.\n        prepare_inputs([prompt.p for prompt in prompt_input], [prompt.p for prompt in decoding_input])\n        # return inputs\n        return input, input_index, position_ids, logit_positions\n\n    def update(self, generated_ids):\n        \"\"\"Update the pool after one iteration of inference.\"\"\"\n        if self._current_batch is None:\n            raise RuntimeError(\"There is no pending batch so update() is unnecessary.\")\n        for generated_id, p in zip(generated_ids, self._current_batch):\n            # check EOS, move finished sentences from wip to finished queue\n            if self.check_exit_condition(p, generated_id):\n                if p.status == PromptStatus.DECODING:\n                    assert p in self.wip\n                    self.wip.remove(p)\n                exit_reason = \"EOS\" if generated_id == self.eos else \"reaching max length\"\n                logger.debug(f\"Prompt {p.sentence_id} exits because of {exit_reason}. \")\n                p.finish(generated_id)\n                free_cache(p.sentence_id)\n                self.done.add(p)\n            elif p.status == PromptStatus.PROMPT:\n                # PROMPT -> DECODING\n                p.add_token(generated_id)\n                self.wip.add(p)\n            elif p.status == PromptStatus.DECODING:\n                # DECODING -> DECODING\n                p.add_token(generated_id)\n            else:\n                raise RuntimeError(f\"Prompt status: {p.status} should not appear here.\" )\n\n    def get_results(self):\n        \"\"\"Return results sorted by their sentence id.\"\"\"\n        sorted_results = sorted(self.done, key=lambda x: x.sentence_id, reverse=False)\n        return [p.input_ids + p.generated_ids for p in sorted_results]\n\n    def get_latency(self):\n        \"\"\"Return the latency of each prompt following their sequence id.\"\"\"\n        sorted_results = sorted(self.done, key=lambda x: x.sentence_id, reverse=False)\n        return [p.latency for p in sorted_results]\n\n    def next_sentence_id(self, number):\n        counter = self._sentence_id_counter\n        if number == 1:\n            ret = [counter]\n        else:\n            ret = list(range(counter, counter + number))\n        self._sentence_id_counter = (counter + number) % (1 << 60)\n        return ret\n\n    def check_exit_condition(self, prompt, generated_id):\n        \"\"\"Check Exit condition: reaching EOS or reaching max length.\"\"\"\n        if generated_id == self.eos:\n            return True\n        if self.max_new_tokens:\n            if prompt.generation_length + 1 == self.max_new_tokens:\n                return True\n        if self.max_length:\n            if prompt.generation_length + 1 + prompt.prompt_length == self.max_length:\n                return True\n        return False\n\n\ndef unpad(inputs: Union[np.ndarray, torch.Tensor, List[List[int]]], pad=1):\n    if isinstance(inputs, np.ndarray) or isinstance(inputs, torch.Tensor):\n        inputs = inputs.tolist()\n    unpadded_inputs = []\n    for seq in inputs:\n        if pad in seq:\n            unpadded_inputs.append(seq[:seq.index(pad)])\n        else:\n            unpadded_inputs.append(seq)\n    return unpadded_inputs\n\n\ndef pad(inputs: Union[np.ndarray, torch.Tensor, List[List[int]]], pad=1):\n    if isinstance(inputs, np.ndarray) or isinstance(inputs, torch.Tensor):\n        inputs = inputs.tolist()\n    padded_inputs = []\n    target_len = max(len(seq) for seq in inputs)\n    for seq in inputs:\n        if len(seq) < target_len:\n            padded_inputs.append(seq + [pad] * (target_len - len(seq)))\n        else:\n            padded_inputs.append(seq)\n    return padded_inputs\n\n\ndef load_params_np(params, path, config, dummy=False):\n    \"\"\"Load parameterswith numpy arrays.\"\"\"\n    np_dtype = np.float32 if config.dtype == jnp.float32 else np.float16\n    if dummy:\n        return jax.tree_map(lambda x: np.full(x.shape, 1e-9, np_dtype), params)\n\n    def load_array(key):\n        return np.load(os.path.join(path, key))\n\n    def load_param(param_key, loaded_array):\n        param_dict = params\n        param_keys = param_key.split('.')\n        for i, key in enumerate(param_keys):\n            if i == len(param_keys) - 1:\n                if dummy:\n                    param_dict[key] = jax.core.ShapedArray(\n                        param_dict[key].shape, param_dict[key].dtype)\n                else:\n                    assert param_dict[key].shape == loaded_array.shape\n                    #assert param_dict[key].dtype == loaded_array.dtype\n                    param_dict[key] = loaded_array\n            else:\n                param_dict = param_dict[key]\n\n    head = config.n_head\n    head_dim = config.hidden_size // head\n\n    params = params.unfreeze()\n    load_param(\"params.transformers.embeddings.word_embeddings.embedding\",\n               load_array(\"decoder.embed_tokens.weight\"))\n    load_param(\"params.transformers.embeddings.position_embeddings.embedding\",\n               load_array(\"decoder.embed_positions.weight\"))\n    if config.version > 2:\n        load_param(\"params.transformers.layer_norm.scale\",\n                   load_array(\"decoder.layer_norm.weight\"))\n        load_param(\"params.transformers.layer_norm.bias\",\n                   load_array(\"decoder.layer_norm.bias\"))\n    for i in range(config.num_hidden_layers):\n        param_prefix = f\"params.transformers.encoder.{i}.\"\n        load_prefix = f\"decoder.layers.{i}.\"\n        # Attention weights\n        wq = load_array(load_prefix + \"self_attn.q_proj.weight\")\n        wk = load_array(load_prefix + \"self_attn.k_proj.weight\")\n        wv = load_array(load_prefix + \"self_attn.v_proj.weight\")\n        dim = wq.shape[-1]\n        w_qkv = np.concatenate([wq, wk, wv], axis=0).reshape(\n            (3, -1, dim)).transpose([2, 1, 0]).reshape((dim, -1))\n        load_param(param_prefix + \"attention.self.qkv_combined.kernel\", w_qkv)\n        bq = load_array(load_prefix + \"self_attn.q_proj.bias\")\n        bk = load_array(load_prefix + \"self_attn.k_proj.bias\")\n        bv = load_array(load_prefix + \"self_attn.v_proj.bias\")\n        # b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape(\n        #     (3, dim)).transpose([1, 0]).reshape((-1,))\n        # load_param(param_prefix + \"attention.self.qkv_combined.bias\", b_qkv)\n        b_qkv = np.concatenate([bq, bk, bv], axis=0).reshape(\n            (3, head, head_dim)).astype(np_dtype)\n        load_param(param_prefix + \"attention.self.qkv_combined_bias\", b_qkv)\n        load_param(\n            param_prefix + \"attention.dense.kernel\",\n            np.transpose(load_array(load_prefix + \"self_attn.out_proj.weight\")))\n        load_param(param_prefix + \"attention.dense.bias\",\n                   load_array(load_prefix + \"self_attn.out_proj.bias\"))\n        load_param(param_prefix + \"attention.layer_norm.scale\",\n                   load_array(load_prefix + \"self_attn_layer_norm.weight\"))\n        load_param(param_prefix + \"attention.layer_norm.bias\",\n                   load_array(load_prefix + \"self_attn_layer_norm.bias\"))\n        # FFN weights\n        load_param(param_prefix + \"ffn.fc1.bias\",\n                   load_array(load_prefix + \"fc1.bias\"))\n        load_param(param_prefix + \"ffn.fc1.kernel\",\n                   np.transpose(load_array(load_prefix + \"fc1.weight\")))\n        load_param(param_prefix + \"ffn.fc2.bias\",\n                   load_array(load_prefix + \"fc2.bias\"))\n        load_param(param_prefix + \"ffn.fc2.kernel\",\n                   np.transpose(load_array(load_prefix + \"fc2.weight\")))\n        load_param(param_prefix + \"ffn.layer_norm.scale\",\n                   load_array(load_prefix + \"final_layer_norm.weight\"))\n        load_param(param_prefix + \"ffn.layer_norm.bias\",\n                   load_array(load_prefix + \"final_layer_norm.bias\"))\n\n    return flax.core.freeze(params)\n\n\ndef get_jax_executable(config: OPTConfig,\n                       output_attentions: bool = False,\n                       output_hidden_states: bool = False):\n    \"\"\"Get a single-gpu executable.\"\"\"\n    # Note(Hao):\n    model, params = init_model_aval(config, total_input_len=256, total_cache_len=512)\n\n    @jax.jit\n    def inference_step(params, batch):\n        output = model.apply(params,\n                             batch[\"input_ids\"],\n                             batch[\"position_ids\"],\n                             attention_cache=batch[\"cache\"],\n                             )\n        return output.logits\n\n    # executables = {}\n    # for length in encoder_chunk_sizes:\n    #     executables[length] = inference_step\n    return inference_step, params\n"
  },
  {
    "path": "examples/llm_serving/model/opt_utils.py",
    "content": "from functools import partial\n\nimport jax\nfrom jax import xla, jit\nfrom jax.core import Primitive\nfrom jax._src.lib import xla_client as xc\nfrom transformers.generation_utils import dataclass\n\n\ndef sync(device_id=0):\n    jax.devices()[device_id].synchronize_all_activity()\n    return\n\n\n@dataclass\nclass TransformerModelConfig:\n    # hidden size\n    H: int = 768\n    # number of layers\n    L: int = 12\n    # number of attention heads\n    n_head: int = 12\n    seq_len: int = 2048\n    vocab_size: int = 50272\n\n\ndef compute_gpt_tflops_inference_with_padding(batch_size, gen_len, seq_len,\n                                              num_layers, hidden_size,\n                                              vocab_size, num_gpus, latency):\n    \"\"\"This calculation assumes that each code decoded attend to seq_len number tokens.\"\"\"\n    factor = 24\n    total_flop = factor * batch_size * gen_len * (hidden_size ** 2) * num_layers * \\\n          (1 + seq_len / (6 * hidden_size)) \\\n          + 2 * batch_size * gen_len * hidden_size * vocab_size\n    # Note (Hao): it should be 4 here because of input embedding, but we will\n    # respect Deepak's eq. instead.\n    tflops = total_flop / latency / num_gpus / 1e12\n    return tflops\n\n\ndef is_power_of_two(n):\n    return (n != 0) and (n & (n-1) == 0)\n\n\nindex_select_p = Primitive(\"index-select\")\n\n\n@partial(jit, static_argnums=(2,))\ndef jax_index_select(input, index, dim=0):\n    return index_select_p.bind(input, index, dim=dim)\n\n\ndef _index_select_eval(input, index, dim):\n    return input\n\n\ndef _index_select_translation(c, input, index, dim):\n    return xc.ops.IndexSelect(input, index, dim)\n\n\nindex_select_p.def_abstract_eval(_index_select_eval)\nindex_select_p.def_impl(partial(xla.apply_primitive, index_select_p))\nxla.translations[index_select_p] = _index_select_translation\n"
  },
  {
    "path": "examples/llm_serving/model/test_cache.py",
    "content": "\"\"\"Test the correctness of cache implementation.\"\"\"\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom alpa.testing import assert_allclose\nfrom llm_serving.model.opt_model import (get_opt_config, init_model_aval,\n                                         inference_step_no_cache,\n                                         init_cache_np,\n                                         build_position_ids,\n                                         load_params_np)\n\n\ndef print_params(params, prefix=\"\"):\n    for key, value in params.items():\n        if isinstance(value, dict):\n            print_params(value, prefix=prefix + key + \".\")\n        else:\n            print(prefix + key, value.shape)\n\n\ndef test_opt_125M(decompose_input):\n    print(\"Testing cache with decompose_input=%s\" % decompose_input)\n    name = \"125M\"\n    config = get_opt_config(name, dtype=jnp.float32)\n    np_weights_folder = f\"/home/ubuntu/opt_weights/{name}_np\"\n    batch_size = 1\n\n    # Init model\n    input_ids = np.array([[5625, 16, 10, 2721, 183, 8, 38, 236, 7]],\n                         dtype=np.int32)\n    input_ids = np.tile(input_ids, [batch_size, 1])\n    position_ids = build_position_ids(input_ids, config.pad)\n    print(\"input_ids\", input_ids)\n\n    model, params = init_model_aval(config)\n    params = load_params_np(params, np_weights_folder, config)\n    params = jax.tree_map(jnp.array, params)\n\n    # Get expected results\n    logits_no_cache = inference_step_no_cache(params, {\n        \"input_ids\": input_ids,\n        \"position_ids\": position_ids,\n    }, model.apply)\n    print(\"logits_no_cache\", logits_no_cache)\n\n    # JIT\n    @jax.jit\n    def inference_step_with_cache(params, batch):\n        print(\"traced\")\n        output = model.apply(params,\n                             batch[\"input_ids\"],\n                             batch[\"position_ids\"],\n                             attention_cache=batch[\"cache\"])\n        return output.logits, output.attention_cache\n\n    cache = init_cache_np(config, input_ids.shape[0])\n\n    if decompose_input:\n        # Decompose input so that all input lengths are one.\n        for i in range(input_ids.shape[1]):\n            input_ids_step = input_ids[:, i:i + 1]\n            position_ids_step = np.full_like(input_ids_step, i + config.pad + 1)\n            logits_step, cache = inference_step_with_cache(\n                params, {\n                    \"input_ids\": input_ids_step,\n                    \"position_ids\": position_ids_step,\n                    \"cache\": cache\n                })\n            assert_allclose(logits_step, logits_no_cache[:, i:i + 1])\n    else:\n        # Same as inference_step_no_cache that has input length > 1.\n        logits_step, cache = inference_step_with_cache(\n            params, {\n                \"input_ids\": input_ids,\n                \"position_ids\": position_ids,\n                \"cache\": cache\n            })\n        assert_allclose(logits_step, logits_no_cache)\n\n\nif __name__ == \"__main__\":\n    test_opt_125M(False)\n    test_opt_125M(True)\n"
  },
  {
    "path": "examples/llm_serving/model/wrapper.py",
    "content": "\"\"\"Wrap models to make them compatible with huggingface's generator API.\"\"\"\nimport time\nfrom collections import defaultdict\nfrom typing import Sequence, Any, Optional, List\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport os\nimport torch\nfrom llm_serving.model import opt_model, bloom_model, codegen_model\nfrom llm_serving.model.opt_utils import (TransformerModelConfig,\n                                         jax_index_select)\nfrom tqdm import tqdm\nfrom transformers import OPTForCausalLM, BloomForCausalLM, CodeGenForCausalLM\nfrom transformers.generation_utils import GenerationMixin, ModelOutput, dataclass\n\nimport alpa\nfrom alpa.device_mesh import DistributedArray\nfrom alpa.mesh_executable import get_index_select_mesh_executable\n\n\n@dataclass\nclass InferenceFuncOutput(ModelOutput):\n    logits: Any = None\n    past_key_values: Any = None\n    hidden_states: Any = None\n    attentions: Any = None\n\n\n@dataclass\nclass InferenceFuncConfig:\n    \"\"\"Implements a minimal config class for using huggingface's generator.\n    Note: these parameters might be overwritten by model.generate(**kwargs).\n    \"\"\"\n    bos_token_id: int = 0\n    num_beams: int = 1\n    num_beam_groups: int = 1\n    length_penalty: float = 1.0\n    repetition_penalty: float = 1.0\n    early_stopping: bool = False\n    num_return_sequences: int = 1\n    pad_token_id: int = 1\n    eos_token_id: int = 2\n    unk_token_id: int = 0\n    output_scores: bool = False\n    output_attentions: bool = False\n    output_hidden_states: bool = False\n    return_dict_in_generate: bool = False\n    is_encoder_decoder: bool = False\n    min_length: bool = 0\n    no_repeat_ngram_size: int = 0\n    encoder_no_repeat_ngram_size: int = 0\n    bad_words_ids: Sequence = None\n    diversity_penalty: float = 0.0\n    forced_bos_token_id: int = None\n    forced_eos_token_id: int = None\n    remove_invalid_values: bool = False\n    exponential_decay_length_penalty: float = None\n    do_sample: bool = False\n    top_k: int = 50\n    top_p: int = 1.0\n    typical_p: int = 1.0\n    temperature: float = 1.0\n    suppress_tokens: Optional[List[int]] = None\n    begin_suppress_tokens: Optional[List[int]] = None\n    forced_decoder_ids: Optional[List[int]] = None\n\n\nclass WrappedInferenceFunc(GenerationMixin):\n    \"\"\"\n    Wrap an inference func as a GenerationMixin.\n    This class implements the minimal interface for using huggingface's generator.\n    \"\"\"\n\n    def __init__(self, inference_func, config, executable, transformer_config, device):\n        self.inference_func = inference_func\n        self.config = config\n        self.main_input_name = \"input_ids\"\n        self.executable = executable  # An alpa executable\n        self.transformer_config = transformer_config\n        self.index_select_executables = {}\n        self.cache_location = None\n        self.device = device\n\n    def forward(self, attention_mask):\n        # This function is never used\n        raise NotImplementedError()\n\n    def prepare_inputs_for_generation(self, input_ids, attention_mask,\n                                      past=None, **kwargs):\n        # If past is defined, it means we are in the decoding stage,\n        # so we only process the last token\n        if past:\n            input_ids = input_ids[:, -1].unsqueeze(-1)\n\n        ret = {\"input_ids\": input_ids, \"past_key_values\": past,\n               \"attention_mask\": attention_mask}\n        return ret\n\n    def __call__(self,\n                 input_ids,\n                 past_key_values=None,\n                 output_attentions=None,\n                 output_hidden_states=None,\n                 attention_mask=None,\n                 return_dict=None):\n        ret = self.inference_func(input_ids,\n                                  past_key_values,\n                                  attention_mask=attention_mask,\n                                  output_hidden_states=output_hidden_states,\n                                  output_attentions=output_attentions)\n        return ret\n\n    def _reorder_cache(self, past, beam_idx):\n        # Reorder cache for beam search\n\n        # PyTorch\n        if hasattr(past[0][0], \"index_select\"):\n            return tuple(\n                tuple(\n                    past_state.index_select(0, beam_idx)\n                    for past_state in layer_past)\n                for layer_past in past)\n\n        # Jax (single-device)\n        if not isinstance(past[0][0], DistributedArray):\n            beam_idx = jnp.array(beam_idx.to(\"cpu\").numpy())\n            return tuple(\n                tuple(\n                    jax_index_select(past_state, beam_idx, 0)\n                    for past_state in layer_past)\n                for layer_past in past)\n\n        # Alpa\n        mesh_groups = defaultdict(list)\n        if self.cache_location is None:\n            self.cache_location = []\n            for layer_past in past:\n                tmp_loc = []\n                for past_state in layer_past:\n                    assert isinstance(past_state, DistributedArray)\n                    mesh = past_state.device_mesh\n                    mesh_groups[mesh].append(past_state)\n                    tmp_loc.append((mesh, len(mesh_groups[mesh]) - 1))\n                self.cache_location.append(tmp_loc)\n        else:\n            for layer_past in past:\n                for past_state in layer_past:\n                    assert isinstance(past_state, DistributedArray)\n                    mesh = past_state.device_mesh\n                    mesh_groups[mesh].append(past_state)\n\n        beam_idx = beam_idx.to(\"cpu\").numpy()\n\n        def grouped_reorder_cache(arys, device_mesh):\n            if len(arys) == 0:\n                return []\n            if device_mesh in self.index_select_executables:\n                executable = self.index_select_executables[device_mesh]\n            else:\n                dim = 0\n                avals = [ary.aval for ary in arys]\n                specs = [ary.sharding_spec for ary in arys]\n                executable = get_index_select_mesh_executable(\n                    avals, specs, beam_idx, dim, device_mesh,\n                    [False] * len(avals))\n                self.index_select_executables[device_mesh] = executable\n            ret = executable(*arys, beam_idx)\n            for v in ret:\n                v.skip_shard_args_check = True\n            return ret\n\n        results = {\n            mesh: grouped_reorder_cache(mesh_groups[mesh], mesh)\n            for mesh in mesh_groups\n        }\n\n        return tuple(\n            tuple(results[mesh][loc]\n                  for mesh, loc in layer_loc)\n            for layer_loc in self.cache_location)\n\n\ndef get_hf_model(model_name, device):\n    \"\"\"Get a huggingface model.\"\"\"\n    disable_torch_init()\n    if \"opt\" in model_name:\n        model_class = OPTForCausalLM\n    elif \"bloom\" in model_name:\n        model_class = BloomForCausalLM\n    elif \"codegen\" in model_name:\n        model_class = CodeGenForCausalLM\n    else:\n        raise ValueError(f\"Invalid model name: {model_name}\")\n\n    model = model_class.from_pretrained(\n        model_name,\n        torch_dtype=torch.float16 if \"cuda\" in device else torch.float32)\n    model = model.to(device)\n    restore_torch_init()\n\n    def inference_func(input_ids,\n                       past_key_values,\n                       attention_mask,\n                       output_attentions,\n                       output_hidden_states):\n        out = model(input_ids=input_ids,\n                        past_key_values=past_key_values,\n                        attention_mask=attention_mask,\n                        output_attentions=output_attentions,\n                        output_hidden_states=output_hidden_states)\n        return InferenceFuncOutput(out.logits, out.past_key_values)\n\n    inference_func_config = InferenceFuncConfig()\n    for key in inference_func_config.__dataclass_fields__.keys():\n        if hasattr(model.config, key):\n            setattr(inference_func_config, key, getattr(model.config, key))\n    if hasattr(model.config, \"max_position_embeddings\"):\n        seq_len = model.config.max_position_embeddings\n    else:\n        seq_len = 2048\n\n    transformer_config = TransformerModelConfig(\n        H=model.config.hidden_size,\n        L=model.config.num_hidden_layers,\n        n_head=model.config.num_attention_heads,\n        seq_len=seq_len,\n        vocab_size=model.config.vocab_size)\n    executable = None\n    return WrappedInferenceFunc(inference_func, inference_func_config,\n                                executable, transformer_config, torch.device(device))\n\n\ndef get_alpa_model(model_name: str,\n                   # Weights\n                   path: str,\n                   dummy: bool = False,\n                   # Batch size and seq length\n                   batch_size: int = 1,\n                   num_micro_batches: int = 1,\n                   max_seq_len: int = 2048,\n                   encoder_chunk_sizes: Sequence[int] = (1, 64),\n                   num_pp_stages: Optional[int] = None,\n                   # Model parameters\n                   dtype=jnp.float16,\n                   torch_device: str = \"cpu\",\n                   # Shared arguments with model.generate\n                   do_sample: bool = False,\n                   num_beams: int = 1,\n                   num_return_sequences: int = 1,\n                   return_dict_in_generate: bool = True,\n                   output_attentions: bool = False,\n                   output_hidden_states: bool = False):\n    \"\"\"Get a alpa-based model that is compatible with HuggingFace's generation API.\"\"\"\n    if num_micro_batches > 1:\n        raise NotImplementedError()\n    assert return_dict_in_generate\n\n    if 1 not in encoder_chunk_sizes:\n        encoder_chunk_sizes += [1]\n    encoder_chunk_sizes = list(set(encoder_chunk_sizes))\n    encoder_chunk_sizes.sort()\n\n    # weight path\n    name = model_name.split(\"/\")[1].lower()\n    path = os.path.abspath(os.path.expanduser(os.path.join(path, f\"{name}-np\")))\n    if not dummy:\n        # Download weights if there is no cached weights.\n        if not os.path.exists(path):\n            if name in [\"opt-175b\"]:\n                raise ValueError(f\"Cannot find cached weights under '{path}'. \"\n                                  \"Please follow the instructions to download \"\n                                  \"and convert weights manually. \")\n            print(f\"Cannot find cached weights under '{path}'.\")\n            download_weights(model_name.split(\"/\")[1], path)\n\n        # Do some sanity check\n        assert os.path.exists(path), f\"No such file or directory: '{path}'\"\n        if \"opt\" in name:\n            embed_weight = os.path.join(path, \"decoder.embed_tokens.weight\")\n        elif \"bloom\" in name:\n            embed_weight = os.path.join(path, \"word_embeddings.weight\")\n        elif \"codegen\" in name:\n            embed_weight = os.path.join(path, \"wte.weight\")\n        assert os.path.exists(embed_weight), f\"No such file or directory: '{embed_weight}'\"\n\n    # Figure out the actual input size\n    if do_sample:\n        batch_size = batch_size * num_beams * num_return_sequences\n    else:\n        if num_return_sequences > num_beams:\n            raise ValueError(\n                \"`num_return_sequences` has to be smaller or equal to `num_beams`.\"\n            )\n        batch_size = batch_size * num_beams\n\n    if \"jax\" in model_name:\n        if \"opt\" in model_name:\n            m = opt_model\n        elif \"bloom\" in model_name:\n            m = bloom_model\n        elif \"codegen\" in model_name:\n            m = codegen_model\n        config = m.get_config(name,\n                              num_pp_stages=None,\n                              mark_boundary=False,\n                              dtype=dtype,\n                              max_seq_len=max_seq_len)\n        transformer_config = TransformerModelConfig(\n            H=config.hidden_size,\n            L=config.num_hidden_layers,\n            n_head=config.n_head,\n            seq_len=config.max_seq_len,\n            vocab_size=config.vocab_size)\n\n        executables, params_aval = m.get_jax_executable(\n            config, encoder_chunk_sizes,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states)\n\n        # load params\n        params = m.load_params_np(params_aval, path, config, dummy)\n        init_cache = m.init_cache_np(config, batch_size=batch_size)\n        params, init_cache = jax.tree_map(jnp.array, (params, init_cache))\n    elif \"alpa\" in model_name:\n        if \"opt\" in model_name:\n            m = opt_model\n        elif \"bloom\" in model_name:\n            m = bloom_model\n        elif \"codegen\" in model_name:\n            m = codegen_model\n\n        alpa.init()\n\n        print(\n            f\"Load model {model_name} ... \"\n            f\"(This can take several minutes for very large models)\"\n        )\n\n        if num_pp_stages is None:\n            num_pp_stages = max(2, alpa.get_global_cluster().num_hosts)\n            num_pp_stages = min(num_pp_stages,\n                                alpa.get_global_cluster().num_devices)\n        config = m.get_config(name,\n                              num_pp_stages=num_pp_stages,\n                              dtype=dtype,\n                              max_seq_len=max_seq_len)\n        transformer_config = TransformerModelConfig(\n            H=config.hidden_size,\n            L=config.num_hidden_layers,\n            n_head=config.n_head,\n            seq_len=config.max_seq_len,\n            vocab_size=config.vocab_size)\n\n        print(f\" - Compile executables for encoder_chunk_sizes={encoder_chunk_sizes}. \",\n              end=\"\", flush=True)\n        tic = time.time()\n        executables, params_aval = m.get_pipeshard_executable(\n            config,\n            batch_size=batch_size,\n            num_micro_batches=num_micro_batches,\n            encoder_chunk_sizes=encoder_chunk_sizes,\n            output_attentions=output_attentions,\n            output_hidden_states=output_hidden_states)\n        print(f\"elapsed: {time.time() - tic:.2f} second.\")\n\n        # Load params\n        print(\" - Load parameters. \", end=\"\", flush=True)\n        tic = time.time()\n        params = m.load_multi_executable_params_dis_array(\n            path, executables, params_aval, config, dummy)\n\n        init_cache = m.init_multi_executable_cache_dis_array(\n            executables, config, batch_size, dummy=dummy)\n        set_skip_shard_args_check(init_cache)\n\n        for executable in executables.values():\n            executable.sync()\n        print(f\"elapsed: {time.time() - tic:.2f} second.\")\n    else:\n        raise ValueError(f\"Invalid model name: {model_name}\")\n\n    num_valid_tokens = None\n    last_token = None\n    step_ct = 0\n\n    def inference_func(input_ids,\n                       past_key_values,\n                       attention_mask,\n                       output_attentions,\n                       output_hidden_states):\n        assert input_ids.shape[0] == batch_size, (\n            f\"Expect batch size = {batch_size}, but got {input_ids.shape[0]}\")\n        input_ids = input_ids.cpu().numpy()\n        attention_mask = attention_mask.cpu().numpy()\n\n        def run_one(_executable, _input_ids, _past_key_values, _attention_mask, num_internal_pad):\n            nonlocal num_valid_tokens\n            nonlocal last_token\n            nonlocal step_ct\n\n            if _past_key_values is None:\n                # Init all states\n                _past_key_values = init_cache\n                num_valid_tokens = np.zeros((batch_size, 1), dtype=np.int32)\n                last_token = np.zeros((batch_size, 1), dtype=np.int32)\n                step_ct = 0\n\n            if _input_ids.shape[1] == 1:\n                # A fast path for step_len = 1\n                cum_sum = _attention_mask[:, -1:]\n                num_valid_tokens = num_valid_tokens + cum_sum\n                position_ids_step = num_valid_tokens + config.pad\n                last_token = np.where(cum_sum, _input_ids, last_token)\n                _input_ids = last_token\n            else:\n                # A general path that works for any step_len\n                cumsum = np.cumsum(_attention_mask[:,step_ct:], axis=1, dtype=np.int32)\n                position_ids_step = num_valid_tokens + cumsum + config.pad\n                num_valid_tokens_step = cumsum[:,-1:]\n                num_valid_tokens = num_valid_tokens + num_valid_tokens_step\n\n                last_token = np.where(num_valid_tokens_step > 0,\n                     np.take_along_axis(_input_ids, num_valid_tokens_step - 1, axis=1),\n                     last_token)\n                _input_ids = np.where(_attention_mask[:, step_ct:], _input_ids, last_token)\n\n            if num_internal_pad:\n                # Use value \"2\" as a special mask to represent internal padding\n                _attention_mask[:,-num_internal_pad:] = 2\n            _attention_mask = pad_attention_mask(_attention_mask, max_seq_len)\n\n            output = _executable(\n                params, {\n                    \"input_ids\": _input_ids,\n                    \"position_ids\": position_ids_step,\n                    \"cache\": _past_key_values,\n                    \"mask\": _attention_mask,\n                })\n\n            step_ct += _input_ids.shape[1] - num_internal_pad\n            set_skip_shard_args_check(output.attention_cache)\n\n            return output\n\n        seq_len = input_ids.shape[1]\n        if seq_len == 1:  # A fast path for seq_len = 1\n            output = run_one(executables[1], input_ids, past_key_values, attention_mask, 0)\n        else:  # A general path that works for all seq_len\n            i = 0\n            while i < seq_len:\n                remaining = seq_len - i\n                step_len = get_padded_step_len(remaining, encoder_chunk_sizes)\n\n                step_input_ids = input_ids[:, i:i + step_len]\n                step_attention_mask = (\n                    attention_mask[:, :attention_mask.shape[1] - remaining + step_len])\n\n                if step_input_ids.shape[1] != step_len:\n                    # Pad the inputs and masks to step_len\n                    # Note that this kind of internal padding is different from\n                    # the padding added by the tokenizer. This internal padding\n                    # should not update cache and step_ct\n                    num_internal_pad = step_len - step_input_ids.shape[1]\n                    pad_shape = (batch_size, num_internal_pad)\n                    step_input_ids = np.concatenate(\n                        (step_input_ids, np.zeros(pad_shape, dtype=np.int32)), axis=1)\n                    step_attention_mask = np.concatenate(\n                        (step_attention_mask, np.zeros(pad_shape, dtype=np.int8)), axis=1)\n                else:\n                    num_internal_pad = 0\n\n                output = run_one(executables[step_len], step_input_ids,\n                                 past_key_values, step_attention_mask,\n                                 num_internal_pad)\n                past_key_values = output.attention_cache\n                i += step_input_ids.shape[1]\n\n        logits_step = torch.from_numpy(np.array(output.logits)).to(torch_device).float()\n        return InferenceFuncOutput(logits_step, output.attention_cache,\n                                   output.hidden_states, output.attentions)\n\n    inference_func_config = InferenceFuncConfig()\n    if \"bloom\" in model_name:\n        inference_func_config.bos_token_id = 1\n        inference_func_config.eos_token_id = 2\n        inference_func_config.pad_token_id = 3\n        inference_func_config.unk_token_id = 0\n    elif \"codegen\" in model_name:\n        inference_func_config.bos_token_id = 1\n        inference_func_config.eos_token_id = 50256\n        inference_func_config.pad_token_id = 50256\n    return WrappedInferenceFunc(inference_func,\n                                inference_func_config,\n                                executables[1],\n                                transformer_config,\n                                torch.device(torch_device))\n\n\ndef get_model(model_name: str,\n              # Weights\n              path: str,\n              dummy: bool = False,\n              # Batch size and seq length\n              batch_size: int = 1,\n              num_micro_batches: int = 1,\n              max_seq_len: int = 2048,\n              encoder_chunk_sizes: Sequence[int] = (1, 64),\n              num_pp_stages: Optional[int] = None,\n              # Model parameters\n              dtype=jnp.float16,\n              torch_device: str = \"cpu\",\n              # Shared arguments with model.generate\n              do_sample: bool = False,\n              num_beams: int = 1,\n              num_return_sequences: int = 1,\n              return_dict_in_generate: bool = True,\n              output_attentions: bool = False,\n              output_hidden_states: bool = False):\n    \"\"\"Get a model that is compatible with HuggingFace's generation API.\n    Args:\n        model_name: \"facebook/opt-\", or \"alpa/opt-\".\n        path: The path to opt weights.\n        dummy: Use dummy weights for faster debugging.\n        batch_size: The batch size.\n        num_micro_batches: The number of micro batch sizs in pipeline\n          parallelism.\n        max_seq_len: The max sequence length.\n        encoder_chunk_sizes: Compile mutliple executables with different\n          chunk sizes. These executables are used to encoding prompts\n          chunk by chunk.\n        num_pp_stages: The number of pipeline parallelism stages.\n        dtype: The type of parameters.\n        torch_device: \"cpu\" or \"gpu\". This only controls the device used\n          by pytorch. Alpa always runs on GPU.\n        other parameters: shared with huggingface's model.generate API.\n    \"\"\"\n    if \"facebook/opt\" in model_name or \"bigscience/bloom\" in model_name or \"Salesforce/codegen\" in model_name:\n        return get_hf_model(model_name, torch_device)\n    elif (\"jax/opt\" in model_name or \"alpa/opt\" in model_name or\n          \"jax/bloom\" in model_name or \"alpa/bloom\" in model_name or\n          \"jax/codegen\" in model_name or \"alpa/codegen\" in model_name):\n        return get_alpa_model(\n              model_name,\n              path,\n              dummy,\n              batch_size,\n              num_micro_batches,\n              max_seq_len,\n              encoder_chunk_sizes,\n              num_pp_stages,\n              dtype,\n              torch_device,\n              do_sample,\n              num_beams,\n              num_return_sequences,\n              return_dict_in_generate,\n              output_attentions,\n              output_hidden_states)\n    else:\n        raise ValueError(f\"Invalid model name: {model_name}\")\n\n\ndef get_padded_step_len(length, encoder_chunk_sizes):\n    \"\"\"For a given length, find the smallest value in encoder_chunk_sizes that\n    is greater than the given length.\"\"\"\n    for i in range(len(encoder_chunk_sizes)):\n        if encoder_chunk_sizes[i] >= length:\n            break\n    return encoder_chunk_sizes[i]\n\n\ndef set_skip_shard_args_check(attention_cache):\n    \"\"\"\n    Skip the check in DistributedPhysicalDeviceMesh::shard_args for\n    attention cache. We need this hack because attention_cache is\n    a batch var but alpa doesn't implement a fast path for batch vars.\n    \"\"\"\n    if isinstance(attention_cache[0], alpa.device_mesh.DistributedArray):\n        for x in attention_cache:\n            x.skip_shard_args_check = True\n    else:\n        for y in attention_cache:\n            for x in y:\n                if isinstance(x, alpa.device_mesh.DistributedArray):\n                    x.skip_shard_args_check = True\n\n\ndef pad_attention_mask(mask, max_seq_len):\n    \"\"\"Pad attention mask to the shape [B, 1, 1, max_seq_len]. \"\"\"\n    batch_size = mask.shape[0]\n    ret_mask = np.zeros((batch_size, max_seq_len), dtype=np.int8)\n    ret_mask[:, :mask.shape[-1]] = mask\n    ret_mask = ret_mask[:, np.newaxis, np.newaxis, :]\n    return ret_mask\n\n\ndef download_weights(model_name, path):\n    \"\"\"Download weights from huggingface.\"\"\"\n    if \"opt\" in model_name:\n        hf_model_name = \"facebook/\" + model_name\n        model_class = OPTForCausalLM\n    elif \"bloom\" in model_name:\n        hf_model_name = \"bigscience/\" + model_name\n        model_class = BloomForCausalLM\n    elif \"codegen\" in model_name:\n        hf_model_name = \"Salesforce/\" + model_name\n        model_class = CodeGenForCausalLM\n\n    print(f\"Load the pre-trained pytorch weights of {model_name} from huggingface. \"\n          f\"The downloading and cpu loading can take dozens of minutes. \"\n          f\"If it seems to get stuck, you can monitor the progress by \"\n          f\"checking the memory usage of this process.\")\n\n    disable_torch_init()\n    model = model_class.from_pretrained(hf_model_name, torch_dtype=torch.float16,\n                                        _fast_init=True)\n    restore_torch_init()\n\n    os.makedirs(path, exist_ok=True)\n\n    print(f\"Convert the weights to alpa format under {path} ...\")\n    if \"opt\" in model_name:\n        for name, param in tqdm(list(model.model.named_parameters())):\n            name = name.replace(\"decoder.final_layer_norm\", \"decoder.layer_norm\")\n            param_path = os.path.join(path, name)\n            with open(param_path, \"wb\") as f:\n                np.save(f, param.cpu().detach().numpy())\n    elif \"bloom\" in model_name:\n        for name, param in tqdm(list(model.transformer.named_parameters())):\n            param_path = os.path.join(path, name)\n            with open(param_path, \"wb\") as f:\n                np.save(f, param.cpu().detach().numpy())\n    elif \"codegen\" in model_name:\n        for name, param in tqdm(list(model.named_parameters())):\n            name = name.replace(\"transformer.\", \"\")\n            param_path = os.path.join(path, name)\n            with open(param_path, \"wb\") as f:\n                np.save(f, param.cpu().detach().numpy())\n\n\n\nglobal torch_linear_init_backup\nglobal torch_layer_norm_init_backup\n\n\ndef disable_torch_init():\n    \"\"\"\n    Disable the redundant torch default initialization to accelerate model creation.\n    \"\"\"\n    global torch_linear_init_backup\n    global torch_layer_norm_init_backup\n\n    torch_linear_init_backup = torch.nn.Linear.reset_parameters\n    setattr(torch.nn.Linear, \"reset_parameters\", lambda self: None)\n\n    torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters\n    setattr(torch.nn.LayerNorm, \"reset_parameters\", lambda self: None)\n\n\ndef restore_torch_init():\n    \"\"\"Rollback the change made by disable_torch_init.\"\"\"\n    setattr(torch.nn.Linear, \"reset_parameters\", torch_linear_init_backup)\n    setattr(torch.nn.LayerNorm, \"reset_parameters\", torch_layer_norm_init_backup)\n"
  },
  {
    "path": "examples/llm_serving/model/wrapper_1d.py",
    "content": "import logging\nimport time\nfrom typing import Union, List\n\nimport cupy\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport os\nimport torch\nimport tqdm\nfrom llm_serving.model import opt_model_1d\nfrom transformers import OPTForCausalLM, BloomForCausalLM\nfrom transformers.generation_utils import dataclass\n\nfrom alpa.timer import timers\nfrom examples.llm_serving.model import opt_model\nfrom examples.llm_serving.model.opt_model_1d import IterationLevelInputPool, unpad, \\\n    pad\nfrom examples.llm_serving.model.opt_utils import sync\nfrom examples.llm_serving.model.wrapper import disable_torch_init, restore_torch_init\n\nlogger = logging.getLogger(__name__)\nlogger.setLevel(logging.INFO)\n\n\n@dataclass\nclass InputPoolConfig:\n    \"\"\"The config for iterative-level input pool.\"\"\"\n    batch_size: int = 512\n    cache_size: int = 4096\n\n\nclass SequenceGenerator:\n    def __init__(self, executable, params, input_pool_config, model_config):\n        self.executable = executable\n        self.params = params\n        self.input_pool_config = input_pool_config\n        self.model_config = model_config\n        # some other attributes\n        self.pad = self.model_config.pad\n\n    def generate(self,\n                 input: Union[IterationLevelInputPool, List[List[int]], np.ndarray],\n                 max_length=None,\n                 max_new_tokens=None,\n                 do_sample=False,\n                 **kwargs):\n        if max_length == None and max_new_tokens == None:\n            raise RuntimeError(\"Please provide at least one of max_length and max_new_tokens.\")\n\n        if isinstance(input, IterationLevelInputPool):\n            raise NotImplementedError()\n        elif isinstance(input, (List, np.ndarray, torch.Tensor)):\n            unpadded_input = unpad(input)\n            return self.generate_by_batch(unpadded_input,\n                                          max_length=max_length,\n                                          max_new_tokens=max_new_tokens,\n                                          do_sample=do_sample)\n        else:\n            raise RuntimeError()\n\n    def generate_by_batch(self,\n                          input_ids: List[List[int]],\n                          max_length=None,\n                          max_new_tokens=None,\n                          do_sample=False):\n        input_pool = IterationLevelInputPool(self.input_pool_config,\n                                             self.model_config,\n                                             max_length=max_length,\n                                             max_new_tokens=max_new_tokens)\n        iter = 0\n        input_pool.enter_prompts(input_ids)\n        while not input_pool.is_finished():\n            # tic = time.time()\n            input, input_index, position_ids, logit_positions = input_pool.next()\n            # timers(\"enter\").suspend(sync)\n            batch = {\n                \"input_ids\": input,\n                \"position_ids\": position_ids,\n                \"cache\": input_pool.cache\n            }\n            # compute\n            # timers(\"compute\").start(sync)\n            logits = self.executable(self.params, batch)\n            # timers(\"compute\").suspend(sync)\n\n            # timers(\"generate\").start(sync)\n            if not do_sample:\n                generated_ids = self._generate_greedy(logits, logit_positions)\n            else:\n                raise NotImplementedError()\n            # timers(\"generate\").suspend(sync)\n\n            # timers(\"update\").start(sync)\n            input_pool.update(generated_ids)\n            # timers(\"update\").suspend(sync)\n            # elapsed = time.time() - tic\n            iter += 1\n            # print(f\"Iter {iter} takes {elapsed}\")\n\n        ret = input_pool.get_results()\n        padded_input = np.array(pad(ret))\n        latency = input_pool.get_latency()\n        return padded_input, latency\n\n    @staticmethod\n    def _generate_greedy(logits, positions):\n        # outputs = []\n        next_token = np.array(jnp.argmax(logits, axis=-1))\n        outputs = next_token[positions].tolist()\n        # for pos in positions:\n        #     outputs.append(int(next_token[pos]))\n        return outputs\n\n\ndef get_model(model_name: str,\n                path: str,\n                dummy: bool = False,\n                # batch size, this batch is #tokens\n                batch_size: int = 256,\n                max_seq_len: int = 2048,\n                cache_size: int = 4096,\n                # model parameters\n                dtype=jnp.float16,\n                # Shared arguments with model.generate\n                do_sample: bool = False):\n    \"\"\"Experimental 1D transformer implementation.\"\"\"\n    assert \"opt-1d\" in model_name, \"are you sure you want to use the experimental 1D version?\"\n    name = model_name.split(\"/\")[1].lower()\n    name = name.replace(\"-1d\", \"\")\n    path = os.path.abspath(os.path.expanduser(os.path.join(path, f\"{name}-np\")))\n    if not dummy:\n        # Download weights if there is no cached weights.\n        if not os.path.exists(path):\n            if name in [\"opt-175b\"]:\n                raise ValueError(f\"Cannot find cached weights under '{path}'. \"\n                                  \"Please follow the instructions to download \"\n                                  \"and convert weights manually. \")\n            print(f\"Cannot find cached weights under '{path}'.\")\n            download_weights(model_name.split(\"/\")[1], path)\n\n        # Do some sanity check\n        assert os.path.exists(path), f\"No such file or directory: '{path}'\"\n        if \"opt\" in name:\n            embed_weight = os.path.join(path, \"decoder.embed_tokens.weight\")\n        elif \"bloom\" in name:\n            embed_weight = os.path.join(path, \"word_embeddings.weight\")\n        assert os.path.exists(embed_weight), f\"No such file or directory: '{embed_weight}'\"\n    # TODO(Hao): figure out the actual input size\n    model_config = opt_model.get_config(name, dtype=dtype, max_seq_len=max_seq_len)\n    executable, params_aval = opt_model_1d.get_jax_executable(model_config)\n\n    # load params\n    # TODO(Hao): use the same func with 2D\n    params = opt_model_1d.load_params_np(params_aval, path, model_config, dummy)\n    params = jax.tree_map(jnp.array, params)\n\n    input_pool_config = InputPoolConfig(batch_size=batch_size,\n                                        cache_size=cache_size)\n\n    return SequenceGenerator(executable, params, input_pool_config, model_config)\n\n\ndef download_weights(model_name, path):\n    \"\"\"Download weights from huggingface.\"\"\"\n    if \"opt\" in model_name:\n        hf_model_name = \"facebook/\" + model_name\n        model_class = OPTForCausalLM\n    elif \"bloom\" in model_name:\n        hf_model_name = \"bigscience/\" + model_name\n        model_class = BloomForCausalLM\n\n    print(f\"Load the pre-trained pytorch weights of {model_name} from huggingface. \"\n          f\"The downloading and cpu loading can take dozens of minutes. \"\n          f\"If it seems to get stuck, you can monitor the progress by \"\n          f\"checking the memory usage of this process.\")\n\n    disable_torch_init()\n    model = model_class.from_pretrained(hf_model_name, torch_dtype=torch.float16,\n                                        _fast_init=True)\n    restore_torch_init()\n\n    os.makedirs(path, exist_ok=True)\n\n    print(f\"Convert the weights to alpa format under {path} ...\")\n    if \"opt\" in model_name:\n        for name, param in tqdm(list(model.model.named_parameters())):\n            name = name.replace(\"decoder.final_layer_norm\", \"decoder.layer_norm\")\n            param_path = os.path.join(path, name)\n            with open(param_path, \"wb\") as f:\n                np.save(f, param.cpu().detach().numpy())\n    elif \"bloom\" in model_name:\n        for name, param in tqdm(list(model.transformer.named_parameters())):\n            param_path = os.path.join(path, name)\n            with open(param_path, \"wb\") as f:\n                np.save(f, param.cpu().detach().numpy())\n"
  },
  {
    "path": "examples/llm_serving/scripts/step_2_consolidate_992_shards_to_singleton.py",
    "content": "\"\"\"Convert the 992 shards into 1 singleton (code adapted from Metaseq and fairscale).\"\"\"\nfrom typing import List, Dict, Any\nimport argparse\nimport gc\nimport logging\nimport os\nimport re\nimport time\nfrom collections import defaultdict, OrderedDict\nfrom glob import glob\nfrom pathlib import Path\nfrom tqdm import tqdm\n\nimport torch\nfrom llm_serving.scripts.utils import load_and_pop_last_optimizer_state\n\nlogger = logging.getLogger(__name__)\n\n\ndef _unpad(shard: torch.Tensor, pad: int) -> torch.Tensor:\n    if pad > 0:\n        shard = shard[:-pad]\n    return shard\n\n\ndef consolidate_shard_weights(\n        shard_weights: List[Dict[str, torch.Tensor]],\n        shard_metadata: List[Dict[str, Any]],\n        with_module_buffers: bool = True,\n        strict: bool = True,\n) -> Dict[str, torch.Tensor]:\n    \"\"\"\n    Given a list of weights and meta data associated to N shards, reconstruct\n    the weights of an equivalent consolidated (non-sharded) state dict.\n    Module parameters are consolidated using the shard metadata.\n    Module buffers are taken from shard 0: this assumes that module buffers\n    are either synchronized or that the shard 0 value is valid for all shards.\n    If this behavior is not correct for your module (for instance if buffers\n    needs to be all-reduced instead), you can disable it with `with_module_buffers=False`.\n    This method is used to re-assemble checkpoints of shards without\n    having to instantiate FSDP wrappers with the world size (i.e. large\n    number of GPUs) originally used to save the shards.\n    Args:\n        shard_weights (List[Dict[str, torch.Tensor]]):\n            List of dictionaries that contains sharded weights from\n            each rank.\n        shard_metadata (List[Dict[str, Any]]):\n            List of dictionaries that contains metadata from each shard.\n            See `local_metadata_dict` above.\n        with_module_buffers (bool):\n            If shard 0's buffer should be returned in the consolidated\n            weight dict.\n            Default: True.\n        strict (bool):\n            allow incomplete shard weights. if True, every key in the metadata must be present in the weights.\n    \"\"\"\n    if len(shard_weights) != len(shard_metadata) or not len(shard_weights):\n        raise ValueError(\"Require metadata for each shard and non-empty shards\")\n\n    consolidated_weights = {}\n    original_world_size = len(shard_weights)\n\n    # For every FSDP instance.\n    for fsdp_obj_idx, metadata in enumerate(shard_metadata[0][\"param_metadata\"]):\n        fsdp_path = metadata[\"fsdp_path\"]\n        params = metadata[\"params\"]\n        # For every this-FSDP-owned param, flattened or not.\n        for backing_param_name, v in params.items():\n            in_state_dict_key = \".\".join([fsdp_path, backing_param_name]) if fsdp_path else backing_param_name\n            # Get full param back with pad removed.\n            if in_state_dict_key not in shard_weights[0] and (not strict):\n                continue\n            shards = []\n            for rank in range(original_world_size):\n                shard = shard_weights[rank][in_state_dict_key]\n                pad = shard_metadata[rank][\"param_metadata\"][fsdp_obj_idx][\"params\"][backing_param_name][\"padding\"]\n                shards.append(_unpad(shard, pad))\n                if metadata[\"no_broadcast_optim_state\"]:\n                    break\n            full_param = torch.cat(shards, dim=0)\n            # (Potentially), split the full param and create original params.\n            names, shapes, numels, _ = v.values()\n            assert sum(numels) == full_param.size(0)\n            for n, t, s in zip(names, full_param.split(numels), shapes):\n                out_state_dict_key = \".\".join([fsdp_path, n]) if fsdp_path else n\n                consolidated_weights[out_state_dict_key] = t.view(s)\n\n    # copy shared parameters\n    for src_path, dest_path in metadata[\"shared_param_info\"]:\n        consolidated_weights[dest_path] = consolidated_weights[src_path]\n\n    # Deal with the buffers, which are not parameters and are not sharded by FSDP\n    # and therefore are replicated among the different shards.\n    # We take the values of the first shard (this assumes that there is some form\n    # of synchronization between shards or that all shards buffers are equivalent).\n    if with_module_buffers:\n        for buffer_name in shard_metadata[0][\"buffer_names\"]:\n            if buffer_name not in shard_weights[0] and (not strict):\n                continue\n            consolidated_weights[buffer_name] = shard_weights[0][buffer_name]\n\n    return consolidated_weights\n\n\ndef _get_shard_number(x) -> int:\n    match = re.search(r\"shard(\\d+).pt\", x)\n    if match is None:\n        raise AssertionError(f\"{x} did not match shard(\\\\d+).pt\")\n    else:\n        return int(match.groups()[0])\n\n\ndef consolidate_fsdp_shards(\n    pth_prefix: str,\n    save_prefix=None,\n    strict=False,\n    new_arch_name=None,\n    no_stitch_megatron=False,\n    megatron_part=None,\n) -> str:\n    if pth_prefix.endswith(\".pt\"):\n        pth_prefix = pth_prefix[:-3]\n    if save_prefix is None:\n        save_prefix = pth_prefix + \"_consolidated\"  # .pt'\n    all_ckpt_files = list(\n        sorted(glob(f\"{pth_prefix}*shard*.pt\"), key=_get_shard_number)\n    )\n    if megatron_part is not None:\n        no_stitch_megatron = True\n        all_ckpt_files = [\n            x for x in all_ckpt_files if f\"model_part-{megatron_part}\" in x\n        ]\n    assert all_ckpt_files, f\"no paths matched {pth_prefix}*shard*.pt\"\n    weights = []\n    metadata = []\n    expert_paths = []\n    expert_dest_paths = []\n    expert_ranks = []\n    names = []\n    dense = True\n    t0 = time.time()\n    for p in tqdm(all_ckpt_files):\n        names.append(Path(p).name)\n        if re.search(r\"rank-(\\d+)\", os.path.basename(p)):  # expert checkpoint\n            expert_paths.append(p)\n            r = re.search(r\"rank-(\\d+)\", os.path.basename(p)).groups()[0]\n            assert r not in expert_ranks\n            expert_ranks.append(r)\n            expert_dest_paths.append(f\"{save_prefix}-rank-{r}.pt\")\n        else:\n            ckpt = load_and_pop_last_optimizer_state(p)\n            weights.append(ckpt[\"model\"])\n            metadata.append(ckpt[\"shard_metadata\"])\n    assert weights, f\"all files were considered experts: {all_ckpt_files}\"\n    do_consolidate = True\n    if \"decoder.embed_tokens.weight\" in weights[0].keys():\n        shape = weights[0][\"decoder.embed_tokens.weight\"].shape\n        logger.info(\n            f\"This ckpt does not seem sharded. I see unflat params! like \"\n            f\"decoder.embed_tokens.weight shaped {shape}. Will just copy files \"\n            f\"and remove optim_state.\"\n        )\n        do_consolidate = False\n    if do_consolidate:\n        num_parts = find_num_parts(names)\n        if num_parts:\n            #consolidated_weights = consolidate_model_parallel(\n            #    metadata,\n            #    names,\n            #    strict,\n            #    weights,\n            #    parts=num_parts,\n            #    no_stitch_megatron=no_stitch_megatron,\n            #)\n            print(\"- Part 1: consolidate Zero-3 shards.\")\n            consolidated_weights = consolidate_model_parallel_part1(\n                metadata,\n                names,\n                strict,\n                weights,\n                parts=num_parts,\n                no_stitch_megatron=no_stitch_megatron,\n            )\n            del weights, metadata\n            gc.collect()\n            if not no_stitch_megatron:\n                print(\"- Part 2: consolidate model-parallel parts.\")\n                consolidated_weights = consolidate_model_parallel_part2(\n                    consolidated_weights)\n        else:\n            print(\"FSDP.consolidate_shard_weights\")\n            consolidated_weights = consolidate_shard_weights(\n                shard_weights=weights, shard_metadata=metadata, strict=strict\n            )\n        #del weights, metadata\n        #gc.collect()\n        done_consolidate = time.time()\n        print(f\"Done consolidating after {done_consolidate-t0//60} minutes\")\n    else:\n        consolidated_weights = weights[0]\n    if new_arch_name is not None:\n        ckpt[\"cfg\"][\"model\"]._name = new_arch_name\n    if dense:\n\n        def save_checkpoint(weights_to_save, prefix):\n            ckpt_consolidated = dict(\n                model=weights_to_save,\n                cfg=ckpt[\"cfg\"],\n                extra_state=ckpt[\"extra_state\"],\n                optimizer_history=ckpt[\"optimizer_history\"],\n                args=ckpt.get(\"args\"),\n            )\n            save_path = f\"{prefix}.pt\"\n            print(f\"- Saving to {save_path} ...\")\n            torch.save(ckpt_consolidated, save_path)\n            print(f\"Done saving after {(time.time() - t0) // 60} minutes\")\n            return save_path\n\n        if no_stitch_megatron:\n            saved_paths = []\n            for part_id, part_consolidated_weights in consolidated_weights.items():\n                saved_paths.append(\n                    save_checkpoint(\n                        part_consolidated_weights, f\"{save_prefix}-model_part-{part_id}\"\n                    )\n                )\n            return saved_paths\n        return save_checkpoint(consolidated_weights, save_prefix)\n\n    ckpt_shared = dict(\n        model=consolidated_weights,\n        cfg=ckpt[\"cfg\"],\n        extra_state=ckpt[\"extra_state\"],\n        optimizer_history=ckpt[\"optimizer_history\"],\n        args=ckpt[\"args\"],\n    )\n    print(\"saving..\")\n    torch.save(ckpt_shared, f\"{save_prefix}-shared.pt\")\n    print(f\"Done saving. Total time: {time.time()-t0//60} minutes,  \")\n    # Process experts\n    for src, dst in tqdm(\n        list(zip(expert_paths, expert_dest_paths)), desc=\"expert files\"\n    ):\n        ckpt = load_and_pop_last_optimizer_state(src)\n        if do_consolidate:\n            expert_wt = consolidate_shard_weights(\n                shard_weights=[ckpt[\"model\"]],\n                shard_metadata=[ckpt[\"shard_metadata\"]],\n                strict=False,\n            )\n            ckpt = dict(\n                model=expert_wt,\n                cfg=ckpt[\"cfg\"],\n                extra_state=ckpt[\"extra_state\"],\n                optimizer_history=ckpt[\"optimizer_history\"],\n                args=ckpt[\"args\"],\n            )\n\n        torch.save(ckpt, dst)\n    logger.info(f\"saved consolidated MoE with prefix {save_prefix}.pt\")\n    return f\"{save_prefix}.pt\"\n\n\ndef consolidate_model_parallel(\n    metadata, names, strict, weights, parts=2, no_stitch_megatron=False\n):\n    model_parts = defaultdict(list)\n    metadata_parts = defaultdict(list)\n    for i, n in enumerate(names):\n        for p in range(parts):\n            if f\"part-{p}\" in n:\n                model_parts[p].append(weights[i])\n                metadata_parts[p].append(metadata[i])\n    all_parts_consolidated = defaultdict(list)\n    for k, v in tqdm(model_parts.items()):\n        print(f\"Processing part: {k}, with {len(v)} shards...\")\n        part_weights = consolidate_shard_weights(\n            shard_weights=v, shard_metadata=metadata_parts[k], strict=strict\n        )\n        all_parts_consolidated[k] = part_weights\n    if no_stitch_megatron:\n        return all_parts_consolidated\n    model = glue_megatron_parts(all_parts_consolidated)\n    return model\n\n\ndef consolidate_model_parallel_part1(\n    metadata, names, strict, weights, parts=2, no_stitch_megatron=False\n):\n    model_parts = defaultdict(list)\n    metadata_parts = defaultdict(list)\n    for i, n in enumerate(names):\n        for p in range(parts):\n            if f\"part-{p}\" in n:\n                model_parts[p].append(weights[i])\n                metadata_parts[p].append(metadata[i])\n    all_parts_consolidated = defaultdict(list)\n    for k, v in tqdm(model_parts.items()):\n        print(f\"Consolidate shards associated with part: {k}, with {len(v)} shards...\")\n        part_weights = consolidate_shard_weights(\n            shard_weights=v, shard_metadata=metadata_parts[k], strict=strict\n        )\n        all_parts_consolidated[k] = part_weights\n    return all_parts_consolidated\n\n\ndef consolidate_model_parallel_part2(all_parts_consolidated):\n    model = glue_megatron_parts(all_parts_consolidated)\n    return model\n\ndef handle_qkv_proj(model_parts, key):\n    parts = [model_parts[part_id][key] for part_id in range(len(model_parts))]\n    ks, vs, qs = [], [], []\n    for p in parts:\n        k, v, q = torch.split(p, p.shape[0] // 3)\n        ks.append(k)\n        vs.append(v)\n        qs.append(q)\n    return torch.cat(ks, dim=0), torch.cat(vs, dim=0), torch.cat(qs, dim=0)\n\n\ndef _handle_one(parts, is_weight):\n    \"\"\"Make it look like a normal LayerNorm\"\"\"\n    n_parts = len(parts)\n    err_msg = f\"Redundant ModelParallelFusedLayerNorm params have been updated.\"\n    if is_weight:\n        init = 1.0\n        assert not torch.logical_and(parts[0].ne(1), parts[1].ne(1)).any(), err_msg\n\n    else:\n        init = 0.0\n        assert not torch.logical_and(parts[0].ne(0), parts[1].ne(0)).any(), err_msg\n    ret_val = torch.cat([p.unsqueeze(-1) for p in parts], dim=1).sum(1) - (\n        init * (n_parts - 1)\n    )\n    return ret_val\n\n\ndef handle_legacy_ln_(glued_model, n_parts):\n    \"\"\"Consolidate ffn_layernorm.lns.weight.{part_id} -> ffn_layernorm.weight\"\"\"\n    if \"decoder.layers.0.ffn_layernorm.lns.0.weight\" not in glued_model:\n        return\n    n_layers = get_n_layers(glued_model)\n    for i in range(n_layers):\n        layer_weights = [\n            glued_model.pop(f\"decoder.layers.{i}.ffn_layernorm.lns.{p}.weight\")\n            for p in range(n_parts)\n        ]\n        layer_biases = [\n            glued_model.pop(f\"decoder.layers.{i}.ffn_layernorm.lns.{p}.bias\")\n            for p in range(n_parts)\n        ]\n        glued_model[f\"decoder.layers.{i}.ffn_layernorm.weight\"] = _handle_one(\n            layer_weights, True\n        )\n        glued_model[f\"decoder.layers.{i}.ffn_layernorm.bias\"] = _handle_one(\n            layer_biases, False\n        )\n\n\ndef get_n_layers(glued_model):\n    n_layers = 0\n    while True:\n        if f\"decoder.layers.{n_layers}.fc1.weight\" in glued_model:\n            n_layers += 1\n        else:\n            assert (\n                n_layers > 0\n            ), f\"found 0 layers bc no keys matching decoder.layers.0.fc1.weight\"\n            return n_layers\n\n\ndef glue_megatron_parts(model_parts):\n    glued_model = OrderedDict()\n\n    def assert_all_close(key):\n        for part_id in range(len(model_parts)):\n            if not torch.allclose(model_parts[part_id][key], model_parts[0][key]):\n                err = (\n                    (model_parts[part_id][key] - model_parts[0][key])\n                    .float()\n                    .abs()\n                    .max()\n                    .item()\n                )\n                logger.info(f\"max discrepancy {key}: {err}\")\n\n    for key in model_parts[0]:\n        print(f\"Glue the key {key}...\")\n        if \"qkv\" in key:\n            # Bias of CP gets concatenated\n            if key.endswith(\"bias\"):\n                k, v, q = handle_qkv_proj(model_parts, key)\n            else:\n                assert key.endswith(\"weight\")\n                k, v, q = handle_qkv_proj(model_parts, key)\n            glued_model[key.replace(\"qkv\", \"k\")] = k\n            glued_model[key.replace(\"qkv\", \"v\")] = v\n            glued_model[key.replace(\"qkv\", \"q\")] = q\n        elif \"ffn_layernorm\" in key:\n            glued_model[key] = torch.cat(\n                [model_parts[part_id][key] for part_id in range(len(model_parts))]\n            )\n\n        elif \"layer_norm\" in key:\n            assert_all_close(key)\n            glued_model[key] = model_parts[0][key]\n        elif \"fc1\" in key or \"k_proj\" in key or \"q_proj\" in key or \"v_proj\" in key:\n            # Bias of CP gets concatenated\n            if key.endswith(\"bias\"):\n                glued_bias = torch.cat(\n                    [model_parts[part_id][key] for part_id in range(len(model_parts))]\n                )\n                glued_model[key] = glued_bias\n            # weights of CP gets concatenated along dim 0\n            else:\n                assert key.endswith(\"weight\")\n                glued_weight = torch.cat(\n                    [model_parts[part_id][key] for part_id in range(len(model_parts))],\n                    dim=0,\n                )\n                glued_model[key] = glued_weight\n                # FC1 is CP\n        # FC2 is RP\n        elif \"fc2\" in key or \"out_proj\" in key:\n            # Bias of RP gets replicated\n            if key.endswith(\"bias\"):\n                assert_all_close(key)\n                glued_model[key] = model_parts[0][key]\n            # weights of RP gets concatenated along dim 1\n            else:\n                assert key.endswith(\"weight\")\n                glued_weight = torch.cat(\n                    [model_parts[part_id][key] for part_id in range(len(model_parts))],\n                    dim=1,\n                )\n                glued_model[key] = glued_weight\n        elif \"embed_tokens.weight\" in key:\n            glued_weight = torch.cat(\n                [model_parts[part_id][key] for part_id in range(len(model_parts))],\n                dim=0,\n            )\n            glued_model[key] = glued_weight\n        elif \"embed_positions\" in key:\n            if \"_float_tensor\" in key:\n                # Assume embed positions are non learned ie.e sinusoidal\n                glued_model[key] = torch.zeros([1])\n            else:\n                assert_all_close(key)\n                glued_model[key] = model_parts[0][key]\n        elif \"version\" in key:\n            glued_model[key] = model_parts[0][key]\n        else:\n            assert_all_close(key)\n            glued_model[key] = model_parts[0][key]\n\n    assert len(glued_model.keys()) >= len(model_parts[0].keys())\n    # Consolidate ffn_layernorm.lns.weight.{part_id} -> ffn_layernorm.weight\n    handle_legacy_ln_(glued_model, len(model_parts))\n    assert \"decoder.layers.0.ffn_layernorm.lns.0.weight\" not in glued_model\n    print(\"- Done with consolidating model parallelism parts. See a summary below:\")\n    for key in glued_model:\n        print(f\"    key: {key}, shape: {glued_model[key].shape}\")\n    return glued_model\n\n\ndef find_num_parts(names) -> int:\n    parts = []\n    for n in names:\n        part = re.search(r\"part-(\\d+)-\", n)\n        if part is not None:\n            parts.append(int(part.groups()[0]))\n    if parts:\n        return max(parts) + 1\n    else:\n        return 0\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--read-prefix\", type=str, default=\"checkpoint_last\")\n    parser.add_argument(\"--save-prefix\", type=str, default=\"consolidated\")\n    parser.add_argument(\"--new-arch-name\", type=str, default=\"transformer_lm_gpt\")\n    args = parser.parse_args()\n    consolidate_fsdp_shards(args.read_prefix,\n                            save_prefix=args.save_prefix,\n                            new_arch_name=args.new_arch_name)\n"
  },
  {
    "path": "examples/llm_serving/scripts/step_3_convert_to_numpy_weights.py",
    "content": "\"\"\"Convert Metaseq's OPT model weights into Alpa numpy weights.\"\"\"\nimport time\n\nimport argparse\nimport os\n\nimport numpy as np\nfrom llm_serving.scripts.utils import torch_load_cpu\n\n\ndef save_numpy(weight_dict, to_folder):\n    os.makedirs(to_folder, exist_ok=True)\n    for tensor_name, tensor in weight_dict.items():\n        print(f\"- Writing tensor {tensor_name} with shape {tensor.shape}\")\n        t = tensor.cpu().detach().numpy()\n        with open(to_folder + \"/\" + tensor_name, \"wb\") as g:\n            np.save(g, t)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--ckpt-path\", type=str, default=\"/home/ubuntu/consolidated\")\n    parser.add_argument(\"--output-folder\", type=str, default=\"/home/ubuntu/opt-175b-np\")\n    args = parser.parse_args()\n    start_time = time.time()\n    print(\"- Reading the weight into memory\")\n    state = torch_load_cpu(args.ckpt_path)\n    print(f\"Done with reading: {time.time() - start_time} seconds\")\n    save_numpy(state[\"model\"], args.output_folder)\n"
  },
  {
    "path": "examples/llm_serving/scripts/utils.py",
    "content": "import torch\nfrom omegaconf.dictconfig import DictConfig\n\n\ndef recursively_cast_dictconfigs(cfg):\n    if isinstance(cfg, DictConfig):\n        return {k2: recursively_cast_dictconfigs(v2) for k2, v2 in cfg.items()}\n    else:\n        return cfg\n\n\ndef torch_load_cpu(path):\n    state = torch.load(path, map_location=torch.device(\"cpu\"))\n    # If model was trained with fp16, model from loaded state_dict can be moved to fp16\n    if not isinstance(state, dict):\n        return state\n    if \"cfg\" in state:\n        state[\"cfg\"] = recursively_cast_dictconfigs(state[\"cfg\"])\n        if (\n            state[\"cfg\"][\"common\"][\"fp16\"]\n            or state[\"cfg\"][\"common\"][\"memory_efficient_fp16\"]\n        ):\n            state[\"model\"] = {k: v.half() for k, v in state[\"model\"].items()}\n\n    return state\n\n\ndef load_and_pop_last_optimizer_state(pth):\n    st = torch_load_cpu(pth)\n    st.pop(\"last_optimizer_state\", None)\n    return st\n"
  },
  {
    "path": "examples/llm_serving/service/__init__.py",
    "content": ""
  },
  {
    "path": "examples/llm_serving/service/constants.py",
    "content": "\"\"\"Hyper params for serving Meta's OPT model.\"\"\"\nfrom enum import Enum\n\n# Alpa serve url\nALPA_SERVE_PORT = 20001\nALPA_SERVE_URL = f\"window.location.protocol + '//' + window.location.hostname + ':{ALPA_SERVE_PORT}/completions'\"\n#ALPA_SERVE_URL = f'\"completions\"'\n\n# Generation params\nNUM_BEAMS = 1\nNUM_RETURN_SEQ = 1\n\n# Authentication params\nUSE_RECAPTCHA = False\nUSE_API_KEYS = False\nALLOW_NON_KEY_ACCESS = True\nKEYS_FILENAME = \"/home/ubuntu/efs/alpa/examples/llm_serving/keys_file.json\"\n\n# Scheduler params\nclass AuthGroups(Enum):\n    RECAPTCHA_USER = 1\n    API_KEY_USER = 2\n    NON_KEY_USER = 3\n\nAUTH_GROUP_WEIGHTS = {\n    AuthGroups.RECAPTCHA_USER: 300,\n    AuthGroups.API_KEY_USER: 10,\n    AuthGroups.NON_KEY_USER: 1\n}\nAUTH_GROUP_SCHEDULER_SCALE = 300\nAPI_KEY_SCHEDULER_SCALE = 100\nAPI_KEY_DEFAULT_WEIGHT = 10\nLOGPROBS_PRIORITY_TIME_LIMIT_S = 15\n\n# Logging params\nLOGDIR = \"weblogs\"\n"
  },
  {
    "path": "examples/llm_serving/service/recaptcha.py",
    "content": "\"\"\"\nAdapted from https://github.com/mardix/flask-recaptcha\n\nThe new Google ReCaptcha implementation for Flask without Flask-WTF\nCan be used as standalone\n\"\"\"\n\n__NAME__ = \"Flask-ReCaptcha\"\n__version__ = \"0.5.0\"\n__license__ = \"MIT\"\n__author__ = \"Mardix\"\n__copyright__ = \"(c) 2015 Mardix\"\n\nimport json\n\n#from flask import request\ntry:\n    from jinja2 import Markup\nexcept ImportError:\n    from jinja2.utils import markupsafe\n    Markup = markupsafe.Markup\nimport requests\n\nfrom llm_serving.service.constants import USE_RECAPTCHA, KEYS_FILENAME\n\n\nclass DEFAULTS(object):\n    IS_ENABLED = True\n    THEME = \"light\"\n    TYPE = \"image\"\n    SIZE = \"normal\"\n    LANGUAGE = \"en\"\n    TABINDEX = 0\n\n\nclass ReCaptcha(object):\n\n    VERIFY_URL = \"https://www.recaptcha.net/recaptcha/api/siteverify\"\n\n    def __init__(self, app=None, site_key=None, secret_key=None, is_enabled=True, **kwargs):\n        if app:\n            self.init_app(app=app)\n        else:\n            self.site_key = site_key\n            self.secret_key = secret_key\n            self.is_enabled = is_enabled\n            self.theme = kwargs.get('theme', DEFAULTS.THEME)\n            self.type = kwargs.get('type', DEFAULTS.TYPE)\n            self.size = kwargs.get('size', DEFAULTS.SIZE)\n            self.language = kwargs.get('language', DEFAULTS.LANGUAGE)\n            self.tabindex = kwargs.get('tabindex', DEFAULTS.TABINDEX)\n\n    def init_app(self, app=None):\n        self.__init__(site_key=app.config.get(\"RECAPTCHA_SITE_KEY\"),\n                      secret_key=app.config.get(\"RECAPTCHA_SECRET_KEY\"),\n                      is_enabled=app.config.get(\"RECAPTCHA_ENABLED\", DEFAULTS.IS_ENABLED),\n                      theme=app.config.get(\"RECAPTCHA_THEME\", DEFAULTS.THEME),\n                      type=app.config.get(\"RECAPTCHA_TYPE\", DEFAULTS.TYPE),\n                      size=app.config.get(\"RECAPTCHA_SIZE\", DEFAULTS.SIZE),\n                      language=app.config.get(\"RECAPTCHA_LANGUAGE\", DEFAULTS.LANGUAGE),\n                      tabindex=app.config.get(\"RECAPTCHA_TABINDEX\", DEFAULTS.TABINDEX))\n\n        @app.context_processor\n        def get_code():\n            return dict(recaptcha=self.get_code())\n\n    def get_code(self):\n        \"\"\"\n        Returns the new ReCaptcha code\n        :return:\n        \"\"\"\n        raw = \"\" if not self.is_enabled else (\"\"\"\n        <script src='//www.recaptcha.net/recaptcha/api.js?hl={LANGUAGE}'></script>\n        <div class=\"g-recaptcha\" data-sitekey=\"{SITE_KEY}\" data-theme=\"{THEME}\" data-type=\"{TYPE}\" data-size=\"{SIZE}\"\\\n         data-tabindex=\"{TABINDEX}\"></div>\n        \"\"\".format(SITE_KEY=self.site_key, THEME=self.theme, TYPE=self.type, SIZE=self.size, LANGUAGE=self.language, TABINDEX=self.tabindex))\n        return Markup(raw)\n\n    def verify(self, response=None, remote_ip=None):\n        if self.is_enabled:\n            data = {\n                \"secret\": self.secret_key,\n                \"response\": response,# or request.json.get('g-recaptcha-response', \"\"),\n                \"remoteip\": remote_ip,# or request.environ.get('REMOTE_ADDR')\n            }\n\n            r = requests.get(self.VERIFY_URL, params=data)\n            return r.json()[\"success\"] if r.status_code == 200 else False\n        return True\n\n\ndef load_recaptcha(use_recaptcha):\n    if use_recaptcha:\n        keys = json.load(open(KEYS_FILENAME, \"r\"))\n        recaptcha = ReCaptcha(site_key=keys[\"RECAPTCHA_SITE_KEY\"],\n                              secret_key=keys[\"RECAPTCHA_SECRET_KEY\"])\n    else:\n        recaptcha = ReCaptcha(is_enabled=False)\n    return recaptcha\n"
  },
  {
    "path": "examples/llm_serving/service/scheduler.py",
    "content": "import asyncio\nimport heapq\nfrom collections import deque, OrderedDict\n\n\nclass WeightedRoundRobin:\n    \"\"\"\n    Scheduler that cycles between queues of different weightings.\n    The interface is the same as it were a queue implemented using deque().\n    This implementation extends the original algorithm by allowing non-integer\n    priorities. All weights in this class are implicitly divided by a scale\n    factor - if all the queue weights are integer multiples of the scale\n    factor, the algorithm behaves just like standard weighted round robin.\n    Using smaller weights makes the scheduler switch between queues more\n    frequently, improving latency.\n    \"\"\"\n    # The scheduling algorithm is implemented using an event list. Each queue\n    # is associated with an hourglass that fills up a certain fraction every\n    # time step. When the hourglass is filled, a task is scheduled from the\n    # corresponding queue. An hourglass is allowed to be filled faster than\n    # 100% per time step - in this case, tasks are consecutively scheduled\n    # from the same queue until the hourglass is no longer full.\n\n    class Hourglass:\n        def __init__(self, update_time, amnt_filled):\n            self.update_time = update_time\n            self.amnt_filled = amnt_filled\n            self.linked_tasks = deque()\n\n        def __repr__(self):\n            return '({}, {}, {})'.format(\n                self.update_time, self.amnt_filled, list(self.linked_tasks))\n\n    def __init__(self, weights, scale, default_weight=None,\n                 max_empty_hourglasses=100):\n        self.weights = weights\n        self.default_weight = default_weight\n        self.scale = scale\n        self.max_empty_hourglasses = max_empty_hourglasses\n        self.curr_item_num = 0\n        self.curr_simulated_time = 0\n        self.tasks = {}\n        self.hourglasses = {}\n        self.event_list = []\n        self.empty_hourglasses = OrderedDict()\n\n    def __len__(self):\n        return len(self.tasks)\n\n    def append(self, name_and_item):\n        queue_name, item = name_and_item\n        self.tasks[self.curr_item_num] = item\n        new_event = False\n        if queue_name in self.empty_hourglasses:\n            self.hourglasses[queue_name] = self.empty_hourglasses[queue_name]\n            del self.empty_hourglasses[queue_name]\n            new_event = True\n        if queue_name not in self.hourglasses:\n            self.hourglasses[queue_name] = \\\n                WeightedRoundRobin.Hourglass(0, 0)\n            new_event = True\n        hourglass = self.hourglasses[queue_name]\n        hourglass.linked_tasks.append(self.curr_item_num)\n        if new_event:\n            hourglass.update_time = self.curr_simulated_time\n            self.__add_new_event(hourglass, queue_name)\n        self.curr_item_num += 1\n\n    def extend(self, items):\n        for item in items:\n            self.append(item)\n\n    def popleft(self):\n        event_entry = heapq.heappop(self.event_list)\n        queue_name = event_entry[2]\n        hourglass = self.hourglasses[queue_name]\n        if hourglass.amnt_filled >= self.scale:\n            hourglass.amnt_filled -= self.scale\n        else:\n            self.curr_simulated_time = event_entry[0]\n            weight = self.weights.get(queue_name, self.default_weight)\n            if weight is None:\n                raise KeyError\n            hourglass.amnt_filled += (\n                self.curr_simulated_time - hourglass.update_time) * weight\n            hourglass.amnt_filled -= self.scale\n        hourglass.update_time = self.curr_simulated_time\n        task_num = hourglass.linked_tasks.popleft()\n        task = self.tasks.pop(task_num)\n        if len(hourglass.linked_tasks) == 0:\n            del self.hourglasses[queue_name]\n            self.empty_hourglasses[queue_name] = hourglass\n            if len(self.empty_hourglasses) > self.max_empty_hourglasses:\n                self.empty_hourglasses.popitem(last=False)\n        else:\n            self.__add_new_event(hourglass, queue_name)\n        return (queue_name, task)\n\n    def __add_new_event(self, hourglass, queue_name):\n        if hourglass.amnt_filled >= self.scale:\n            event_time = self.curr_simulated_time\n            event_entry = (event_time, hourglass.linked_tasks[0], queue_name)\n            heapq.heappush(self.event_list, event_entry)\n        else:\n            weight = self.weights.get(queue_name, self.default_weight)\n            if weight is None:\n                raise KeyError\n            time_to_full = (\n                self.scale - hourglass.amnt_filled + weight - 1) // weight\n            event_time = self.curr_simulated_time + time_to_full\n            event_entry = (event_time, hourglass.linked_tasks[0], queue_name)\n            heapq.heappush(self.event_list, event_entry)\n\n    def verify_state(self):\n        \"\"\"Checks the invariants of the class\"\"\"\n        task_nums = []\n        try:\n            assert len(self.event_list) == 0 or \\\n                self.curr_simulated_time <= self.event_list[0][0]\n            for queue_name, hourglass in self.hourglasses.items():\n                assert len(hourglass.linked_tasks) > 0\n                for task_num in hourglass.linked_tasks:\n                    assert task_num in self.tasks\n                assert hourglass.amnt_filled >= 0\n                assert queue_name not in self.empty_hourglasses\n                task_nums += list(hourglass.linked_tasks)\n                if hourglass.amnt_filled >= self.scale:\n                    assert self.event_list[0][0] == self.curr_simulated_time\n                    assert self.curr_simulated_time == hourglass.update_time\n            for hourglass in self.empty_hourglasses.values():\n                assert len(hourglass.linked_tasks) == 0\n                assert hourglass.amnt_filled >= 0\n            assert sorted(task_nums) == sorted(list(self.tasks.keys()))\n        except AssertionError as e:\n            e.args += (repr(self),)\n            raise e\n\n    def __repr__(self):\n        return \"Tasks: {}\\nEvent list: {}\\nHourglasses: {}\\nTime: {}\".format(\n            self.tasks, self.event_list, self.hourglasses,\n            self.curr_simulated_time)\n\n\nclass NestedScheduler:\n    \"\"\"\n    Scheduler where each queue is an independent inner scheduler object.\n    This can be used to implement hierarchies of weights and queues.\n    \"\"\"\n    def __init__(self, outer_scheduler, inner_schedulers):\n        self.outer_scheduler = outer_scheduler\n        self.inner_schedulers = inner_schedulers\n\n    def __len__(self):\n        return len(self.outer_scheduler)\n\n    def append(self, name_and_item):\n        name, item = name_and_item\n        self.outer_scheduler.append((name, None))\n        self.inner_schedulers[name].append(item)\n\n    def extend(self, items):\n        for item in items:\n            self.append(item)\n\n    def popleft(self):\n        name = self.outer_scheduler.popleft()[0]\n        return (name, self.inner_schedulers[name].popleft())\n\n    def __repr__(self):\n        return '\\n'.join(\n            ['Outer: ' + repr(self.outer_scheduler)] +\n            [repr(name) + ': ' + repr(s)\n             for (name, s) in self.inner_schedulers.items()])\n\n\nclass FrontQueueScheduler:\n    \"\"\"\n    Scheduler decorator that allows tasks to be placed at the front of the\n    queue. The front behaves like the front of a deque(), i.e. it is LIFO.\n    \"\"\"\n    def __init__(self, scheduler):\n        self.scheduler = scheduler\n        self.front_queue = deque()\n\n    def __len__(self):\n        return len(self.front_queue) + len(self.scheduler)\n\n    def append(self, item):\n        self.scheduler.append(item)\n\n    def extend(self, items):\n        for item in items:\n            self.append(item)\n\n    def popleft(self):\n        if len(self.front_queue) > 0:\n            return self.front_queue.popleft()\n        return self.scheduler.popleft()\n\n    def appendleft(self, item):\n        self.front_queue.appendleft(item)\n\n    def extendleft(self, items):\n        self.front_queue.extendleft(items)\n\n    def __repr__(self):\n        return \"Front queue:{}\\n{}\".format(self.front_queue, self.scheduler)\n\n\nclass AsyncWrapper:\n    \"\"\"\n    Decorator that makes a scheduler object behave like an asyncio.Queue().\n    \"\"\"\n    def __init__(self, scheduler):\n        self.schedule_waitlist = asyncio.Queue()\n        self.scheduler = scheduler\n\n    @property\n    def maxsize(self):\n        return 0\n\n    def qsize(self):\n        return len(self.scheduler) + self.schedule_waitlist.qsize()\n\n    def empty(self):\n        return len(self.scheduler) == 0 and self.schedule_waitlist.empty()\n\n    def full(self):\n        return False\n\n    async def put(self, item):\n        await self.schedule_waitlist.put((item, None))\n\n    def put_nowait(self, item):\n        self.schedule_waitlist.put_nowait((item, None))\n\n    async def get(self):\n        if self.empty():\n            self.__process_waitlist_item(await self.schedule_waitlist.get())\n        while not self.schedule_waitlist.empty():\n            self.__process_waitlist_item(\n                self.schedule_waitlist.get_nowait())\n        return self.scheduler.popleft()\n\n    def get_nowait(self):\n        if self.empty():\n            raise asyncio.QueueEmpty\n        while not self.schedule_waitlist.empty():\n            self.__process_waitlist_item(self.schedule_waitlist.get_nowait())\n        return self.scheduler.popleft()\n\n    def __process_waitlist_item(self, waitlist_item):\n        data, strategy = waitlist_item\n        if strategy is None:\n            self.scheduler.append(data)\n        else:\n            strategy(self.scheduler, data)\n\n    def task_done(self):\n        self.scheduler_waitlist.task_done()\n\n    async def join(self):\n        await self.scheduler_waitlist.join()\n\n    def put_nowait_special(self, strategy, data):\n        \"\"\"Must add exactly one item into the schedule\"\"\"\n        self.schedule_waitlist.put_nowait((data, strategy))\n\n    def __repr__(self):\n        return repr(self.scheduler)\n"
  },
  {
    "path": "examples/llm_serving/service/static/index.html",
    "content": "<html lang=\"en\">\n<head>\n    <meta charset=\"utf-8\">\n    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\">\n    <title>Serving OPT-175B Language Model with Alpa</title>\n\n    <link rel=\"icon\" type=\"image/x-icon\" href=\"https://raw.githubusercontent.com/alpa-projects/alpa/main/docs/logo/alpa-logo.ico\">\n    <script src=\"//code.jquery.com/jquery-1.11.0.min.js\"></script>\n    <script src=\"https://ajax.googleapis.com/ajax/libs/jquery/3.4.1/jquery.min.js\"></script>\n    <script src=\"https://cdn.jsdelivr.net/npm/bootstrap@5.2.0-beta1/dist/js/bootstrap.bundle.min.js\" integrity=\"sha384-pprn3073KE6tl6bjs2QrFaJGz5/SUsLqktiwsUTF55Jfv3qYSDhgCecCxMW52nD2\" crossorigin=\"anonymous\"></script>\n    <script async defer src=\"https://buttons.github.io/buttons.js\"></script>\n    <script src=\"https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.min.js\"></script>\n    <script src=\"https://superal.github.io/canvas2image/canvas2image.js\"></script>\n    <link href=\"https://cdn.jsdelivr.net/npm/bootstrap@5.2.0-beta1/dist/css/bootstrap.min.css\" rel=\"stylesheet\" integrity=\"sha384-0evHe/X+R7YkIZDRvuzKMRqM+OrBnVFBL6DOitfPri4tjfHxaWutUpFmBp4vmVor\" crossorigin=\"anonymous\">\n    <link rel=\"stylesheet\" href=\"https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.1.1/css/all.min.css\">\n    <link rel=\"stylesheet\" href=\"https://cdnjs.cloudflare.com/ajax/libs/bootstrap-social/5.1.1/bootstrap-social.css\">\n\n    <link rel=\"stylesheet\" type=\"text/css\" href=\"//cdn.jsdelivr.net/npm/slick-carousel@1.8.1/slick/slick.css\"/>\n    <script src=\"https://cdnjs.cloudflare.com/ajax/libs/slick-carousel/1.9.0/slick.min.js\"></script>\n\n    <script type=\"text/javascript\">\n    // these constants are only used for providing user expectations.\n    var OVERHEAD = 1;\n    var PROMPT_TOKEN_PER_SECOND = 40;\n    var DECODING_TOKEN_PER_SECOND = 4;\n\n    // examples for the user\n    var EXAMPLES = {\n        \"fact\": {\n            \"prompt\": \"Question: Where were the 2004 Olympics held?\\nAnswer: Athens, Greece\\n\\nQuestion: What is the longest river on the earth?\\nAnswer:\",\n            \"length\": 64\n        },\n        \"chatbot\": {\n            \"prompt\": \"A chat between a curious human and the Statue of Liberty.\\n\\n\" +\n                \"Human: What is your name?\\nStatue: I am the Statue of Liberty.\\n\" +\n                \"Human: Where do you live?\\nStatue: New York City.\\n\" +\n                \"Human: How long have you lived there?\",\n            \"length\": 64\n        },\n        \"airport\": {\n            \"prompt\": \"Extract the airport codes from this text.\\n\\n\" +\n                \"Text: \\\"I want a flight from New York to San Francisco.\\\"\\n\" +\n                \"Airport codes: JFK, SFO.\\n\\n\" +\n                \"Text: \\\"I want you to book a flight from Phoenix to Las Vegas.\\\"\\n\" +\n                \"Airport codes:\",\n            \"length\": 64\n        },\n        \"translation\": {\n            \"prompt\": \"English: I want to go home.\\nChinese: 我想回家。\\n\\n\" +\n                      \"English: I don't know.\\nChinese: 我不知道。\\n\\n\" +\n                      \"English: I am hungry.\\nChinese:\",\n            \"length\": 64\n        },\n        \"cryptocurrency\": {\n            \"prompt\": \"Every year, cryptocurrency experts prepare forecasts for the price of Dogecoin. In 2025, it is estimated that DOGE will\",\n            \"length\": 64\n        },\n        \"programming\": {\n            \"prompt\":\n                \"def fib(k):\\n\" +\n                \"    \\\"\\\"\\\"Returns the k-th Fibonacci number. Check the corner cases.\\\"\\\"\\\"\",\n            \"length\": 64\n        },\n        \"math\": {\n            \"prompt\": \"Question: If x is 2 and y is 5, what is x + 2y?\\n\" +\n                      \"Answer: x + 2y = 2 + 2(5) = 2 + 10 = 12\\n\\n\" +\n                      \"Question: If x is 8 and y is 9, what is 3x + y?\\n\" +\n                      \"Answer: 3x + y = 3(8) + 9 = 24 + 9 = 33\\n\\n\" +\n                      \"Question: If x is 7 and y is 6, what is x + 4y?\\n\" +\n                      \"Answer:\",\n            \"length\": 64\n        }\n    };\n\n    function getFormData($form) {\n        var unindexed_array = $form.serializeArray();\n        var indexed_array = {};\n        $.map(unindexed_array, function(n, i){\n            indexed_array[n['name']] = n['value'].replace(\"\\r\\n\", \"\\n\");\n        });\n        indexed_array['model'] = \"default\"\n        return indexed_array;\n    }\n\n    function set_prompt(name) {\n        $(\"#length_slider\").val(EXAMPLES[name][\"length\"]);\n        $(\"#length_slider_output\").text(EXAMPLES[name][\"length\"]);\n        $(\"#textbox\").val(EXAMPLES[name][\"prompt\"]);\n    }\n\n    function takeshot() {\n      let div = document.getElementById('generation');\n      html2canvas(div).then(\n      function (canvas) {\n            // return Canvas2Image.saveAsPNG(canvas);\n                    var url = canvas.toDataURL();\n                      $(\"<a>\", {\n                        href: url,\n                        download: \"my-opt175b-result\"\n                      })\n                      .on(\"click\", function() {$(this).remove()})\n                      .appendTo(\"body\")[0].click()\n        });\n    }\n\n    function test() {\n        $(\"#promptDisplay\").text(\"A chat between a professor and a graduate student in Computer Science.\\n\\nStudent: Which is the best Computer Science graduate school in the US? UC Berkeley or CMU?\\nProfessor: \");\n        $(\"#promptDisplay\").text(\"def fib(n):\\n\" +\n                \"    Returns n-th Fibonacci number.\"\n        )\n        $(\"#response\").text(\"Sorry I don't know\\n\");\n        $(\"#error\").text(\"\");\n    }\n\n    $(document).ready(function() {\n      $('.logo-carousel').slick({\n        slidesToShow: 4,\n        slidesToScroll: 1,\n        autoplay: true,\n        autoplaySpeed: 5000,\n        arrows: true,\n        dots: false,\n        pauseOnHover: false,\n        responsive: [{\n          breakpoint: 768,\n          settings: {\n            slidesToShow: 4\n          }\n        }, {\n          breakpoint: 520,\n          settings: {\n            slidesToShow: 2\n          }\n        }]\n      });\n    });\n\n    // actual logic\n    $(document).ready(function() {\n      $(\"#generate-form\").submit(function(event) {\n        event.preventDefault();\n        var prompt_length = $(\"#textbox\").val().split(' ').length;\n        var length = parseInt($(\"#length_slider\").val());\n        var eta = (prompt_length / PROMPT_TOKEN_PER_SECOND + length / DECODING_TOKEN_PER_SECOND + OVERHEAD).toFixed(1);\n        $(\"#eta\").text(eta);\n        $(\"#loader_holder\").css(\"visibility\", \"visible\");\n        $(\"#generate-form-button\").prop(\"disabled\", true);\n        $(\"#error\").text(\"\");\n        var formData = getFormData($(\"form\"));\n        var submitData = JSON.stringify(formData)\n        console.log(\"submitData:\");\n        console.log(submitData);\n        $.ajax({\n            url: {{alpa_serve_url | safe}},\n            type: \"POST\",\n            processData: true,\n            contentType: \"application/json\",\n            data: submitData,\n            complete: function () {\n                $(\"#loader_holder\").css(\"visibility\", \"hidden\");\n                $(\"#generate-form-button\").prop(\"disabled\", false);\n                grecaptcha.reset();\n            },\n            success: function (data) {\n                console.log(\"Response:\");\n                console.log(data);\n                for (let i = 0; i < data[\"choices\"].length; ++i) {\n                  $(\"#promptDisplay\", \"#result\" + i + \"-content\").text(formData[\"prompt\"]);\n                  $(\"#response\", \"#result\" + i + \"-content\").text(data[\"choices\"][i][\"text\"]);\n                  $(\"#error\", \"#result\" + i + \"-content\").text(\"\");\n                }\n            },\n            error: function (xhr) {\n                console.log(\"Error:\");\n                console.log(xhr);\n                $(\"#promptDisplay\").text(\"\");\n                $(\"#response\").text(\"\");\n                if (\"responseJSON\" in xhr) {\n                  msg = \"Error: \" + xhr.responseJSON.message;\n                  if (msg.includes(\"No replica of model\") ||\n                      msg.includes(\"is not registered\") ||\n                      msg.includes(\"object has no attribute\")) {\n                    msg += \"\\nThe server is probably under regular maintenance. \" +\n                           \"Please come back later.\";\n                  }\n                  $(\"#error\").text(msg);\n                } else {\n                  $(\"#error\").text(\n                      \"Cannot connect to the server due to unknown errors. \" +\n                      \"\\nThe server is probably under regular maintenance. \" +\n                      \"Please come back later.\");\n                }\n            }\n        });\n      });\n    });\n    </script>\n</head>\n\n<style>\n/***** For logo slider *****/\n.slick-slide {\n  margin: 0px 20px;\n}\n\n.logo-carousel {\n  overflow: inherit;\n  margin-top: 32px;\n  /*border-top: 1px solid #353535;*/\n  /*border-bottom: 1px solid #353535;*/\n}\n\n.slick-slide img {\n  width: 100%;\n}\n\n.slick-loading {\n  visibility: hidden;\n}\n\n.slick-slide.slick-loading img {\n  display: none;\n}\n\n.slick-slide.dragging img {\n  pointer-events: none;\n}\n\n.slick-loading .slick-slide {\n  visibility: hidden;\n}\n\n.slick-arrow {\n  position: absolute;\n  top: 50%;\n  background: url(https://raw.githubusercontent.com/solodev/infinite-logo-carousel/master/images/arrow.svg?sanitize=true) center no-repeat;\n  color: #fff;\n  filter: invert(77%) sepia(32%) saturate(1%) hue-rotate(344deg) brightness(105%) contrast(103%);\n  border: none;\n  width: 2rem;\n  height: 1.5rem;\n  text-indent: -10000px;\n  margin-top: -16px;\n  z-index: 99;\n}\n\n.slick-arrow.slick-next {\n  right: -40px;\n  transform: rotate(180deg);\n}\n\n.slick-arrow.slick-prev {\n  left: -40px;\n}\n\n@media (max-width: 768px) {\n  .slick-arrow {\n    width: 1rem;\n    height: 1rem;\n  }\n}\n\n.row {\n  overflow: hidden;\n}\n\n/***** Prompt and result display *****/\n.result-block {\n    white-space: pre-wrap;\n    word-wrap: break-word;\n    clear: both;\n    min-height: 10em;\n}\n#promptDisplay {\n    font-weight: 600;\n}\n#error {\n    color: red;\n}\n\n/***** Loader *****/\n#loader_holder {\n    visibility: hidden;\n}\n\n#loaderInline {\n  display: inline-block;\n  vertical-align:middle;\n  width: 20px;\n  height: 20px;\n  margin-left: 5px;\n  margin-right: 5px;\n  border: 2px solid #f3f3f3;\n  border-radius: 50%;\n  border-top: 2px solid #3498db;\n  -webkit-animation: spin 2s linear infinite;\n  animation: spin 2s linear infinite;\n}\n\n@-webkit-keyframes spin {\n  0% {\n    -webkit-transform: rotate(0deg);\n  }\n  100% {\n    -webkit-transform: rotate(360deg);\n  }\n}\n\n@keyframes spin {\n  0% {\n    transform: rotate(0deg);\n  }\n  100% {\n    transform: rotate(360deg);\n  }\n}\n\n.animate-bottom {\n  position: relative;\n  -webkit-animation-name: animatebottom;\n  -webkit-animation-duration: 1s;\n  animation-name: animatebottom;\n  animation-duration: 1s\n}\n\n@-webkit-keyframes animatebottom {\n  from {\n    bottom: -100px;\n    opacity: 0\n  }\n  to {\n    bottom: 0px;\n    opacity: 1\n  }\n}\n\n@keyframes animatebottom {\n  from {\n    bottom: -100px;\n    opacity: 0\n  }\n  to {\n    bottom: 0;\n    opacity: 1\n  }\n}\n</style>\n\n<!--  <body class=\"d-flex h-100 text-center text-dark bg-dark\">-->\n<body class=\"bg-white wy-text-center\">\n  <div class=\"container\">\n    <header class=\"d-flex flex-wrap justify-content-center py-2 mb-4 border-bottom\">\n     <!--\n      <a href=\"/\" class=\"d-flex align-items-center mb-3 mb-md-0 me-md-auto text-dark text-decoration-none\">\n        <img alt=\"alpa logo\" class=\"bi me-2\" width=\"40\" src=\"https://raw.githubusercontent.com/alpa-projects/alpa/main/docs/logo/alpa-logo-cropped.svg\">\n      </a>\n      -->\n      <ul class=\"nav nav-pills\">\n        <li class=\"nav-item\"><a href=\"#generation\" class=\"nav-link\" aria-current=\"page\">Generation</a></li>\n        <li class=\"nav-item\"><a href=\"#faq\" class=\"nav-link\">FAQs</a></li>\n        <li class=\"nav-item\"><a href=\"#contact\" class=\"nav-link\">Contact</a></li>\n        <li class=\"nav-item\"><a href=\"https://github.com/alpa-projects/alpa\" class=\"nav-link\" target=\"_blank\"><i class=\"fa-brands fa-github\"></i> GitHub</a></li>\n      </ul>\n    </header>\n  </div>\n\n  <div class=\"container my-5 text-center\">\n      <div class=\"py-5\">\n        <img alt=\"alpa logo\" width=\"200\" src=\"https://raw.githubusercontent.com/alpa-projects/alpa/main/docs/logo/alpa-logo-cropped.svg\">\n      </div>\n      <h1 class=\"display-2\">Large Model for Everyone</h1>\n      <p class=\"lead mb-4\">\n          Alpa is a system for training and serving gigantic machine learning models.\n          <br  \\>\n          Alpa makes training and serving large models like GPT-3 simple, affordable, accessible to everyone.\n      </p>\n      <div class=\"pt-2 pb-4\">\n        <iframe src=\"https://ghbtns.com/github-btn.html?user=alpa-projects&repo=alpa&type=star&count=true&size=large\"  width=\"170\" height=\"35\" title=\"GitHub\"></iframe>\n      </div>\n\n      <div class=\"d-grid gap-4 d-sm-flex justify-content-sm-center\">\n        <a href=\"#generation\" class=\"btn btn-primary px-4 btn-lg\">Try Live Generation (OPT-175B)</a>\n        <a href=\"https://alpa-projects.github.io/tutorials/opt_serving.html\" class=\"btn btn-outline-primary px-4 btn-lg\" target=\"_blank\">Host Your Own Service (OPT, BLOOM, CodeGen)</a>\n      </div>\n  </div>\n\n<!--<div class=\"bg-white\" id=\"generation\">-->\n<div class=\"container py-5\" id=\"generation\">\n    <div class=\"p-4 mb-3 bg-light rounded-4\">\n    <p>\n        <svg xmlns=\"http://www.w3.org/2000/svg\" fill=\"gold\" class=\"bi bi-lightning-fill\" style=\"width:5%;\" viewBox=\"0 0 20 20\">\n            <path d=\"M5.52.359A.5.5 0 0 1 6 0h4a.5.5 0 0 1 .474.658L8.694 6H12.5a.5.5 0 0 1 .395.807l-7 9a.5.5 0 0 1-.873-.454L6.823 9.5H3.5a.5.5 0 0 1-.48-.641l2.5-8.5z\"/>\n        </svg>\n        <strong class=\"display-6 fw-bold\">Free, Unlimited OPT-175B Text Generation</strong>\n    </p>\n    <p> <strong>Warning</strong>: This model might generate something offensive. No safety measures are in place as a free service. </p>\n<!--    <p id=\"examples\"> <strong>Examples: </strong> </p>-->\n    <div class=\"gap-2 d-sm-flex justify-content-sm-center\" style=\"line-height: 250%\">\n    <a type=\"button\" class=\"btn btn-outline-primary\" href='javascript:set_prompt(\"fact\");'><i class=\"fa-brands fa-wikipedia-w\"></i> Fact</a>\n    <a type=\"button\" class=\"btn btn-outline-secondary\" href='javascript:set_prompt(\"chatbot\");'><i class=\"fa-solid fa-robot\"></i> Chatbot</a>\n    <a type=\"button\" class=\"btn btn-outline-success\" href='javascript:set_prompt(\"airport\");'><i class=\"fa-solid fa-plane-departure\"></i> Airport Code</a>\n    <a type=\"button\" class=\"btn btn-outline-danger\" href='javascript:set_prompt(\"translation\");'><i class=\"fa-solid fa-language\"></i> Translation</a>\n    <a type=\"button\" class=\"btn btn-outline-warning\" href='javascript:set_prompt(\"cryptocurrency\");'><i class=\"fa-brands fa-bitcoin\"></i> Cryptocurrency</a>\n    <a type=\"button\" class=\"btn btn-outline-info\" href='javascript:set_prompt(\"programming\");'><i class=\"fa-solid fa-rocket\"></i> Code</a>\n    <a type=\"button\" class=\"btn btn-outline-dark\" href='javascript:set_prompt(\"math\");'><i class=\"fa-solid fa-calculator\"></i> Math</a>\n    </div>\n\n<form method=\"POST\" action=\"/generate\" id=\"generate-form\">\n    <div class=\"my-3\">\n        <label for=\"textbox\" class=\"form-label\"></label>\n        <textarea class=\"form-control\" style=\"font-size: 20px;\" name=\"prompt\" rows=\"8\" id=\"textbox\" placeholder=\"Type the prompts here\"></textarea>\n    </div>\n\n    <div class=\"form-group row\" data-html2canvas-ignore=\"true\">\n    <label for=\"length_slider\" class=\"col col-form-label text-end fw-bold\" style=\"white-space: nowrap;\">Response Length:</label>\n    <div class=\"col my-2\">\n        <input type=\"range\" value=\"64\" min=\"32\" max=\"256\" step=\"32\" class=\"form-range\"\n            oninput=\"this.parentNode.nextElementSibling.value = this.value\" name=\"max_tokens\"\n            id='length_slider'>\n    </div>\n    <output class='col col-form-label' id=\"length_slider_output\">64</output>\n    </div>\n\n    <div class=\"form-group row\" data-html2canvas-ignore=\"true\" style=\"{{sampling_css}}\">\n        <label for=\"temperature_slider\" class=\"col col-form-label text-end fw-bold\">Temperature:</label>\n        <div class=\"col my-2\">\n            <input type=\"range\" value=\"0.7\" min=\"0.0\" max=\"1.0\" step=\"0.10\" class=\"form-range\"\n            oninput=\"this.parentNode.nextElementSibling.value = this.value\" name=\"temperature\" id=\"temperature_slider\">\n        </div>\n            <output class='col col-form-label'>0.7</output>\n    </div>\n\n    <div class=\"form-group row\" data-html2canvas-ignore=\"true\" style=\"{{sampling_css}}\">\n        <label for=\"topp_slider\" class=\"col col-form-label text-end fw-bold\">Top-p:</label>\n        <div class=\"col my-2\">\n        <input type=\"range\" value=\"0.7\" min=\"{{ '0.1' if num_return_sequences > 1 else '0.0' }}\" max=\"1.0\" step=\"0.1\" class=\"form-range\"\n            oninput=\"this.parentNode.nextElementSibling.value = this.value\" name=\"top_p\" id=\"topp_slider\">\n        </div>\n       <output class='col col-form-label'>0.7</output>\n    </div>\n\n    <div>\n        {{ recaptcha }}\n        <input class=\"btn btn-primary btn-lg mt-2\" type=\"submit\" value=\"Generate\" id=\"generate-form-button\"/>\n        <div id=\"loader_holder\" style=\"display:inline; vertical-align:middle\">\n            <div id=\"loaderInline\"></div> Please be patient. Your generation may take <span id=\"eta\">X</span> seconds.  <!-- Each run may produce different results due to random sampling. -->\n        </div>\n    </div>\n</form>\n\n{% if num_return_sequences > 1 %}\n<ul class=\"nav nav-tabs\" id=\"resultTabNav\" role=\"tablist\">\n  {%for i in range(0, num_return_sequences)%}\n  <li class=\"nav-item\" role=\"presentation\">\n    <button class=\"nav-link{{ ' active' if i == 0 else '' }}\" id=\"result-tab{{i}}\" data-bs-toggle=\"tab\" data-bs-target=\"#result{{i}}\" type=\"button\" role=\"tab\" aria-controls=\"result{{i}}\" aria-selected=\"{{ 'tur' if i == 0 else 'false' }}\">Result {{i+1}}</button>\n  </li>\n  {% endfor %}\n</ul>\n{% endif %}\n<div class=\"tab-content\" id=\"resultTabContent\">\n  {% for i in range(0, num_return_sequences) %}\n  <div class=\"tab-pane fade{{ ' show active' if i == 0 else '' }}\" id=\"result{{i}}\" role=\"tabpanel\" aria-labelledby=\"result{{i}}-tab\">\n    <div id=\"result{{i}}-content\" class=\"result-block form-control p-2\" style=\"font-size: 20px;\"><span id=\"promptDisplay\"></span><span id=\"response\">\n        </span><span id=\"error\"></span>\n    </div>\n  </div>\n  {% endfor %}\n</div>\n\n</div>\n\n\n<div class=\"d-sm-flex justify-content-center text-center\">\n    <p class=\"lead\">Like the results? &#11088;  Support Alpa development by staring Alpa on GitHub  &nbsp;</p>\n    <a class=\"github-button\" href=\"https://github.com/alpa-projects/alpa\" data-color-scheme=\"no-preference: light; light: light; dark: dark;\" data-icon=\"octicon-star\" data-size=\"large\" data-show-count=\"true\" aria-label=\"Star alpa-projects/alpa on GitHub\">Star</a>\n</div>\n\n  <div class=\"d-grid gap-2 d-sm-flex justify-content-sm-center\">\n    <a class=\"btn btn-block btn-tumblr\" onclick=\"takeshot()\"><i class=\"fa-solid fa-camera\"></i> Screenshot</a>\n    <a href=\"https://twitter.com/intent/tweet?text=Prompting%20OPT-175B%20with%20Alpa%20is%20fun!%20Try%20it%20yourself%20(unlimited)%20at%20http%3A%2F%2Fopt.alpa.ai%2F!%20%23alpa\" target=\"_blank\" class=\"btn btn-block btn-twitter\"><i class=\"fa-brands fa-twitter\"></i> Tweet it! #alpa</a>\n  </div>\n</div>\n\n<div class=\"container bg-white py-3\" id=\"faq\">\n    <h1 class=\"display-6 py-3 fw\">Frequently Asked Questions</h1>\n    <div class=\"accordion accordion-flush\" id=\"accordionPanelsStayOpenExample\">\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-headingOne\">\n          <button class=\"accordion-button\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-collapseOne\" aria-expanded=\"true\" aria-controls=\"panelsStayOpen-collapseOne\">\n              What is Alpa?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-collapseOne\" class=\"accordion-collapse collapse show\" aria-labelledby=\"panelsStayOpen-headingOne\">\n          <div class=\"accordion-body\">\n            <a href=\"https://github.com/alpa-projects/alpa\" target=\"_blank\">Alpa</a> is an open-source system for training and serving large-scale neural networks. Alpa aims to automate large-scale distributed training and serving with <strong>just a few lines of code</strong>.\n              Alpa was initially developed by folks in the <a href=\"https://sky.cs.berkeley.edu/\" target=\"_blank\">Sky Lab, UC Berkeley</a>. Some advanced techniques used in Alpa have been written in <a href=\"https://arxiv.org/pdf/2201.12023.pdf\" target=\"_blank\"> a paper published in OSDI'2022</a>.\n              Alpa community is growing with new contributors from Google, Amazon, AnyScale, and <a href=\"https://github.com/alpa-projects/alpa/graphs/contributors\" target=\"_blank\">more</a>.\n          </div>\n        </div>\n      </div>\n\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"what-is-opt-gpt\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#what-is-opt-gpt-collapse\" aria-expanded=\"false\" aria-controls=\"what-is-opt-gpt-collapse\">\n              What are language models and GPT-3? Could you give more general introduction about them and their applications?\n          </button>\n        </h2>\n        <div id=\"what-is-opt-gpt-collapse\" class=\"accordion-collapse collapse\" aria-labelledby=\"what-is-opt-gpt\">\n          <div class=\"accordion-body\">\n              <p>\n            A language model is a probability distribution over sequences of words. It predicts the next word based on all the previous words.\n                  It is useful for a variety of AI applications, such the auto-completion in your email or chatbot service.\n              For more information, check out the <a href=\"https://en.wikipedia.org/wiki/Language_model\" target=\"_blank\">language model wikipedia page</a>.\n              </p>\n              <p>\n                  <a href=\"https://en.wikipedia.org/wiki/GPT-3\" target=\"_blank\">GPT-3</a> is very large language model, with 175 billion parameters, that uses deep learning to produce human-like text.\n                  Many researchers and news articles described GPT-3 as \"one of the most interesting and important AI systems ever produced\".\n                  GPT-3 is gradually being used as a backbone in the latest NLP research and applications.\n              </p>\n              <p>\n                  Due to its gigantic size, training and serving GPT-3 are very difficult and expensive, and pose significant challenges to the underlying software systems.\n                  The original GPT-3 trained by OpenAI is closed sourced and developed as a charged service --- When using it, the users have to pay for every token generated.\n              </p>\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-headingTwo\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-collapseTwo\" aria-expanded=\"false\" aria-controls=\"panelsStayOpen-collapseTwo\">\n            What is OPT-175B? How does it compare to GPT-3?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-collapseTwo\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-headingTwo\">\n          <div class=\"accordion-body\">\n              <a href=\"https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/MODEL_LICENSE.md\" target=\"_blank\">OPT-175B</a> is a GPT-3 equivalent model trained by Meta. It is by far the largest pretrained language model available with 175 billion parameters.\n              You can request the access to the trained weights by filling <a href=\"https://forms.gle/BDB2i44QwCr2mCJN6\" target=\"_blank\">this form</a>. For detailed performance of OPT-175B,\n              check the <a href=\"https://arxiv.org/pdf/2205.01068.pdf\" target=\"_blank\">OPT paper</a>.\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-headingThree\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-collapseThree\" aria-expanded=\"false\" aria-controls=\"panelsStayOpen-collapseThree\">\n            Any tips for better generation?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-collapseThree\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-headingThree\">\n          <div class=\"accordion-body\">\n              You can start with the provided examples. Avoid spaces at the end of your query. New lines are great though.\n              More examples can be found in the appendix of the <a href=\"https://arxiv.org/pdf/2205.01068.pdf\" target=\"_blank\">OPT paper</a>.\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-headingSampling\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-collapseSampling\" aria-expanded=\"false\" aria-controls=\"panelsStayOpen-collapseSampling\">\n            What sampling method do you use? What do Temperature and Top-p mean?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-collapseSampling\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-headingSampling\">\n          <div class=\"accordion-body\">\n            <p>Right now we use random sampling, so every time you click \"generate\" the generated result might be different. The <em>temperature</em> controls how <em>sharp</em> the sampling distribution is.\n                Lower temperature pushes the generator to pick the tokens with higher scores from the model.\n                <em>Top-p</em> sampling chooses from the smallest possible set of words whose cumulative probability exceeds the probability <em>p</em>.\n                Small value of <em>p</em> prevents the model to choose from tokens with lower scores.\n                See more detailed description on how to sample on <a href=\"https://huggingface.co/blog/how-to-generate\" target=\"_blank\">this page from huggingface</a>.</p>\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-more-generation-args\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-more-generation-args-collapse\" aria-expanded=\"false\"\n                  aria-controls=\"panelsStayOpen-more-generation-args-collapse\">\n            I want more customizations on how to generate, such as using beam search or tuning the repetition penalty. How can I do that?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-more-generation-args-collapse\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-more-generation-args\">\n          <div class=\"accordion-body\">\n            <p>This web interface exposes only three arguments for simplicity, although our backend supports\n                <a href=\"https://huggingface.co/docs/transformers/v4.20.1/en/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate\" target=\"_blank\">a diverse set of generation techniques and arguments</a>.\n            </p>\n              <p>We are developing a RESTFUL API to expose the full set of arguments. Stay tuned.\n              Meanwhile, if you want to try out different generation techniques and hyperparameters now, you can <a href=\"https://alpa-projects.github.io/tutorials/llm_serving.html\" target=\"_blank\">set up your own OPT-175B service using Alpa</a>\n                  and start from <a href=\"https://github.com/alpa-projects/alpa/blob/main/examples/llm_serving/benchmark/benchmark_text_gen.py#L183\" target=\"_blank\">here</a>.</p>\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-data-collection\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-data-collection-collapse\" aria-expanded=\"false\" aria-controls=\"panelsStayOpen-data-collection-collapse\">\n            Are you collecting any data from my inputs when I use this service?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-data-collection-collapse\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-data-collection\">\n          <div class=\"accordion-body\">\n            We are not storing the content of your inputs. We only log the traffic patterns, such as the timestamp when you submitted your inputs and the length of your inputs.\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-headingFour\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-collapseFour\" aria-expanded=\"false\" aria-controls=\"panelsStayOpen-collapseFour\">\n            Why should I choose Alpa over existing systems?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-collapseFour\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-headingFour\">\n          <div class=\"accordion-body\">\n            <p>High-level speaking, Alpa is <b>more automatic, scalable, and cost-effective</b> compared to existing systems.</p>\n            <p>\n            In more details, if you are an ML developer or data scientist who is looking for a system that can train or serve large models like GPT-3, Alpa provides state-of-the-art performance while requires\n                the least amount of system expertise to setup. Meanwhile, Alpa enables to train or serve large models on older generations of (hence cheaper) GPUs, such as 40GB A100, V100, T4, M60, etc.,\n                which are common in many in-house clusters and more accessible for many people.\n            <p>\n            If you are a system developer aiming for developing better training or serving systems, Alpa, as a compiler, offers the most flexibility to try out\n                various ML parallelization methods (inter- and intra-op parallelisms), and the richest coverage of big model architectures (GPT-3, MoE, WideResNet, etc.).\n              Alpa might be a good starting point for you to start your prototyping.\n            </p>\n            <p>\n            If you are an amateur in ML/NLP/systems, well &#128539, you can play with OPT-175B inference for free; while all existing service will charge you for each token generated.\n            </p>\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-headingFive\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-collapseFive\" aria-expanded=\"false\" aria-controls=\"panelsStayOpen-collapseFive\">\n            How many GPUs are needed to run the serving service for OPT-175B or GPT-3?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-collapseFive\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-headingFive\">\n          <div class=\"accordion-body\">\n              <p>\n            It depends on which types of GPUs used. A hard constraint now is that the total GPU memory in the cluster needs to be greater than 350GB in order to successfully run the model inference.\n              Many existing training or serving systems usually rely on using the latest generations of GPUs with the largest memory capacity, such as 80GB A100. In contrast, Alpa, due to its more powerful\n              backend, enables serving OPT-175B with more flexible parallelisms on older generations of GPUs, such as 40GB A100, V100, T4, M60, etc.</p>\n              <p>\n                Take an example, if you choose to use 16GB V100 GPUs, then you would need 350 / 16 = 22 V100 GPUs to run the service.\n              </p>\n              <p>\n                We are working on a feature to enable serving models even if you do not have enough GPU memory, stay tuned.\n              </p>\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-headingSix\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-collapseSix\" aria-expanded=\"false\" aria-controls=\"panelsStayOpen-collapseSix\">\n            How do you keep this service free?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-collapseSix\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-headingSix\">\n          <div class=\"accordion-body\">\n          <p>\n            Alpa does not require the latest generation GPUs (such as 80GB A100), hence reduces the machine cost.\n            With that, we leverage older generations of hardware provided by our sponsors: <a href=\"https://mbzuai.ac.ae/\" target=\"_blank\">MBZUAI</a>\n                and <a href=\"https://sky.cs.berkeley.edu/\" target=\"_blank\">Sky Lab, UC Berkeley</a>.\n          </p>\n          <p>\n              If you are interested in any form of donation or sponsorship to help the development of Alpa, please get in touch with Alpa authors in <a href=\"https://docs.google.com/forms/d/e/1FAIpQLScXE0pDOm1FBcKS8C9JxAS6GbD-8b037NqH36ndKRMrGJ3_Cw/viewform\" target=\"_blank\">Alpa Slack</a>.\n          </p>\n          </div>\n        </div>\n      </div>\n\n      <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"panelsStayOpen-heading\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#panelsStayOpen-collapse\" aria-expanded=\"false\" aria-controls=\"panelsStayOpen-collapse\">\n            Can I use this free service for my business?\n          </button>\n        </h2>\n        <div id=\"panelsStayOpen-collapse\" class=\"accordion-collapse collapse\" aria-labelledby=\"panelsStayOpen-heading\">\n          <div class=\"accordion-body\">\n              <strong>No</strong>. This is a public service provided by the Alpa authors and sponsors.\n              Your usage of this service is subject to Alpa's open source license. Your usage of the OPT-175B model is subject to Meta's <a href=\"https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/MODEL_LICENSE.md\" target=\"_blank\">OPT-175B license</a>,\n              which limits use to research purposes.\n          </div>\n        </div>\n      </div>\n\n    <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"offensive\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#offensive-collapse\" aria-expanded=\"false\" aria-controls=\"offensive-collapse\">\n            Why does this model sometimes generate something very offensive?\n          </button>\n        </h2>\n        <div id=\"offensive-collapse\" class=\"accordion-collapse collapse\" aria-labelledby=\"offensive\">\n          <div class=\"accordion-body\">\n              This is a well-known problem with large language models trained on text corpora collected from Internet.\n              There is an active line of research in the NLP and ML community on addressing this issue.\n              See <a href=\"https://www.deepmind.com/publications/ethical-and-social-risks-of-harm-from-language-models\" target=\"_blank\">this article</a>.\n              We'll incorporate latest research results into this service to improve the results in following iterations.\n          </div>\n         </div>\n    </div>\n\n\n    <div class=\"accordion-item\">\n        <h2 class=\"accordion-header\" id=\"ray\">\n          <button class=\"accordion-button collapsed\" type=\"button\" data-bs-toggle=\"collapse\" data-bs-target=\"#ray-collapse\" aria-expanded=\"false\" aria-controls=\"ray-collapse\">\n            What's the relation between Alpa and the Ray project?\n          </button>\n        </h2>\n        <div id=\"ray-collapse\" class=\"accordion-collapse collapse\" aria-labelledby=\"ray\">\n          <div class=\"accordion-body\">\n              Alpa currently runs on top of a Ray cluster, and uses some Ray functionalities to coordinate distributed processes. However, in contrast to Ray,\n              Alpa is designed as a compiler for large-scale distributed machine learning training and serving with high performance.\n          </div>\n         </div>\n    </div>\n</div>\n</div>\n\n<div class=\"container py-3\" id=\"partners\">\n  <h1 class=\"display-6 py-3 my-3 fw\">Alpa Partners</h1>\n  <div class=\"row\">\n      <div class=\"container\">\n        <section class=\"logo-carousel slider px-3 \" data-arrows=\"true\">\n          <div class=\"slide\"><img alt=\"berkeley-logo\" src=\"https://raw.githubusercontent.com/zhisbug/test-alpa-ci/master/ucberkeley-logo.png\" style=\"width: 70%;\"></div>\n          <div class=\"slide\"><img alt=\"mbzuai-logo\" src=\"https://upload.wikimedia.org/wikipedia/en/5/55/Mohamed_bin_Zayed_University_of_Artificial_Intelligence_logo.png\" style=\"width: 80%;\"></div>\n          <div class=\"slide\"><img alt=\"anyscale-logo\" src=\"https://lever-client-logos.s3.us-west-2.amazonaws.com/0114ec37-170e-4864-b9a9-f85452de1ce0-1633971232076.png\" style=\"width:80%;\"></div>\n          <div class=\"slide\"><img alt=\"aws-logo\" src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/9/93/Amazon_Web_Services_Logo.svg/512px-Amazon_Web_Services_Logo.svg.png?20170912170050\" style=\"width: 50%;\" ></div>\n          <div class=\"slide\"><img alt=\"google-logo\" src=\"https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTE2WoDOUPB5L3bRgbQZkU-iLCtmRLYSj29iwFqtMAKQCNALlPtbwSUoswoSvqk8sjrT5w&usqp=CAU\" style=\"width: 60%;\"></div>\n          <div class=\"slide\"><img alt=\"casl-logo\" src=\"https://www.casl-project.ai/assets/img/casl_logo-main.svg\" style=\"width: 80%;\"></div>\n        </section>\n      </div>\n  </div>\n</div>\n\n<div class=\"container py-5 my-5 text-center bg-light rounded-4\" id=\"contact\">\n  <h1 class=\"display-5 fw col-lg-10 mx-auto py-5\">Interested in contributing to the Alpa project?</h1>\n    <div class=\"col-lg-9 mx-auto\">\n      <div class=\"d-grid gap-2 d-sm-flex justify-content-sm-center\">\n        <a href=\"https://github.com/alpa-projects/alpa/fork\" target=\"_blank\" class=\"btn btn-primary btn-lg px-4 gap-3\">Fork on GitHub</a>\n        <a href=\"https://docs.google.com/forms/d/e/1FAIpQLScXE0pDOm1FBcKS8C9JxAS6GbD-8b037NqH36ndKRMrGJ3_Cw/viewform\" class=\"btn btn-outline-primary btn-lg px-4\" target=\"_blank\">Join Alpa Slack</a>\n      </div>\n    </div>\n</div>\n\n  <div class=\"container\">\n    <footer class=\"d-flex flex-wrap justify-content-between align-items-center py-3 border-top\">\n    <div class=\"d-flex align-items-center\">\n        <a href=\"#home\" class=\"me-2\">\n          <img alt=\"alpa logo\" style=\"width: 40px;\" src=\"https://raw.githubusercontent.com/alpa-projects/alpa/main/docs/logo/alpa-logo-cropped.svg\">\n        </a>\n        <span class=\"text-muted\">&copy; 2022 Alpa Developers.</span>\n      </div>\n\n      <ul class=\"nav nav-pills\">\n        <li class=\"nav-item\"><a href=\"#generation\" class=\"nav-link px-2 text-muted\">Generation</a></li>\n        <li class=\"nav-item\"><a href=\"#faq\" class=\"nav-link px-2 text-muted\">FAQs</a></li>\n        <li class=\"nav-item\"><a href=\"https://github.com/alpa-projects/alpa/blob/main/LICENSE\" target=\"_blank\" class=\"nav-link px-2 text-muted\">Alpa License</a></li>\n        <li class=\"nav-item\"><a href=\"https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/MODEL_LICENSE.md\" target=\"_blank\" class=\"nav-link px-2 text-muted\">OPT License</a></li>\n        <li class=\"nav-item\"><a href=\"https://github.com/alpa-projects/alpa\" target=\"_blank\" class=\"nav-link px-2 text-muted\">GitHub</a></li>\n      </ul>\n    </footer>\n  </div>\n\n<!-- Google tag (gtag.js) -->\n<script async src=\"https://www.googletagmanager.com/gtag/js?id=G-XPSB9HFTDS\"></script>\n<script>\n  window.dataLayer = window.dataLayer || [];\n  function gtag(){dataLayer.push(arguments);}\n  gtag('js', new Date());\n\n  gtag('config', 'G-XPSB9HFTDS');\n</script>\n\n</body>\n</html>\n"
  },
  {
    "path": "examples/llm_serving/service/utils.py",
    "content": "\"\"\"Adapted from Metaseq.\"\"\"\nimport datetime\nimport logging\nimport logging.handlers\nimport os\nimport sys\n\nfrom llm_serving.service.constants import LOGDIR\n\n\nhandler = None\n\n\ndef build_logger():\n    global handler\n\n    formatter = logging.Formatter(\n        fmt=\"%(asctime)s | %(levelname)s | %(name)s | %(message)s\",\n        datefmt=\"%Y-%m-%d %H:%M:%S\",\n    )\n\n    # Set the format of root handlers\n    if not logging.getLogger().handlers:\n        logging.basicConfig(level=logging.INFO)\n    logging.getLogger().handlers[0].setFormatter(formatter)\n\n    # Redirect stdout and stderr to loggers\n    stdout_logger = logging.getLogger(\"stdout\")\n    stdout_logger.setLevel(logging.INFO)\n    sl = StreamToLogger(stdout_logger, logging.INFO)\n    sys.stdout = sl\n\n    stderr_logger = logging.getLogger(\"stderr\")\n    stderr_logger.setLevel(logging.ERROR)\n    sl = StreamToLogger(stderr_logger, logging.ERROR)\n    sys.stderr = sl\n\n    # Get logger\n    logger = logging.getLogger(\"alpa.llm_serving\")\n    logger.setLevel(logging.INFO)\n\n    # Add a file handler for all loggers\n    if handler is None:\n        os.makedirs(LOGDIR, exist_ok=True)\n        filename = os.path.join(LOGDIR, f\"llm_serving.worker.log\")\n        handler = logging.handlers.TimedRotatingFileHandler(\n            filename, when='D', utc=True)\n        handler.setFormatter(formatter)\n\n        for name, item in logging.root.manager.loggerDict.items():\n            if isinstance(item, logging.Logger):\n                item.addHandler(handler)\n\n    return logger\n\n\nclass StreamToLogger(object):\n    \"\"\"\n    Fake file-like stream object that redirects writes to a logger instance.\n    \"\"\"\n    def __init__(self, logger, log_level=logging.INFO):\n        self.terminal = sys.stdout\n        self.logger = logger\n        self.log_level = log_level\n        self.linebuf = ''\n\n    def __getattr__(self, attr):\n        return getattr(self.terminal, attr)\n\n    def write(self, buf):\n        temp_linebuf = self.linebuf + buf\n        self.linebuf = ''\n        for line in temp_linebuf.splitlines(True):\n            # From the io.TextIOWrapper docs:\n            #   On output, if newline is None, any '\\n' characters written\n            #   are translated to the system default line separator.\n            # By default sys.stdout.write() expects '\\n' newlines and then\n            # translates them so this is still cross platform.\n            if line[-1] == '\\n':\n                self.logger.log(self.log_level, line.rstrip())\n            else:\n                self.linebuf += line\n\n    def flush(self):\n        if self.linebuf != '':\n            self.logger.log(self.log_level, self.linebuf.rstrip())\n        self.linebuf = ''\n"
  },
  {
    "path": "examples/llm_serving/test_completions.py",
    "content": "\"\"\"\nUsage:\n\npython3 test_completions.py --url http://localhost:20001\npython3 test_completions.py --url https://api.alpa.ai --api-key YOUR_KEY\n\"\"\"\nimport argparse\n\nfrom client import Client\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--url\", type=str)\n    parser.add_argument(\"--api-key\", type=str)\n    parser.add_argument(\"--model\", type=str, default=\"default\")\n    args = parser.parse_args()\n\n    client = Client(args.url, api_key=args.api_key, default_model=args.model)\n    ret = client.completions(\n        [\"Paris is the capital city of\",\n         \"Computer science is the study of\"]\n    )\n    print(ret)\n"
  },
  {
    "path": "examples/llm_serving/test_logprobs.py",
    "content": "\"\"\"\nUsage:\n\npython3 test_logprobs.py --url http://localhost:20001\npython3 test_logprobs.py --url https://api.alpa.ai --api-key YOUR_KEY\n\"\"\"\nimport argparse\nimport time\n\nimport numpy as np\nfrom scipy.special import softmax\nfrom transformers import AutoTokenizer\n\nfrom client import Client\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--url\", type=str)\n    parser.add_argument(\"--api-key\", type=str)\n    args = parser.parse_args()\n\n    client = Client(args.url, api_key=args.api_key)\n    tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-30b\", use_fast=False)\n    tokenizer.add_bos_token = False\n\n    prompts = [\n        \"Paris is the capital city of France\",\n        \"Computer science is the\",\n    ]\n\n    input_ids = tokenizer(prompts, padding=\"longest\").input_ids\n    top_k = 50\n\n    output = client.logprobs(input_ids, top_k=top_k)\n\n    tic = time.time()\n    num_tokens = 40\n    for i in range(num_tokens):\n        print(\"=\" * 20 + f\" Step {i} \" + \"=\" * 20)\n        for j in range(len(input_ids)):\n            distribution = np.full((tokenizer.vocab_size + 10), -1e8, dtype=np.float32)\n            for idx, logprob in zip(output['indices'][j], output['logprobs'][j]):\n                distribution[idx] = logprob\n            # distribution = softmax(distribution)\n            # token = np.random.choice(np.arange(len(distribution)), p=distribution)\n            token = distribution.argmax()\n            input_ids[j].append(int(token))\n            print(tokenizer.decode(input_ids[j], skip_special_tokens=True))\n            print(\"-\" * 20)\n        output = client.logprobs(input_ids, top_k=top_k, cache_id=output[\"cache_id\"])\n    time_cost = time.time() - tic\n    print(f\"Generation throughput: {len(prompts) * num_tokens/time_cost:.2f} token/s\")\n"
  },
  {
    "path": "examples/llm_serving/test_textgen.sh",
    "content": "# Test the correctness of textgen.py\nset -x\n\npython3 textgen.py --model bigscience/bloom-560m\npython3 textgen.py --model jax/bloom-560m\npython3 textgen.py --model alpa/bloom-560m\n\npython3 textgen.py --model facebook/opt-1.3b\npython3 textgen.py --model jax/opt-1.3b\npython3 textgen.py --model alpa/opt-1.3b\n"
  },
  {
    "path": "examples/llm_serving/textgen.py",
    "content": "\"\"\"Use huggingface/transformers interface and Alpa backend for distributed inference.\"\"\"\nimport argparse\n\nimport numpy as np\nfrom transformers import AutoTokenizer\n\nfrom llm_serving.model.wrapper import get_model\n\ndef main(args):\n    # Load the tokenizer.\n    if \"opt\" in args.model:\n        # We have to use the 30B version because other versions have some issues.\n        # The 30B version works for all OPT models.\n        tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-30b\")\n        tokenizer.add_bos_token = False\n    elif \"bloom\" in args.model:\n        name = args.model.replace(\"alpa\", \"bigscience\")\\\n                         .replace(\"jax\", \"bigscience\")\n        tokenizer = AutoTokenizer.from_pretrained(name)\n\n    generate_params = {\n        \"do_sample\": args.do_sample,\n        \"num_beams\": args.num_beams,\n        \"num_return_sequences\": args.num_return_sequences\n    }\n    \n    # Load the model\n    model = get_model(model_name=args.model,\n                      path=args.path,\n                      batch_size=args.n_prompts,\n                      **generate_params)\n\n    # Generate\n    prompts = [\n        \"Paris is the capital city of\",\n        \"Today is a good day and I'd like to\",\n        \"Computer Science studies the area of\",\n        \"University of California Berkeley is a public university\"\n    ]\n    prompts = prompts[:args.n_prompts]\n    input_ids = tokenizer(prompts, return_tensors=\"pt\", padding=\"longest\").input_ids\n    output_ids = model.generate(input_ids=input_ids,\n                                max_length=64,\n                                **generate_params)\n    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n    \n    # Print results\n    print(\"Outputs:\\n\" + 100 * '-')\n    for i, output in enumerate(outputs):\n        print(f\"{i}: {output}\")\n        print(100 * '-')\n    \n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--model', type=str, default='alpa/opt-1.3b')\n    parser.add_argument('--path', type=str, default='~/opt_weights')\n    parser.add_argument('--do-sample', action='store_true')\n    parser.add_argument('--num-beams', type=int, default=1)\n    parser.add_argument('--num-return-sequences', type=int, default=1)\n    parser.add_argument('--n-prompts', type=int, default=4)\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/llm_serving/textgen_1d.py",
    "content": "\"\"\"Use huggingface/transformers interface and Alpa backend for distributed inference.\"\"\"\nimport argparse\nimport time\n\nimport numpy as np\nfrom transformers import AutoTokenizer\n\nfrom llm_serving.model.wrapper_1d import get_model\nfrom llm_serving.model.opt_utils import sync\nfrom alpa.timer import timers\n\n\ndef main(args):\n    # Load the tokenizer. We have to use the 30B version because\n    # other versions have some issues. The 30B version works for all OPT models.\n    tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-30b\", use_fast=False)\n    tokenizer.add_bos_token = False\n\n    generate_params = {\n        \"do_sample\": args.do_sample,\n        \"max_new_tokens\": 128,\n        # \"max_length\": 128\n    }\n\n    # Load the model\n    model = get_model(model_name=args.model,\n                      path=\"~/opt_weights\",\n                      batch_size=32,\n                      cache_size=4096)\n\n    prompts = [\n        \"Computer science is the study of computation and\",\n        \"Ion Stoica is a Romanian-American computer scientist specializing in\",\n        \"The University of California, Berkeley is a public\",\n        \"Today is a good day and I want to\",\n        \"What is the valuation of Databricks?\",\n        \"Paris is the capital city of\",\n        \"Which country has the most population?\",\n        \"What do you think about the future of Cryptocurrency?\",\n        \"What do you think about the meaning of life?\",\n        \"Donald Trump is the president of\",\n        \"GPT-3 is a large language model that is capable of\"\n    ]\n\n    input_ids = tokenizer(prompts, return_tensors=\"np\", padding=\"longest\").input_ids\n\n    n_warmup = 10\n    for i in range(n_warmup):\n        sync()\n        tic = time.time()\n        output_ids, latency = model.generate(input_ids, **generate_params)\n        sync()\n        elapsed = time.time() - tic\n        print(f\"- It takes {elapsed}, latency: {latency}\")\n\n        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)\n    if False:\n        print(\"Outputs:\\n\" + 100 * '-')\n        for i, output in enumerate(outputs):\n            print(output_ids[i])\n            print(f\"{i + 1}: {output}\")\n            print(100 * '-')\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--model\", type=str, default=\"alpa/opt-1d-1.3b\")\n    parser.add_argument('--do-sample', action='store_true')\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/mnist/README.md",
    "content": "--------------------------------------------------------------------------------\n\nAdopted from https://github.com/google/flax/tree/main/examples/mnist.\n\nUse `alpa.parallelize` to parallelize the training loop.\n\n1. Run training with all local GPUs in a single machine.\n```\npython3 main.py --workdir=/tmp/mnist --config=configs/default.py --config.batch_size 8192\n```\nSee `train.py` for a minimal example of using alpa on a single machine.\n\n2. Run training with all GPUs in a ray cluster\n```\nray start --head\npython3 main.py --workdir=/tmp/mnist --config=configs/default.py --config.batch_size 8192 --use_ray\n```\nSee `train_ray.py` for a minimal example of using alpa on a ray cluster.\n\n--------------------------------------------------------------------------------\n\n## MNIST classification\n\nTrains a simple convolutional network on the MNIST dataset.\n\nYou can run this code and even modify it directly in Google Colab, no\ninstallation required:\n\nhttps://colab.research.google.com/github/google/flax/blob/main/examples/mnist/mnist.ipynb\n\n### Requirements\n* TensorFlow dataset `mnist` will be downloaded and prepared automatically, if necessary\n\n### Example output\n\n|  Name   | Epochs | Walltime | Top-1 accuracy |   Metrics   |                  Workdir                  |\n| :------ | -----: | :------- | :------------- | :---------- | :---------------------------------------- |\n| default |     10 | 7.7m     | 99.17%         | [tfhub.dev] | [gs://flax_public/examples/mnist/default] |\n\n[tfhub.dev]: https://tensorboard.dev/experiment/1G9SvrW5RQyojRtMKNmMuQ/#scalars&_smoothingWeight=0&regexInput=default\n[gs://flax_public/examples/mnist/default]: https://console.cloud.google.com/storage/browser/flax_public/examples/mnist/default\n\n```\nI0828 08:51:41.821526 139971964110656 train.py:130] train epoch: 10, loss: 0.0097, accuracy: 99.69\nI0828 08:51:42.248714 139971964110656 train.py:180] eval epoch: 10, loss: 0.0299, accuracy: 99.14\n```\n\n### How to run\n\n`python main.py --workdir=/tmp/mnist --config=configs/default.py`\n\n#### Overriding Hyperparameter configurations\n\nMNIST example allows specifying a hyperparameter configuration by the means of\nsetting `--config` flag. Configuration flag is defined using\n[config_flags](https://github.com/google/ml_collections/tree/master#config-flags).\n`config_flags` allows overriding configuration fields. This can be done as\nfollows:\n\n```shell\npython main.py \\\n--workdir=/tmp/mnist --config=configs/default.py \\\n--config.learning_rate=0.05 --config.num_epochs=5\n```\n"
  },
  {
    "path": "examples/mnist/configs/default.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Default Hyperparameter configuration.\"\"\"\n\nimport ml_collections\n\n\ndef get_config():\n  \"\"\"Get the default hyperparameter configuration.\"\"\"\n  config = ml_collections.ConfigDict()\n\n  config.learning_rate = 0.1\n  config.momentum = 0.9\n  config.batch_size = 128\n  config.num_epochs = 10\n  return config\n"
  },
  {
    "path": "examples/mnist/main.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"Main file for running the MNIST example.\n\nThis file is intentionally kept short. The majority of logic is in libraries\nthan can be easily tested and imported in Colab.\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nfrom clu import platform\nimport jax\nfrom ml_collections import config_flags\nimport tensorflow as tf\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('workdir', None, 'Directory to store model data.')\nflags.DEFINE_boolean('use_ray', False, 'Whether to use Ray cluster.')\nconfig_flags.DEFINE_config_file(\n    'config',\n    None,\n    'File path to the training hyperparameter configuration.',\n    lock_config=True)\n\n\ndef main(argv):\n  if len(argv) > 1:\n    raise app.UsageError('Too many command-line arguments.')\n\n  # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make\n  # it unavailable to JAX.\n  tf.config.experimental.set_visible_devices([], 'GPU')\n\n  if FLAGS.use_ray:\n    import train_ray as train\n  else:\n    import train\n\n  train.train_and_evaluate(FLAGS.config, FLAGS.workdir)\n\n\nif __name__ == '__main__':\n  flags.mark_flags_as_required(['config', 'workdir'])\n  app.run(main)\n"
  },
  {
    "path": "examples/mnist/requirements.txt",
    "content": "absl-py==1.0.0\nclu==0.0.6\nflax==0.3.6\njax==0.2.21\n--find-links https://storage.googleapis.com/jax-releases/jax_releases.html\njaxlib==0.1.70+cuda110  # Make sure CUDA version matches the base image.\nml-collections==0.1.0\nnumpy==1.21.4\noptax==0.1.0\ntensorflow==2.7.0\ntensorflow-datasets==4.4.0\n"
  },
  {
    "path": "examples/mnist/train.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"MNIST example.\n\nLibrary file which executes the training and evaluation loop for MNIST.\nThe data is loaded using tensorflow_datasets.\n\"\"\"\n\n# See issue #620.\n# pytype: disable=wrong-keyword-args\n\nimport time\n\n\nfrom absl import logging\nimport alpa\nfrom flax import linen as nn\nfrom flax.metrics import tensorboard\nfrom flax.training import train_state\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport numpy as np\nimport optax\nimport tensorflow_datasets as tfds\n\n\nclass CNN(nn.Module):\n  \"\"\"A simple CNN model.\"\"\"\n\n  @nn.compact\n  def __call__(self, x):\n    x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n    x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n    x = x.reshape((x.shape[0], -1))  # flatten\n    x = nn.Dense(features=256)(x)\n    x = nn.relu(x)\n    x = nn.Dense(features=10)(x)\n    return x\n\n\n@alpa.parallelize\ndef train_step(state, images, labels):\n  \"\"\"Computes gradients, loss and accuracy for a single batch.\"\"\"\n  def loss_fn(params):\n    logits = state.apply_fn({'params': params}, images)\n    one_hot = jax.nn.one_hot(labels, 10)\n    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))\n    return loss, logits\n\n  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n  (loss, logits), grads = grad_fn(state.params)\n  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n  state = state.apply_gradients(grads=grads)\n  return state, loss, accuracy\n\n\n@alpa.parallelize(donate_argnums=())\ndef eval_step(state, images, labels):\n  logits = state.apply_fn({'params': state.params}, images)\n  one_hot = jax.nn.one_hot(labels, 10)\n  loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))\n  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n  return loss, accuracy\n\n\ndef train_epoch(state, train_ds, batch_size):\n  \"\"\"Train for a single epoch.\"\"\"\n  train_ds_size = len(train_ds['image'])\n  steps_per_epoch = train_ds_size // batch_size\n\n  epoch_loss = []\n  epoch_accuracy = []\n\n  for i in range(steps_per_epoch):\n    batch_images = train_ds['image'][i*batch_size:(i+1)*batch_size]\n    batch_labels = train_ds['label'][i*batch_size:(i+1)*batch_size]\n    state, loss, accuracy = train_step(state, batch_images, batch_labels)\n    epoch_loss.append(loss)\n    epoch_accuracy.append(accuracy)\n  alpa.prefetch((epoch_loss, epoch_accuracy))\n  train_loss = np.mean(epoch_loss)\n  train_accuracy = np.mean(epoch_accuracy)\n  return state, train_loss, train_accuracy\n\n\ndef get_datasets():\n  \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n  ds_builder = tfds.builder('mnist')\n  ds_builder.download_and_prepare()\n  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n  train_ds['image'] = np.float32(train_ds['image']) / 255.\n  test_ds['image'] = np.float32(test_ds['image']) / 255.\n  train_ds['label'] = np.int32(train_ds['label'])\n  test_ds['label'] = np.int32(test_ds['label'])\n  return train_ds, test_ds\n\n\ndef create_train_state(rng, config):\n  \"\"\"Creates initial `TrainState`.\"\"\"\n  cnn = CNN()\n  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']\n  tx = optax.sgd(config.learning_rate, config.momentum)\n  return train_state.TrainState.create(\n      apply_fn=cnn.apply, params=params, tx=tx)\n\n\ndef train_and_evaluate(config: ml_collections.ConfigDict,\n                       workdir: str) -> train_state.TrainState:\n  \"\"\"Execute model training and evaluation loop.\n\n  Args:\n    config: Hyperparameter configuration for training and evaluation.\n    workdir: Directory where the tensorboard summaries are written to.\n\n  Returns:\n    The train state (which includes the `.params`).\n  \"\"\"\n  train_ds, test_ds = get_datasets()\n\n  summary_writer = tensorboard.SummaryWriter(workdir)\n  summary_writer.hparams(dict(config))\n\n  rng = jax.random.PRNGKey(0)\n  state = create_train_state(rng, config)\n\n  for epoch in range(1, config.num_epochs + 1):\n    tic = time.time()\n    state, train_loss, train_accuracy = train_epoch(state, train_ds,\n                                                    config.batch_size)\n    epoch_time = time.time() - tic\n    test_loss, test_accuracy = eval_step(state, test_ds['image'], test_ds['label'])\n    test_accuracy = np.array(test_accuracy)\n    logging.info(\n        'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f, epoch_time: %.3f'\n        % (epoch, train_loss, train_accuracy * 100, test_loss,\n           test_accuracy * 100, epoch_time))\n\n    summary_writer.scalar('train_loss', train_loss, epoch)\n    summary_writer.scalar('train_accuracy', train_accuracy, epoch)\n    summary_writer.scalar('test_loss', test_loss, epoch)\n    summary_writer.scalar('test_accuracy', test_accuracy, epoch)\n\n  summary_writer.flush()\n  return state\n"
  },
  {
    "path": "examples/mnist/train_ray.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"MNIST example.\n\nLibrary file which executes the training and evaluation loop for MNIST.\nThe data is loaded using tensorflow_datasets.\n\"\"\"\n\n# See issue #620.\n# pytype: disable=wrong-keyword-args\n\nimport time\n\n\nfrom absl import logging\nimport alpa\nfrom flax import linen as nn\nfrom flax.metrics import tensorboard\nfrom flax.training import train_state\nimport jax\nimport jax.numpy as jnp\nimport ml_collections\nimport numpy as np\nimport optax\nimport tensorflow_datasets as tfds\n\n\nclass CNN(nn.Module):\n  \"\"\"A simple CNN model.\"\"\"\n\n  @nn.compact\n  def __call__(self, x):\n    x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n    x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n    x = nn.relu(x)\n    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n    x = x.reshape((x.shape[0], -1))  # flatten\n    x = nn.Dense(features=256)(x)\n    x = nn.relu(x)\n    x = nn.Dense(features=10)(x)\n    return x\n\n\n@alpa.parallelize\ndef train_step(state, images, labels):\n  \"\"\"Computes gradients, loss and accuracy for a single batch.\"\"\"\n  def loss_fn(params):\n    logits = state.apply_fn({'params': params}, images)\n    one_hot = jax.nn.one_hot(labels, 10)\n    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))\n    return loss, logits\n\n  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n  (loss, logits), grads = grad_fn(state.params)\n  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n  state = state.apply_gradients(grads=grads)\n  return state, loss, accuracy\n\n\n@alpa.parallelize(donate_argnums=())\ndef eval_step(state, images, labels):\n  logits = state.apply_fn({'params': state.params}, images)\n  one_hot = jax.nn.one_hot(labels, 10)\n  loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))\n  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n  return loss, accuracy\n\n\ndef train_epoch(state, train_data_loader, steps_per_epoch):\n  \"\"\"Train for a single epoch.\"\"\"\n  epoch_loss = []\n  epoch_accuracy = []\n\n  for i in range(steps_per_epoch):\n    batch_images, batch_labels = next(train_data_loader)\n    state, loss, accuracy = train_step(state, batch_images, batch_labels)\n    epoch_loss.append(loss)\n    epoch_accuracy.append(accuracy)\n  alpa.prefetch((epoch_loss, epoch_accuracy))\n  train_loss = np.mean(epoch_loss)\n  train_accuracy = np.mean(epoch_accuracy)\n  return state, train_loss, train_accuracy\n\n\ndef get_datasets():\n  \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n  ds_builder = tfds.builder('mnist')\n  ds_builder.download_and_prepare()\n  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n  train_ds['image'] = np.float32(train_ds['image']) / 255.\n  test_ds['image'] = np.float32(test_ds['image']) / 255.\n  train_ds['label'] = np.int32(train_ds['label'])\n  test_ds['label'] = np.int32(test_ds['label'])\n  return train_ds, test_ds\n\n\ndef create_train_state(rng, config):\n  \"\"\"Creates initial `TrainState`.\"\"\"\n  cnn = CNN()\n  params = cnn.init(rng, jnp.ones([1, 28, 28, 1]))['params']\n  tx = optax.sgd(config.learning_rate, config.momentum)\n  return train_state.TrainState.create(\n      apply_fn=cnn.apply, params=params, tx=tx)\n\n\ndef get_train_data_loader(train_ds, state, batch_size):\n  images_np = train_ds['image']\n  labels_np = train_ds['label']\n  steps_per_epoch = len(images_np) // batch_size\n\n  def input_iter_func(start, end, batch_size):\n    while True:\n      for i in range(steps_per_epoch):\n        idx = start + i * batch_size\n        yield (images_np[idx:idx + batch_size],\n               labels_np[idx:idx + batch_size])\n\n  batch_images = jax.core.ShapedArray(\n      (batch_size, 28, 28, 1), jnp.float32)\n  batch_labels = jax.core.ShapedArray(\n      (batch_size,), jnp.int32)\n  executable = train_step.get_executable(state, batch_images, batch_labels)\n\n  data_loader = alpa.MeshDriverDataLoader(\n      batch_size, len(images_np),\n      input_iter_func, executable.get_input_placement_specs()[1:3],\n      prefetch_size=4, repeat=True)\n  return iter(data_loader), steps_per_epoch\n\n\ndef train_and_evaluate(config: ml_collections.ConfigDict,\n                       workdir: str) -> train_state.TrainState:\n  \"\"\"Execute model training and evaluation loop.\n\n  Args:\n    config: Hyperparameter configuration for training and evaluation.\n    workdir: Directory where the tensorboard summaries are written to.\n\n  Returns:\n    The train state (which includes the `.params`).\n  \"\"\"\n  alpa.init(cluster=\"ray\")\n  train_ds, test_ds = get_datasets()\n\n  summary_writer = tensorboard.SummaryWriter(workdir)\n  summary_writer.hparams(dict(config))\n\n  rng = jax.random.PRNGKey(0)\n  state = create_train_state(rng, config)\n\n  train_data_loader, steps_per_epoch = get_train_data_loader(\n      train_ds, state, config.batch_size)\n\n  for epoch in range(1, config.num_epochs + 1):\n    tic = time.time()\n    state, train_loss, train_accuracy = train_epoch(state, train_data_loader,\n                                                    steps_per_epoch)\n    epoch_time = time.time() - tic\n    test_loss, test_accuracy = eval_step(state, test_ds['image'], test_ds['label'])\n    test_accuracy = np.array(test_accuracy)\n    logging.info(\n        'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f, epoch_time: %.3f'\n        % (epoch, train_loss, train_accuracy * 100, test_loss,\n           test_accuracy * 100, epoch_time))\n\n    summary_writer.scalar('train_loss', train_loss, epoch)\n    summary_writer.scalar('train_accuracy', train_accuracy, epoch)\n    summary_writer.scalar('test_loss', test_loss, epoch)\n    summary_writer.scalar('test_accuracy', test_accuracy, epoch)\n\n  summary_writer.flush()\n  return state\n"
  },
  {
    "path": "examples/opt_finetune/README.md",
    "content": "<!---\nCopyright 2021 The HuggingFace Team. All rights reserved.\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n    http://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n-->\n\n# Fine-tuning OPT Language Models\n\n## Instructions\n\n### Launch a Ray cluster\n\n1. Use the command below to launch ray on a head node  \n  ```ray start --head```\n2. (Optional) If you have more nodes, connect them to the head node. The command should look like this, but with the ip address and password printed by the previous command.   \n  ```ray start --address='172.31.34.216:6379' --redis-password='5241590000000000'```\n\n### Run training\n\n**Note**: The command below is tested on AWS p3.16xlarge instances with 8 x 16GB V100 GPUs.\nTo run on other clusters, please tune the arguments `per_device_train_batch_size/num_micro_batches/operator_parallel/pipeline_parallel` to avoid out-of-memory and achieve a good throughput.\n```\npython3 run_clm_flax.py \\\n    --output_dir=\"./output\" \\\n    --model_name_or_path=\"facebook/opt-2.7b\" \\\n    --dataset_name=\"wikitext\" \\\n    --dataset_config_name=\"wikitext-2-raw-v1\" \\\n    --do_train --do_eval \\\n    --block_size=\"1024\" \\\n    --per_device_train_batch_size=\"20\" \\\n    --per_device_eval_batch_size=\"20\" \\\n    --num_micro_batches 4 \\\n    --operator_parallel 4 \\\n    --pipeline_parallel 1 \\\n    --dtype=\"float16\" \\\n    --learning_rate=\"5e-4\" --warmup_steps=\"2000\" \\\n    --adam_beta1=\"0.9\" --adam_beta2=\"0.98\" --weight_decay=\"0.01\" \\\n    --overwrite_output_dir \\\n    --num_train_epochs=\"8\" \\\n    --logging_steps=\"16\" \\\n    --save_steps=\"2500\" \\\n    --eval_steps=\"2500\"\n```\n\nMore documentation coming soon.\n\n\n# Acknowledgement\nAdopted from https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling\n"
  },
  {
    "path": "examples/opt_finetune/run_125m_shard.sh",
    "content": "python3 run_clm_flax.py \\\n    --output_dir=\"./output\" \\\n    --model_name_or_path=\"facebook/opt-125m\" \\\n    --dataset_name=\"wikitext\" \\\n    --dataset_config_name=\"wikitext-2-raw-v1\" \\\n    --do_train --do_eval \\\n    --block_size=\"1024\" \\\n    --per_device_train_batch_size=\"20\" \\\n    --per_device_eval_batch_size=\"20\" \\\n    --num_micro_batches 4 \\\n    --operator_parallel 4 \\\n    --pipeline_parallel 1 \\\n    --dtype=\"float16\" \\\n    --learning_rate=\"5e-4\" --warmup_steps=\"2000\" \\\n    --adam_beta1=\"0.9\" --adam_beta2=\"0.98\" --weight_decay=\"0.01\" \\\n    --overwrite_output_dir \\\n    --num_train_epochs=\"8\" \\\n    --logging_steps=\"16\" \\\n    --save_steps=\"32\" \\\n    --eval_steps=\"32\"\n"
  },
  {
    "path": "examples/opt_finetune/run_2.7b_pipe.sh",
    "content": "python3 run_clm_flax.py \\\n    --output_dir=\"./output\" \\\n    --model_name_or_path=\"facebook/opt-2.7b\" \\\n    --dataset_name=\"wikitext\" \\\n    --dataset_config_name=\"wikitext-2-raw-v1\" \\\n    --do_train --do_eval \\\n    --block_size=\"1024\" \\\n    --per_device_train_batch_size=\"64\" \\\n    --per_device_eval_batch_size=\"64\" \\\n    --num_micro_batches 64 \\\n    --operator_parallel 1 \\\n    --pipeline_parallel 2 \\\n    --dtype=\"float16\" \\\n    --learning_rate=\"5e-4\" --warmup_steps=\"2000\" \\\n    --adam_beta1=\"0.9\" --adam_beta2=\"0.98\" --weight_decay=\"0.01\" \\\n    --overwrite_output_dir \\\n    --num_train_epochs=\"10\" \\\n    --logging_steps=\"5\" \\\n    --save_steps=\"40\" \\\n    --eval_steps=\"25\"\n"
  },
  {
    "path": "examples/opt_finetune/run_2.7b_shard.sh",
    "content": "python3 run_clm_flax.py \\\n    --output_dir=\"./output\" \\\n    --model_name_or_path=\"facebook/opt-2.7b\" \\\n    --dataset_name=\"wikitext\" \\\n    --dataset_config_name=\"wikitext-2-raw-v1\" \\\n    --do_train --do_eval \\\n    --block_size=\"1024\" \\\n    --per_device_train_batch_size=\"20\" \\\n    --per_device_eval_batch_size=\"20\" \\\n    --num_micro_batches 4 \\\n    --operator_parallel 4 \\\n    --pipeline_parallel 1 \\\n    --dtype=\"float16\" \\\n    --learning_rate=\"5e-4\" --warmup_steps=\"2000\" \\\n    --adam_beta1=\"0.9\" --adam_beta2=\"0.98\" --weight_decay=\"0.01\" \\\n    --overwrite_output_dir \\\n    --num_train_epochs=\"8\" \\\n    --logging_steps=\"16\" \\\n    --save_steps=\"2500\" \\\n    --eval_steps=\"2500\"\n"
  },
  {
    "path": "examples/opt_finetune/run_clm_flax.py",
    "content": "#!/usr/bin/env python\n# coding=utf-8\n# Copyright 2021 The HuggingFace Team All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\"\"\"\nPre-training/Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.\n\nHere is the full list of checkpoints on the hub that can be fine-tuned by this script:\nhttps://huggingface.co/models?filter=text-generation\n\"\"\"\n# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.\n\nimport json\nimport logging\nimport math\nimport os\nimport sys\nimport time\nfrom dataclasses import asdict, dataclass, field\nfrom enum import Enum\nimport functools\nfrom itertools import chain\nfrom pathlib import Path\nfrom typing import Callable, Optional\n\nimport datasets\nimport numpy as np\nfrom datasets import Dataset, load_dataset\nfrom tqdm import tqdm\n\nimport alpa\nfrom alpa.model.model_util import DynamicScale, TrainState\nfrom alpa import AutoShardingOption, AutoLayerOption, ManualStageOption\nimport jax\nimport jax.numpy as jnp\nimport optax\nimport transformers\nimport tensorflow as tf\nfrom flax import jax_utils, traverse_util\nfrom flax.training import train_state\nfrom flax.training.common_utils import onehot, shard, shard_prng_key\nfrom huggingface_hub import Repository\nfrom transformers import (\n    CONFIG_MAPPING,\n    FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n    AutoConfig,\n    AutoTokenizer,\n    FlaxAutoModelForCausalLM,\n    HfArgumentParser,\n    is_tensorboard_available,\n    set_seed,\n)\n\nalpa.init(cluster=\"ray\")\n\nfrom transformers.testing_utils import CaptureLogger\nfrom transformers.utils import get_full_repo_name, send_example_telemetry\n\ntf.config.experimental.set_visible_devices([], 'GPU')\n\nlogger = logging.getLogger(__name__)\n\nMODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_CAUSAL_LM_MAPPING.keys())\nMODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)\n\n\n@dataclass\nclass TrainingArguments:\n    output_dir: str = field(\n        metadata={\"help\": \"The output directory where the model predictions and checkpoints will be written.\"},\n    )\n    overwrite_output_dir: bool = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"Overwrite the content of the output directory. \"\n                \"Use this to continue training if output_dir points to a checkpoint directory.\"\n            )\n        },\n    )\n    do_train: bool = field(default=False, metadata={\"help\": \"Whether to run training.\"})\n    do_eval: bool = field(default=False, metadata={\"help\": \"Whether to run eval on the dev set.\"})\n    per_device_train_batch_size: int = field(\n        default=8, metadata={\"help\": \"Batch size per GPU/TPU core/CPU for training.\"}\n    )\n    per_device_eval_batch_size: int = field(\n        default=8, metadata={\"help\": \"Batch size per GPU/TPU core/CPU for evaluation.\"}\n    )\n    num_micro_batches: int = field(default=1, metadata={\"help\": \"The number of micro batches for gradient accumulation.\"})\n    operator_parallel: int = field(default=1, metadata={\"help\": \"The degree of operator model parallelism.\"})\n    pipeline_parallel: int = field(default=1, metadata={\"help\": \"The degree of pipeline model parallelism.\"})\n    use_remat: bool = field(default=True, metadata={\"help\": \"Whether or not to use gradient rematerilization/gradient checkpointing.\"})\n    learning_rate: float = field(default=5e-5, metadata={\"help\": \"The initial learning rate for AdamW.\"})\n    weight_decay: float = field(default=0.0, metadata={\"help\": \"Weight decay for AdamW if we apply some.\"})\n    adam_beta1: float = field(default=0.9, metadata={\"help\": \"Beta1 for AdamW optimizer\"})\n    adam_beta2: float = field(default=0.999, metadata={\"help\": \"Beta2 for AdamW optimizer\"})\n    adam_epsilon: float = field(default=1e-8, metadata={\"help\": \"Epsilon for AdamW optimizer.\"})\n    adafactor: bool = field(default=False, metadata={\"help\": \"Whether or not to replace AdamW by Adafactor.\"})\n    num_train_epochs: float = field(default=3.0, metadata={\"help\": \"Total number of training epochs to perform.\"})\n    warmup_steps: int = field(default=0, metadata={\"help\": \"Linear warmup over warmup_steps.\"})\n    logging_steps: int = field(default=500, metadata={\"help\": \"Log every X updates steps.\"})\n    save_steps: int = field(default=500, metadata={\"help\": \"Save checkpoint every X updates steps.\"})\n    eval_steps: int = field(default=None, metadata={\"help\": \"Run an evaluation every X steps.\"})\n    seed: int = field(default=42, metadata={\"help\": \"Random seed that will be set at the beginning of training.\"})\n    push_to_hub: bool = field(\n        default=False, metadata={\"help\": \"Whether or not to upload the trained model to the model hub after training.\"}\n    )\n    hub_model_id: str = field(\n        default=None, metadata={\"help\": \"The name of the repository to keep in sync with the local `output_dir`.\"}\n    )\n    hub_token: str = field(default=None, metadata={\"help\": \"The token to use to push to the Model Hub.\"})\n\n    def __post_init__(self):\n        if self.output_dir is not None:\n            self.output_dir = os.path.expanduser(self.output_dir)\n\n    def to_dict(self):\n        \"\"\"\n        Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates\n        the token values by removing their value.\n        \"\"\"\n        d = asdict(self)\n        for k, v in d.items():\n            if isinstance(v, Enum):\n                d[k] = v.value\n            if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):\n                d[k] = [x.value for x in v]\n            if k.endswith(\"_token\"):\n                d[k] = f\"<{k.upper()}>\"\n        return d\n\n\n@dataclass\nclass ModelArguments:\n    \"\"\"\n    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.\n    \"\"\"\n\n    model_name_or_path: Optional[str] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch.\"\n            )\n        },\n    )\n    model_type: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"If training from scratch, pass a model type from the list: \" + \", \".join(MODEL_TYPES)},\n    )\n    config_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"Pretrained config name or path if not the same as model_name\"}\n    )\n    tokenizer_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"Pretrained tokenizer name or path if not the same as model_name\"}\n    )\n    cache_dir: Optional[str] = field(\n        default=None, metadata={\"help\": \"Where do you want to store the pretrained models downloaded from s3\"}\n    )\n    use_fast_tokenizer: bool = field(\n        default=True,\n        metadata={\"help\": \"Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.\"},\n    )\n    dtype: Optional[str] = field(\n        default=\"float32\",\n        metadata={\n            \"help\": (\n                \"Floating-point format in which the model weights should be initialized and trained. Choose one of\"\n                \" `[float32, float16, bfloat16]`.\"\n            )\n        },\n    )\n    use_auth_token: bool = field(\n        default=False,\n        metadata={\n            \"help\": (\n                \"Will use the token generated when running `transformers-cli login` (necessary to use this script \"\n                \"with private models).\"\n            )\n        },\n    )\n\n\n@dataclass\nclass DataTrainingArguments:\n    \"\"\"\n    Arguments pertaining to what data we are going to input our model for training and eval.\n    \"\"\"\n\n    dataset_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"The name of the dataset to use (via the datasets library).\"}\n    )\n    dataset_config_name: Optional[str] = field(\n        default=None, metadata={\"help\": \"The configuration name of the dataset to use (via the datasets library).\"}\n    )\n    train_file: Optional[str] = field(default=None, metadata={\"help\": \"The input training data file (a text file).\"})\n    validation_file: Optional[str] = field(\n        default=None,\n        metadata={\"help\": \"An optional input evaluation data file to evaluate the perplexity on (a text file).\"},\n    )\n    max_train_samples: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"For debugging purposes or quicker training, truncate the number of training examples to this \"\n                \"value if set.\"\n            )\n        },\n    )\n    max_eval_samples: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"For debugging purposes or quicker training, truncate the number of evaluation examples to this \"\n                \"value if set.\"\n            )\n        },\n    )\n    overwrite_cache: bool = field(\n        default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n    )\n    validation_split_percentage: Optional[int] = field(\n        default=5,\n        metadata={\n            \"help\": \"The percentage of the train set used as validation set in case there's no validation split\"\n        },\n    )\n    block_size: Optional[int] = field(\n        default=None,\n        metadata={\n            \"help\": (\n                \"Optional input sequence length after tokenization. \"\n                \"The training dataset will be truncated in block of this size for training. \"\n                \"Default to the model max input length for single sentence inputs (take into account special tokens).\"\n            )\n        },\n    )\n    overwrite_cache: bool = field(\n        default=False, metadata={\"help\": \"Overwrite the cached training and evaluation sets\"}\n    )\n    preprocessing_num_workers: Optional[int] = field(\n        default=None,\n        metadata={\"help\": \"The number of processes to use for the preprocessing.\"},\n    )\n    keep_linebreaks: bool = field(\n        default=True, metadata={\"help\": \"Whether to keep line breaks when using TXT files or not.\"}\n    )\n\n    def __post_init__(self):\n        if self.dataset_name is None and self.train_file is None and self.validation_file is None:\n            raise ValueError(\"Need either a dataset name or a training/validation file.\")\n        else:\n            if self.train_file is not None:\n                extension = self.train_file.split(\".\")[-1]\n                assert extension in [\"csv\", \"json\", \"txt\"], \"`train_file` should be a csv, a json or a txt file.\"\n            if self.validation_file is not None:\n                extension = self.validation_file.split(\".\")[-1]\n                assert extension in [\"csv\", \"json\", \"txt\"], \"`validation_file` should be a csv, a json or a txt file.\"\n\n\ndef data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int,\n                min_batch_size: int, shuffle: bool = False):\n    \"\"\"\n    Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.\n    Shuffle batches if `shuffle` is `True`.\n    \"\"\"\n    if len(dataset) < batch_size:\n        assert len(dataset) >= min_batch_size\n        batch_size = len(dataset) // min_batch_size * min_batch_size\n\n    data_collator = transformers.DefaultDataCollator(\"np\")\n    tf_dataset = dataset.to_tf_dataset(batch_size=batch_size,\n                                       columns=dataset.column_names,\n                                       collate_fn=data_collator,\n                                       shuffle=shuffle,\n                                       drop_remainder=True)\n\n    for batch in tf_dataset:\n        batch = {k: v._numpy() for k, v in batch.items()}\n        yield batch\n\n\ndef write_train_metric(summary_writer, train_metrics, train_time, step):\n    summary_writer.scalar(\"train_time\", train_time, step)\n\n    train_metrics = alpa.util.get_metrics(train_metrics)\n    for key, vals in train_metrics.items():\n        tag = f\"train_{key}\"\n        for i, val in enumerate(vals):\n            summary_writer.scalar(tag, val, step - len(vals) + i + 1)\n\n\ndef write_eval_metric(summary_writer, eval_metrics, step):\n    for metric_name, value in eval_metrics.items():\n        summary_writer.scalar(f\"eval_{metric_name}\", value, step)\n\n\ndef create_learning_rate_fn(\n    train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float\n) -> Callable[[int], jnp.array]:\n    \"\"\"Returns a linear warmup, linear_decay learning rate function.\"\"\"\n    steps_per_epoch = train_ds_size // train_batch_size\n    num_train_steps = steps_per_epoch * num_train_epochs\n    warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)\n    decay_fn = optax.linear_schedule(\n        init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps\n    )\n    schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])\n    return schedule_fn\n\n\ndef monkey_patch_remat():\n    # Use monkey patch to add remat for all transformer layers.\n    from transformers.models.opt.modeling_flax_opt import FlaxOPTDecoderLayer, FlaxOPTDecoderLayerCollection\n    from flax.linen.partitioning import remat\n    from flax.linen.module import wrap_method_once\n    import flax.linen as nn\n\n    @wrap_method_once\n    def setup(self):\n        self.layers = [\n            remat(FlaxOPTDecoderLayer, static_argnums=(2, 3, 4))(\n                self.config, name=str(i), dtype=self.dtype)\n            for i in range(self.config.num_hidden_layers)\n        ]\n        self.layerdrop = self.config.layerdrop\n\n    def call(\n        self,\n        hidden_states,\n        attention_mask,\n        deterministic: bool = True,\n        init_cache: bool = False,\n        output_attentions: bool = False,\n        output_hidden_states: bool = False,\n    ):\n        # decoder layers\n        all_hidden_states = () if output_hidden_states else None\n        all_self_attns = () if output_attentions else None\n\n        for decoder_layer in self.layers:\n            if output_hidden_states:\n                all_hidden_states += (hidden_states,)\n\n            layer_outputs = decoder_layer(\n                hidden_states,\n                attention_mask,\n                init_cache,\n                output_attentions,\n                deterministic,\n            )\n\n            hidden_states = layer_outputs[0]\n            if output_attentions:\n                all_self_attns += (layer_outputs[1],)\n\n        outputs = [hidden_states, all_hidden_states, all_self_attns]\n        return outputs\n\n    setattr(FlaxOPTDecoderLayerCollection, \"setup\", setup)\n    setattr(FlaxOPTDecoderLayerCollection, \"__call__\", call)\n\n\ndef main():\n    # See all possible arguments in src/transformers/training_args.py\n    # or by passing the --help flag to this script.\n    # We now keep distinct sets of args, for a cleaner separation of concerns.\n\n    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))\n    if len(sys.argv) == 2 and sys.argv[1].endswith(\".json\"):\n        # If we pass only one argument to the script and it's the path to a json file,\n        # let's parse it to get our arguments.\n        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))\n    else:\n        model_args, data_args, training_args = parser.parse_args_into_dataclasses()\n\n    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The\n    # information sent is the one passed as arguments along with your Python/PyTorch versions.\n    send_example_telemetry(\"run_clm\", model_args, data_args, framework=\"flax\")\n\n    if (\n        os.path.exists(training_args.output_dir)\n        and os.listdir(training_args.output_dir)\n        and training_args.do_train\n        and not training_args.overwrite_output_dir\n    ):\n        raise ValueError(\n            f\"Output directory ({training_args.output_dir}) already exists and is not empty.\"\n            \"Use --overwrite_output_dir to overcome.\"\n        )\n\n    # Make one log on every process with the configuration for debugging.\n    logging.basicConfig(\n        format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n        datefmt=\"%m/%d/%Y %H:%M:%S\",\n        level=logging.INFO,\n    )\n    # Setup logging, we only want one process per machine to log things on the screen.\n    logger.setLevel(logging.INFO)\n    datasets.utils.logging.set_verbosity_warning()\n    transformers.utils.logging.set_verbosity_info()\n\n    # Set the verbosity to info of the Transformers logger (on main process only):\n    logger.info(f\"Training/evaluation parameters {training_args}\")\n\n    # Set seed before initializing model.\n    set_seed(training_args.seed)\n\n    # Handle the repository creation\n    if training_args.push_to_hub:\n        if training_args.hub_model_id is None:\n            repo_name = get_full_repo_name(\n                Path(training_args.output_dir).absolute().name, token=training_args.hub_token\n            )\n        else:\n            repo_name = training_args.hub_model_id\n        repo = Repository(training_args.output_dir, clone_from=repo_name)\n\n    #  Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)\n    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/\n    # (the dataset will be downloaded automatically from the datasets Hub).\n    #\n    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called\n    # 'text' is found. You can easily tweak this behavior (see below).\n    #\n    # In distributed training, the load_dataset function guarantees that only one local process can concurrently\n    # download the dataset.\n    if data_args.dataset_name is not None:\n        # Downloading and loading a dataset from the hub.\n        dataset = load_dataset(\n            data_args.dataset_name,\n            data_args.dataset_config_name,\n            cache_dir=model_args.cache_dir,\n            keep_in_memory=False,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n\n        if \"validation\" not in dataset.keys():\n            dataset[\"validation\"] = load_dataset(\n                data_args.dataset_name,\n                data_args.dataset_config_name,\n                split=f\"train[:{data_args.validation_split_percentage}%]\",\n                cache_dir=model_args.cache_dir,\n                use_auth_token=True if model_args.use_auth_token else None,\n            )\n            dataset[\"train\"] = load_dataset(\n                data_args.dataset_name,\n                data_args.dataset_config_name,\n                split=f\"train[{data_args.validation_split_percentage}%:]\",\n                cache_dir=model_args.cache_dir,\n                use_auth_token=True if model_args.use_auth_token else None,\n            )\n    else:\n        data_files = {}\n        dataset_args = {}\n        if data_args.train_file is not None:\n            data_files[\"train\"] = data_args.train_file\n        if data_args.validation_file is not None:\n            data_files[\"validation\"] = data_args.validation_file\n        extension = data_args.train_file.split(\".\")[-1]\n        if extension == \"txt\":\n            extension = \"text\"\n            dataset_args[\"keep_linebreaks\"] = data_args.keep_linebreaks\n        dataset = load_dataset(\n            extension,\n            data_files=data_files,\n            cache_dir=model_args.cache_dir,\n            **dataset_args,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n\n        if \"validation\" not in dataset.keys():\n            dataset[\"validation\"] = load_dataset(\n                extension,\n                data_files=data_files,\n                split=f\"train[:{data_args.validation_split_percentage}%]\",\n                cache_dir=model_args.cache_dir,\n                **dataset_args,\n                use_auth_token=True if model_args.use_auth_token else None,\n            )\n            dataset[\"train\"] = load_dataset(\n                extension,\n                data_files=data_files,\n                split=f\"train[{data_args.validation_split_percentage}%:]\",\n                cache_dir=model_args.cache_dir,\n                **dataset_args,\n                use_auth_token=True if model_args.use_auth_token else None,\n            )\n    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at\n    # https://huggingface.co/docs/datasets/loading_datasets.html.\n\n    # Load pretrained model and tokenizer\n\n    # Distributed training:\n    # The .from_pretrained methods guarantee that only one local process can concurrently\n    # download model & vocab.\n    if model_args.config_name:\n        config = AutoConfig.from_pretrained(\n            model_args.config_name,\n            cache_dir=model_args.cache_dir,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    elif model_args.model_name_or_path:\n        config = AutoConfig.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=model_args.cache_dir,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    else:\n        config = CONFIG_MAPPING[model_args.model_type]()\n        logger.warning(\"You are instantiating a new config instance from scratch.\")\n\n    if training_args.use_remat:\n        monkey_patch_remat()\n\n    if model_args.tokenizer_name:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_args.tokenizer_name,\n            cache_dir=model_args.cache_dir,\n            use_fast=model_args.use_fast_tokenizer,\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n    elif model_args.model_name_or_path:\n        tokenizer = AutoTokenizer.from_pretrained(\n            model_args.model_name_or_path,\n            cache_dir=model_args.cache_dir,\n            #use_fast=model_args.use_fast_tokenizer,\n            use_auth_token=True if model_args.use_auth_token else None,\n            use_fast=False,\n        )\n    else:\n        raise ValueError(\n            \"You are instantiating a new tokenizer from scratch. This is not supported by this script.\"\n            \"You can do it from another script, save it, and load it from here, using --tokenizer_name.\"\n        )\n\n    if model_args.model_name_or_path:\n        model = FlaxAutoModelForCausalLM.from_pretrained(\n            model_args.model_name_or_path,\n            config=config,\n            seed=training_args.seed,\n            dtype=getattr(jnp, model_args.dtype),\n            use_auth_token=True if model_args.use_auth_token else None,\n        )\n        #from transformers import FlaxOPTForCausalLM\n        #config.num_hidden_layers = 2\n        #model = FlaxOPTForCausalLM(\n        #    config=config,\n        #    seed=training_args.seed,\n        #    dtype=getattr(jnp, model_args.dtype),\n        #)\n    else:\n        model = FlaxAutoModelForCausalLM.from_config(\n            config,\n            seed=training_args.seed,\n            dtype=getattr(jnp, model_args.dtype),\n        )\n\n    # Preprocessing the datasets.\n    # First we tokenize all the texts.\n    if training_args.do_train:\n        column_names = dataset[\"train\"].column_names\n    else:\n        column_names = dataset[\"validation\"].column_names\n    text_column_name = \"text\" if \"text\" in column_names else column_names[0]\n\n    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function\n    tok_logger = transformers.utils.logging.get_logger(\"transformers.tokenization_utils_base\")\n\n    def tokenize_function(examples):\n        with CaptureLogger(tok_logger) as cl:\n            output = tokenizer(examples[text_column_name])\n        # clm input could be much much longer than block_size\n        if \"Token indices sequence length is longer than the\" in cl.out:\n            tok_logger.warning(\n                \"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits\"\n                \" before being passed to the model.\"\n            )\n        return output\n\n    logger.info(\"***** Tokenize dataset *****\")\n    tokenized_datasets = dataset.map(\n        tokenize_function,\n        batched=True,\n        num_proc=data_args.preprocessing_num_workers,\n        remove_columns=column_names,\n        load_from_cache_file=not data_args.overwrite_cache,\n    )\n\n    if data_args.block_size is None:\n        block_size = tokenizer.model_max_length\n        if block_size > config.max_position_embeddings:\n            logger.warning(\n                f\"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). \"\n                \"Picking 1024 instead. You can change that default value by passing --block_size xxx.\"\n            )\n            block_size = 1024\n    else:\n        if data_args.block_size > tokenizer.model_max_length:\n            logger.warning(\n                f\"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model\"\n                f\"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.\"\n            )\n        block_size = min(data_args.block_size, tokenizer.model_max_length)\n\n    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.\n    def group_texts(examples):\n        # Concatenate all texts.\n        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}\n        total_length = len(concatenated_examples[list(examples.keys())[0]])\n        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can\n        # customize this part to your needs.\n        if total_length >= block_size:\n            total_length = (total_length // block_size) * block_size\n        # Split by chunks of max_len.\n        result = {\n            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]\n            for k, t in concatenated_examples.items()\n        }\n        result[\"labels\"] = result[\"input_ids\"].copy()\n        return result\n\n    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder\n    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower\n    # to preprocess.\n    #\n    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:\n    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map\n\n    logger.info(\"***** Build dataset *****\")\n    lm_datasets = tokenized_datasets.map(\n        group_texts,\n        batched=True,\n        num_proc=data_args.preprocessing_num_workers,\n        load_from_cache_file=not data_args.overwrite_cache,\n    )\n\n    if training_args.do_train:\n        if \"train\" not in tokenized_datasets:\n            raise ValueError(\"--do_train requires a train dataset\")\n        train_dataset = lm_datasets[\"train\"]\n        if data_args.max_train_samples is not None:\n            max_train_samples = min(len(train_dataset), data_args.max_train_samples)\n            train_dataset = train_dataset.select(range(max_train_samples))\n\n    if training_args.do_eval:\n        if \"validation\" not in tokenized_datasets:\n            raise ValueError(\"--do_eval requires a validation dataset\")\n        eval_dataset = lm_datasets[\"validation\"]\n        if data_args.max_eval_samples is not None:\n            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)\n            eval_dataset = eval_dataset.select(range(max_eval_samples))\n\n    # Adjust batch size and num_micro_batches for small datasets\n    num_devices = alpa.get_global_num_devices()\n    train_min_batch_size = (num_devices // training_args.operator_parallel //\n                            training_args.pipeline_parallel * training_args.num_micro_batches)\n    eval_num_micro_batches = training_args.num_micro_batches\n    eval_min_batch_size = (num_devices // training_args.operator_parallel //\n                           training_args.pipeline_parallel * eval_num_micro_batches)\n    while len(eval_dataset) < eval_min_batch_size:\n        eval_num_micro_batches //= 2\n        eval_min_batch_size = (num_devices // training_args.operator_parallel //\n                               training_args.pipeline_parallel * eval_num_micro_batches)\n\n    # Enable tensorboard only on the master node\n    has_tensorboard = is_tensorboard_available()\n    if has_tensorboard:\n        try:\n            from flax.metrics.tensorboard import SummaryWriter\n\n            summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))\n        except ImportError as ie:\n            has_tensorboard = False\n            logger.warning(\n                f\"Unable to display metrics through TensorBoard because some package are not installed: {ie}\"\n            )\n    else:\n        logger.warning(\n            \"Unable to display metrics through TensorBoard because the package is not installed: \"\n            \"Please run pip install tensorboard to enable.\"\n        )\n\n    # Initialize our training\n    rng = jax.random.PRNGKey(training_args.seed)\n    rng, dropout_rng = jax.random.split(rng)\n\n    # Store some constant\n    num_epochs = int(training_args.num_train_epochs)\n    train_batch_size = int(training_args.per_device_train_batch_size) * num_devices\n    eval_batch_size = int(training_args.per_device_eval_batch_size) * num_devices\n    steps_per_epoch = len(train_dataset) // train_batch_size\n    total_train_steps = steps_per_epoch * num_epochs\n\n    # Create learning rate schedule\n    linear_decay_lr_schedule_fn = create_learning_rate_fn(\n        len(train_dataset),\n        train_batch_size,\n        training_args.num_train_epochs,\n        training_args.warmup_steps,\n        training_args.learning_rate,\n    )\n\n    # We use Optax's \"masking\" functionality to not apply weight decay\n    # to bias and LayerNorm scale parameters. decay_mask_fn returns a\n    # mask boolean with the same structure as the parameters.\n    # The mask is True for parameters that should be decayed.\n    # Note that this mask is specifically adapted for FlaxGPT2.\n    # For other models, one should correct the layer norm parameter naming\n    # accordingly.\n    def decay_mask_fn(params):\n        flat_params = traverse_util.flatten_dict(params)\n        flat_mask = {\n            path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n            for path in flat_params\n        }\n        return traverse_util.unflatten_dict(flat_mask)\n\n    # create adam optimizer\n    if training_args.adafactor:\n        # We use the default parameters here to initialize adafactor,\n        # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74\n        optimizer = optax.adafactor(\n            learning_rate=linear_decay_lr_schedule_fn,\n        )\n    else:\n        optimizer = optax.chain(\n            optax.clip_by_global_norm(1.0),\n            optax.adamw(\n                learning_rate=linear_decay_lr_schedule_fn,\n                b1=training_args.adam_beta1,\n                b2=training_args.adam_beta2,\n                eps=training_args.adam_epsilon,\n                weight_decay=training_args.weight_decay,\n                mask=decay_mask_fn)\n        )\n\n    # Setup train state\n    if model_args.dtype == \"float16\":\n        use_master_copy = True\n        dynamic_scale = DynamicScale()\n        # Fix a bug in huggingface's implementation (https://github.com/huggingface/transformers/pull/18462)\n        alpa.global_config.flax_always_use_fp16_embedding = True\n    else:\n        use_master_copy = dynamic_scale = None\n    state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer,\n                              dynamic_scale=dynamic_scale, use_master_copy=use_master_copy)\n\n    def loss_fn(logits, labels):\n        shift_logits = logits[..., :-1, :]\n        shift_labels = labels[..., 1:]\n        loss = optax.softmax_cross_entropy(\n            shift_logits,\n            jax.nn.one_hot(shift_labels, logits.shape[-1]))\n        return loss.mean()\n\n    # Define gradient update step fn\n    def train_step(state, batch):\n\n        def compute_loss(params):\n            labels = batch.pop(\"labels\")\n            logits = state.apply_fn(**batch, params=params, deterministic=True)[0]\n            loss = loss_fn(logits, labels)\n            return loss\n\n        dynamic_scale = state.dynamic_scale\n        if dynamic_scale:\n            grad_fn = dynamic_scale.value_and_grad(compute_loss)\n            dynamic_scale, is_fin, loss, grads = grad_fn(state.params)\n        else:\n            grad_fn = alpa.value_and_grad(compute_loss)\n            loss, grads = grad_fn(state.params)\n\n        new_state = state.apply_gradients(grads=grads)\n\n        if dynamic_scale:\n            new_state = new_state.replace(\n                opt_state=jax.tree_map(\n                    functools.partial(jnp.where, is_fin),\n                    new_state.opt_state, state.opt_state),\n                params=jax.tree_map(\n                    functools.partial(jnp.where, is_fin),\n                    new_state.params, state.params),\n                master_copy=jax.tree_map(\n                    functools.partial(jnp.where, is_fin),\n                    new_state.master_copy, state.master_copy),\n                dynamic_scale=dynamic_scale)\n\n        metrics = {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}\n\n        return new_state, metrics\n\n    # Define eval fn\n    def eval_step(params, batch):\n        labels = batch.pop(\"labels\")\n        logits = model(**batch, params=params, deterministic=True)[0]\n        loss = loss_fn(logits, labels)\n\n        # summarize metrics\n        metrics = {\"loss\": loss}\n        return metrics\n\n    # Create parallel version of the train and eval step\n    method = alpa.get_3d_parallel_method(\n            num_micro_batches=training_args.num_micro_batches,\n            data_parallel=-1,\n            operator_parallel=training_args.operator_parallel,\n            pipeline_parallel=training_args.pipeline_parallel)\n\n    p_train_step = alpa.parallelize(train_step,\n                                    method=method,\n                                    donate_argnums=(0,))\n    p_eval_step = alpa.parallelize(eval_step,\n                                   method=alpa.FollowParallel(\n                                       p_train_step, num_micro_batches=eval_num_micro_batches))\n\n    dump_debug_info_train_step = dump_debug_info_eval_step = True\n\n    logger.info(\"***** Running training *****\")\n    logger.info(f\"  Num examples = {len(train_dataset)}\")\n    logger.info(f\"  Num Epochs = {num_epochs}\")\n    logger.info(f\"  Batch size per device (w. accumulation) = {training_args.per_device_train_batch_size}\")\n    logger.info(f\"  Global train batch size (w. parallel & distributed) = {train_batch_size}\")\n    logger.info(f\"  Total optimization steps = {total_train_steps}\")\n\n    train_time = 0\n    train_metrics = []\n    epochs = tqdm(range(num_epochs), desc=\"Epoch ... \", position=0)\n\n    step_ct = 0\n    last_time = time.time()\n\n    epochs.write(\"Initial compilation. This might take some minutes...\")\n\n    for epoch in epochs:\n        # ======================== Training ================================\n        train_start = time.time()\n\n        # Create sampling rng\n        rng, input_rng = jax.random.split(rng)\n\n        # Generate an epoch by shuffling sampling indices from the train dataset\n        train_loader = data_loader(input_rng, train_dataset, train_batch_size,\n                                   train_min_batch_size, shuffle=True)\n        steps_per_epoch = len(train_dataset) // train_batch_size\n        # train\n        for step in tqdm(range(steps_per_epoch), desc=\"Training...\", position=1, leave=False):\n            batch = next(train_loader)\n            batch[\"position_ids\"] = (batch[\"attention_mask\"].cumsum(axis=1) *\n                                     batch[\"attention_mask\"]) - 1\n            state, train_metric = p_train_step(state, batch)\n            train_metrics.append(train_metric)\n\n            cur_step = epoch * (len(train_dataset) // train_batch_size) + step\n\n            if dump_debug_info_train_step:\n                dump_debug_info_train_step = False\n                executable = p_train_step.get_last_executable()\n                executable.sync()\n                executable.dump_debug_info(\"alpa_debug_info\")\n                epochs.write(f\"Initial compilation completed. \"\n                             f\"Time elapsed: {time.time() - train_start:.2f} s\")\n\n            step_ct += 1\n            if cur_step % training_args.logging_steps == 0 and cur_step > 0:\n                executable.sync()\n                latency = (time.time() - last_time) / step_ct\n                throughput_tokens = np.prod(batch[\"input_ids\"].shape) / latency\n                throughput_tflops = alpa.util.compute_gpt_tflops(\n                    batch_size=batch[\"input_ids\"].shape[0],\n                    seq_len=batch[\"input_ids\"].shape[1],\n                    num_layers=config.num_hidden_layers,\n                    hidden_size=config.hidden_size,\n                    vocab_size=config.vocab_size,\n                    num_gpus=alpa.get_global_num_devices(),\n                    latency=latency)\n                step_ct = 0\n\n                # Save metrics\n                train_time += time.time() - train_start\n                if has_tensorboard:\n                    write_train_metric(summary_writer, train_metrics, train_time, cur_step)\n\n                train_metric = jax.tree_map(np.mean, train_metric)\n\n                epochs.write(\n                    f\"Step... {cur_step} | \"\n                    f\"Loss: {train_metric['loss'].mean():.4f}, \"\n                    f\"Learning Rate: {train_metric['learning_rate'].mean():.5f}, \"\n                    f\"Throughput: {throughput_tokens:.2f} token/s, \"\n                    f\"{throughput_tflops:.2f} TFLOP/s\"\n                )\n\n                train_metrics = []\n                last_time = time.time()\n\n            if cur_step % training_args.eval_steps == 0 and cur_step > 0:\n                # ======================== Evaluating ==============================\n                eval_metrics = []\n                eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size,\n                                          eval_min_batch_size)\n                eval_steps = max(len(eval_dataset) // eval_batch_size, 1)\n                for _ in tqdm(range(eval_steps), desc=\"Evaluating...\", position=2, leave=False):\n                    # Model forward\n                    batch = next(eval_loader)\n                    batch[\"position_ids\"] = (batch[\"attention_mask\"].cumsum(axis=1) *\n                                             batch[\"attention_mask\"]) - 1\n                    metrics = p_eval_step(state.params, batch)\n                    eval_metrics.append(metrics)\n\n                    if dump_debug_info_eval_step:\n                        dump_debug_info_eval_step = False\n                        executable = p_eval_step.get_last_executable()\n                        executable.dump_debug_info(\"alpa_debug_info\")\n\n                # normalize eval metrics\n                eval_metrics = alpa.util.get_metrics(eval_metrics)\n                eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n\n                try:\n                    eval_metrics[\"perplexity\"] = math.exp(eval_metrics[\"loss\"])\n                except OverflowError:\n                    eval_metrics[\"perplexity\"] = float(\"inf\")\n\n                # Print metrics and update progress bar\n                desc = (\n                    f\"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:\"\n                    f\" {eval_metrics['perplexity']})\"\n                )\n                epochs.write(desc)\n\n                # Save metrics\n                if has_tensorboard:\n                    write_eval_metric(summary_writer, eval_metrics, cur_step)\n\n            if cur_step % training_args.save_steps == 0 and cur_step > 0:\n                # save checkpoint after each epoch and push checkpoint to the hub\n                epochs.write(\"\\nSave checkpoint...\")\n                alpa.prefetch(state.params)\n                params = alpa.util.map_to_nparray(state.params)\n                model.save_pretrained(training_args.output_dir, params=params)\n                tokenizer.save_pretrained(training_args.output_dir)\n                if training_args.push_to_hub:\n                    repo.push_to_hub(commit_message=f\"Saving weights and logs of step {cur_step}\", blocking=False)\n\n    # Eval after training\n    if training_args.do_eval:\n        eval_metrics = []\n        eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size,\n                                  eval_min_batch_size)\n        eval_steps = max(len(eval_dataset) // eval_batch_size, 1)\n        for _ in tqdm(range(eval_steps), desc=\"Evaluating...\", position=2, leave=False):\n            # Model forward\n            batch = next(eval_loader)\n            batch[\"position_ids\"] = (batch[\"attention_mask\"].cumsum(axis=1) *\n                                     batch[\"attention_mask\"]) - 1\n            metrics = p_eval_step(state.params, batch)\n            eval_metrics.append(metrics)\n\n        # normalize eval metrics\n        eval_metrics = alpa.util.get_metrics(eval_metrics)\n        eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)\n\n        try:\n            eval_metrics[\"perplexity\"] = math.exp(eval_metrics[\"loss\"])\n        except OverflowError:\n            eval_metrics[\"perplexity\"] = float(\"inf\")\n\n        eval_metrics = {f\"eval_{metric_name}\": value for metric_name, value in eval_metrics.items()}\n        path = os.path.join(training_args.output_dir, \"eval_results.json\")\n        with open(path, \"w\") as f:\n            json.dump(eval_metrics, f, indent=4, sort_keys=True)\n\n    # Save the final model\n    epochs.write(\"\\nSave the final model...\")\n    alpa.prefetch(state.params)\n    params = alpa.util.map_to_nparray(state.params)\n    model.save_pretrained(training_args.output_dir, params=params)\n    tokenizer.save_pretrained(training_args.output_dir)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/setup.py",
    "content": "import sys\nfrom setuptools import find_packages, setup\n\nsetup(name=\"llm_serving\",\n      packages=find_packages())\n"
  },
  {
    "path": "examples/slurm_script_examples/test_cuda.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=test_cuda\n#SBATCH -N 1\n#SBATCH -p GPU-shared\n#SBATCH -t 1:00\n#SBATCH --gpus=v100-16:1\n\n#import modules\nmodule purge\nmodule load cuda\nmodule load nvhpc\n\n#check environments\necho $CUDA_HOME\nnvcc --version\n\n#exit\n"
  },
  {
    "path": "examples/slurm_script_examples/test_prerequisites.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=test_alpa_prerequisites\n#SBATCH -p GPU-shared\n#SBATCH -t 1:00\n#SBATCH --gpus=v100-16:1\n\nmodule load cuda\nmodule load cudnn\nmodule load nvhpc\n\nnvcc --version\n"
  },
  {
    "path": "examples/slurm_script_examples/test_ray_multinode.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=ray_multinode_test\n#SBATCH --cpus-per-task=16\n#SBATCH --mem-per-cpu=1GB\n#SBATCH --ntasks-per-node=1\ngpus_per_node=0\n# load modules\nmodule purge\nconda init bash\nsource ~/.bashrc\n# start conda\nconda activate alpa_environment\n# environment activated, check environment\npython3 -V\npython3 -c \"from cupy.cuda import nccl\"\n# Getting the node names\nnodes=$(scontrol show hostnames \"$SLURM_JOB_NODELIST\")\nnodes_array=($nodes)\n\nhead_node=${nodes_array[0]}\nhead_node_ip=$(srun --nodes=1 --ntasks=1 -w \"$head_node\" hostname --ip-address)\n\n# if we detect a space character in the head node IP, we'll\n# convert it to an ipv4 address. This step is optional.\nif [[ \"$head_node_ip\" == *\" \"* ]]; then\nIFS=' ' read -ra ADDR <<<\"$head_node_ip\"\nif [[ ${#ADDR[0]} -gt 16 ]]; then\n  head_node_ip=${ADDR[1]}\nelse\n  head_node_ip=${ADDR[0]}\nfi\necho \"IPV6 address detected. We split the IPV4 address as $head_node_ip\"\nfi\n\n# start head node\nport=6789\nip_head=$head_node_ip:$port\nexport ip_head\n\nsrun --nodes=1 --ntasks=1 -w \"$head_node\" \\\n\tray start --head --node-ip-address=\"$head_node_ip\" --port=$port \\\n\t--num-cpus \"${SLURM_CPUS_PER_TASK}\" --num-gpus $gpus_per_node --block &\n\n# start worker nodes\n# number of nodes other than the head node\nworker_num=$((SLURM_JOB_NUM_NODES - 1))\n\nfor ((i = 1; i <= worker_num; i++)); do\n    node_i=${nodes_array[$i]}\n    echo \"Starting WORKER $i at $node_i\"\n    srun --nodes=1 --ntasks=1 -w \"$node_i\" \\\n\tray start --address \"$ip_head\" --num-cpus \"${SLURM_CPUS_PER_TASK}\" \\\n\t--num-gpus $gpus_per_node --block &\n    sleep 5\ndone\n# try ray\necho \"test ray status\"\nray list nodes --address \"$ip_head\"\nray list nodes\nray list actors\nray summary tasks\n# end ray\nray stop\n# exit environment\nconda deactivate\nexit\n"
  },
  {
    "path": "examples/slurm_script_examples/textgen_alpa_test.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=ray_singlenode_test\n# load modules\nmodule purge\nmodule load cuda\nmodule load nvhpc\nconda init bash\nsource ~/.bashrc\n# test nvcc\nnvcc --version\n# start environment using conda\nconda activate alpa_environment\n# start ray on head\nray start --head\n# start alpa textgen.py\npython3 alpa/examples/llm_serving/textgen.py --model alpa/bloom-560m --n-prompts 1 --path $PROJECT/alpa_weights\n# end ray\nray stop\n# exit environment\nconda deactivate\nexit\n"
  },
  {
    "path": "examples/slurm_script_examples/textgen_pt_test.sh",
    "content": "#!/bin/bash\n#SBATCH --job-name=ray_singlenode_test\n# load modules\nmodule purge\nmodule load cuda\nmodule load nvhpc\nconda init bash\nsource ~/.bashrc\n# test nvcc\nnvcc --version\n# start environment using conda\nconda activate alpa_environment\n# start ray on head\nray start --head\n# start alpa textgen.py\npython3 alpa/examples/llm_serving/textgen.py --model facebook/opt-125m --n-prompts 1 --path $PROJECT/alpa_weights\n# end ray\nray stop\n# exit environment\nconda deactivate\nexit\n"
  },
  {
    "path": "format.sh",
    "content": "#!/usr/bin/env bash\n# YAPF formatter, adapted from ray and sky.\n#\n# Usage:\n#    # Do work and commit your work.\n#\n#    # Format files that differ from origin/main.\n#    bash format.sh\n#\n#    # Commit changed files with message 'Run yapf and pylint'\n#\n# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.\n# You are encouraged to run this locally before pushing changes for review.\n\n# Cause the script to exit if a single command fails\nset -eo pipefail\n\n# this stops git rev-parse from failing if we run this from the .git directory\nbuiltin cd \"$(dirname \"${BASH_SOURCE:-$0}\")\"\nROOT=\"$(git rev-parse --show-toplevel)\"\nbuiltin cd \"$ROOT\" || exit 1\n\nYAPF_VERSION=$(yapf --version | awk '{print $2}')\nPYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')\n\n# params: tool name, tool version, required version\ntool_version_check() {\n    if [[ $2 != $3 ]]; then\n        echo \"Wrong $1 version installed: $3 is required, not $2.\"\n        exit 1\n    fi\n}\n\ntool_version_check \"yapf\" $YAPF_VERSION \"0.32.0\"\ntool_version_check \"pylint\" $PYLINT_VERSION \"2.14.0\"\n\nYAPF_FLAGS=(\n    '--style' \"$ROOT/.style.yapf\"\n    '--recursive'\n    '--parallel'\n)\n\nYAPF_EXCLUDES=(\n    '--exclude' 'benchmark/cupy/*'\n    '--exclude' 'benchmark/alpa/old_backup/*'\n    '--exclude' 'benchmark/deepspeed/*'\n    '--exclude' 'benchmark/megatron/*'\n    '--exclude' 'build_jaxlib/*'\n    '--exclude' 'docs/*'\n    '--exclude' 'examples/*'\n    '--exclude' 'playground/*'\n    '--exclude' 'third_party/*'\n)\n\n# Format specified files\nformat() {\n    yapf --in-place \"${YAPF_FLAGS[@]}\" \"$@\"\n}\n\n# Format files that differ from main branch. Ignores dirs that are not slated\n# for autoformat yet.\nformat_changed() {\n    # The `if` guard ensures that the list of filenames is not empty, which\n    # could cause yapf to receive 0 positional arguments, making it hang\n    # waiting for STDIN.\n    #\n    # `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that\n    # exist on both branches.\n    MERGEBASE=\"$(git merge-base origin/main HEAD)\"\n\n    if ! git diff --diff-filter=ACM --quiet --exit-code \"$MERGEBASE\" -- '*.py' &>/dev/null; then\n        git diff --name-only --diff-filter=ACM \"$MERGEBASE\" -- '*.py' | xargs -P 5 \\\n             yapf --in-place \"${YAPF_EXCLUDES[@]}\" \"${YAPF_FLAGS[@]}\"\n    fi\n\n}\n\n# Format all files\nformat_all() {\n    yapf --in-place \"${YAPF_FLAGS[@]}\" \"${YAPF_EXCLUDES[@]}\" alpa tests benchmark\n}\n\n## This flag formats individual files. --files *must* be the first command line\n## arg to use this option.\nif [[ \"$1\" == '--files' ]]; then\n   format \"${@:2}\"\n   # If `--all` is passed, then any further arguments are ignored and the\n   # entire python directory is formatted.\nelif [[ \"$1\" == '--all' ]]; then\n   format_all\nelse\n   # Format only the files that changed in last commit.\n   format_changed\nfi\n\n# Run Pylint\necho 'Alpa Pylint:'\npylint alpa\n\n# Run Pylint on tests (TODO(zhuohan) enable linting on tests)\n# echo 'Alpa Tests Pylint:'\n# pylint tests\n\nif ! git diff --quiet &>/dev/null; then\n    echo 'Reformatted files. Please review and stage the changes.'\n    echo 'Changes not staged for commit:'\n    echo\n    git --no-pager diff --name-only\n\n    exit 1\nfi\n"
  },
  {
    "path": "playground/alpa_micro_benchmark/benchmark_dist_save_load.py",
    "content": "import os\nimport subprocess\nimport time\n\nfrom flax.training.checkpoints import save_checkpoint, restore_checkpoint\nimport jax\nimport jax.numpy as jnp\nfrom jax import random\nimport numpy as np\n\nimport alpa\nfrom alpa import save_checkpoint as alpa_save_checkpoint\nfrom alpa import restore_checkpoint as alpa_restore_checkpoint\nfrom alpa import PipeshardParallel, DistributedArray\nfrom alpa.testing import (MLPModel, create_train_state, get_mlp_train_step)\nfrom alpa.device_mesh import get_global_cluster\n\n\ndef _get_efs_mount_point():\n    # Hacky function to get the EFS mount point\n    for line in subprocess.check_output(\"df -h\",\n                                        shell=True).decode().split('\\n'):\n        cols = line.split(' ')\n        if \"efs\" in cols[0]:\n            return cols[-1] + \"/\"\n    return None\n\n\ndef _get_save_prefix(to_efs):\n    if to_efs:\n        # Get EFS mount point for the multi-host test\n        save_prefix = _get_efs_mount_point()\n    else:\n        # Use tmp dir for the single-host test\n        save_prefix = \"/tmp/\"\n    return save_prefix\n\n\nLOOP_CNT = 2\n\n\ndef benchmark_ndarray_save_load(mode=\"flax\", to_efs=True):\n    \"\"\"\n    EFS performance: https://docs.aws.amazon.com/efs/latest/ug/performance.html\n\n    if mode == \"flax\": use flax.training.checkpoints.save_checkpoint/restore_checkpoint\n    elif mode == \"alpa\": use alpa.serialization.save_checkpoint/restore_checkpoint\n    elif mode == \"numpy: use np.save/load\n\n    Benchmark results on EFS: \n    - flax.save_checkpoint:    save average run time: 15.0580 seconds, save average throughput: 0.5313 Gbps\n    - flax.restore_checkpoint: load average run time:  6.8287 seconds, load average throughput: 1.2225 Gbps\n\n    - alpa.save_checkpoint:    save average run time: 12.8583 seconds, save average throughput: 0.6222 Gbps\n                 use cache:    \n    - alpa.restore_checkpoint: N/A\n\n    - np.save:                 save average run time: 10.4157 seconds, save average throughput: 0.7682 Gbps\n    - np.load:                 load average run time:  2.9987 seconds, load average throughput: 4.9950 Gbps\n\n    Benchmark results on local filesystem:\n    - flax.save_checkpoint:    save average run time: 5.5268 seconds, save average throughput: 1.4475 Gbps\n    - flax.restore_checkpoint: load average run time: 5.1856 seconds, load average throughput: 1.5428 Gbps\n\n    - alpa.save_checkpoint:    save average run time: 10.3145 seconds, save average throughput: 0.7756 Gbps\n    - alpa.restore_checkpoint: N/A\n\n    - np.save:                 save average run time: 0.8104 seconds, save average throughput:  9.8718 Gbps\n    - np.load:                 load average run time: 0.7327 seconds, load average throughput: 10.9179 Gbps\n    \"\"\"\n    rngkey = random.PRNGKey(0)\n    #arr_sizes = [1024*1024, 4*1024*1024, 16*1024*1024, 32*1024*1024] # 4M, 16M, 64M, 128M\n    arr_sizes = [256 * 1024 * 1024]  # 1G\n    benchmark_arrs = [\n        random.normal(rngkey, (arr_size,)) for arr_size in arr_sizes\n    ]\n    for arr in benchmark_arrs:\n        save_tot_duration = 0.0\n        save_tot_throughput = 0.0\n        load_tot_duration = 0.0\n        load_tot_throughput = 0.0\n        prefix = _get_save_prefix(to_efs)\n        for i in range(LOOP_CNT):\n            assert (prefix is not None)\n            outdir = os.path.join(prefix, \"benchmark_checkpoint\")\n            # clean working directory\n            subprocess.run([\"rm\", \"-rf\", outdir])\n            # rebuild working directory\n            os.mkdir(outdir)\n            print(f\"save to {outdir}\")\n            ckpt_path = os.path.join(outdir, \"checkpoint_1.npy\")  # numpy-only\n\n            # save benchmark\n            start = time.time()\n            if mode == \"flax\":\n                save_checkpoint(outdir, arr, i)\n            elif mode == \"alpa\":\n                alpa_save_checkpoint(outdir, arr, i, \"/tmp\")\n            else:\n                np.save(ckpt_path, arr)\n            duration = time.time() - start\n            throughput = arr.size * 32 / 1024 / 1024 / 1024 / duration\n            if i >= 1:\n                save_tot_duration += duration\n                save_tot_throughput += throughput\n            print(\n                f\"loop {i} save, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps\"\n            )\n\n            gpus = jax.devices(\"gpu\")\n            # load benchmark\n            start = time.time()\n            if mode == \"flax\":\n                restore_checkpoint(outdir, None, None)\n            elif mode == \"alpa\":\n                print(\"alpa skip load array benchmark\")\n                continue\n            else:\n                jax.block_until_ready(\n                    jax.device_put(np.load(ckpt_path), gpus[0]))\n\n            duration = time.time() - start\n            throughput = arr.size * 32 / 1024 / 1024 / 1024 / duration\n            if i >= 1:\n                load_tot_duration += duration\n                load_tot_throughput += throughput\n            print(\n                f\"loop {i} load, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps\"\n            )\n\n        print(\n            f\"save average run time: {save_tot_duration/(LOOP_CNT - 1):.4f} seconds, save average throughput: {save_tot_throughput/(LOOP_CNT - 1):.4f} Gbps\"\n        )\n        print(\n            f\"load average run time: {load_tot_duration/(LOOP_CNT - 1):.4f} seconds, load average throughput: {load_tot_throughput/(LOOP_CNT - 1):.4f} Gbps\"\n        )\n\n\ndef count_params(model):\n    return sum(x.size for x in jax.tree_leaves(model))\n\n\ndef benchmark_mlp_save(mode=\"flax\", to_efs=True):\n    \"\"\"\n    Benchmark results on EFS: \n    - flax.save_checkpoint: average run time: 45.19087886810303 seconds, average throughput: 0.5313484040513637 Gbps\n    - alpa.save_checkpoint: average run time: 16.15189399719238, average throughput: 1.4860819837013484 Gbps\n                 use cache: \n    - np.save:              average run time: 20.618193340301513, average throughput: 1.1642373201358331 Gbps\n\n    Benchmark results on local disk:\n    - flax.save_checkpoint: average run time: 16.1341721534729, average throughput: 1.4877078603042466 Gbps\n    - alpa.save_checkpoint: average run time: 10.663438653945922, average throughput: 2.2509621962263244 Gbps\n    - np.save:              average run time: 20.618193340301513, average throughput: 1.1642373201358331 Gbps\n    \"\"\"\n    # Init model and optimizer\n    batch_size = 64\n    hidden_dim = 8192  # 3072M\n    input_dim = output_dim = hidden_dim\n    model = MLPModel(hidden_dim=hidden_dim,\n                     output_dim=output_dim,\n                     manual_pipeline_layer=True)\n\n    # Init batch args\n    rngkey = random.PRNGKey(0)\n    x = random.normal(rngkey, (batch_size, input_dim), jnp.float32)\n    state = create_train_state(rngkey, model, [x])\n    model_size = count_params(state)\n    print(f\"model size: {model_size * 4 / 1024 / 1024} MB\")\n\n    tot_duration = 0.0\n    tot_throughput = 0.0\n    prefix = _get_save_prefix(to_efs)\n    for i in range(LOOP_CNT):\n        assert (prefix is not None)\n        outdir = os.path.join(prefix, \"benchmark_checkpoint\")\n        ckpt_path = os.path.join(outdir, f\"checkpoint_1.npy\")  # numpy-only\n        # clean working directory\n        subprocess.run([\"rm\", \"-rf\", outdir])\n        # rebuild working directory\n        os.mkdir(outdir)\n        print(f\"save to {outdir}\")\n\n        start = time.time()\n        if mode == \"flax\":\n            save_checkpoint(outdir, state, i)\n        elif mode == \"alpa\":\n            alpa_save_checkpoint(outdir, state, i, \"/tmp\")\n        else:\n            np.save(ckpt_path, state.params)\n            np.save(ckpt_path, state.opt_state)\n        duration = time.time() - start\n\n        throughput = model_size * 32 / 1024 / 1024 / 1024 / duration\n        tot_duration += duration\n        tot_throughput += throughput\n        print(\n            f\"loop {i}, time: {duration} seconds, throughput: {throughput} Gbps\"\n        )\n    print(\n        f\"average run time: {tot_duration/LOOP_CNT}, average throughput: {tot_throughput/LOOP_CNT} Gbps\"\n    )\n\n\ndef benchmark_dist_arr_save(to_efs=False):\n    \"\"\"\n    Benchmark results on local disk:\n    - one host:\n        - TensorStore: save average run time: 9.9292 seconds, save average throughput: 0.8057 Gbps\n        - np.save      save average run time: 0.8113 seconds, save average throughput: 9.8601 Gbps\n\n    - two hosts:\n        - TensorStore: save average run time: 3.9092 seconds, save average throughput: 2.0465 Gbps\n        - np.save:     save average run time: 0.4702 seconds, save average throughput: 17.0149 Gbps\n    \"\"\"\n    device_cluster = get_global_cluster()\n    physical_mesh = device_cluster.get_physical_mesh()\n    logical_mesh = physical_mesh.get_logical_mesh()\n\n    rngkey = random.PRNGKey(0)\n    arr_shape = (64 * 1024, 16 * 1024)  #1GB\n    arr = random.normal(rngkey, arr_shape)\n\n    sharding_spec = logical_mesh.make_tile_spec(arr, [0, 1], [0, 1])\n    input_indices = sharding_spec.indices(arr.shape).flatten()\n    (dist_arr,) = physical_mesh.shard_args_to_arrays(\n        (jax.ShapedArray(arr.shape, jnp.int32),), (input_indices,),\n        (sharding_spec,), (arr,))\n\n    save_tot_duration = 0.0\n    save_tot_throughput = 0.0\n    outdir = \"/tmp/benchmark_save\"\n    for i in range(LOOP_CNT):\n        # Save the DistributedArray (one replica only)\n        subprocess.run([\"rm\", \"-rf\", outdir])\n        print(f\"save to {outdir}\")\n\n        start = time.time()\n        jax.block_until_ready(dist_arr.save(outdir))\n        duration = time.time() - start\n        throughput = arr.size * 32 / 1024 / 1024 / 1024 / duration\n        if i >= 1:\n            save_tot_duration += duration\n            save_tot_throughput += throughput\n        print(\n            f\"loop {i} save, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps\"\n        )\n    print(\n        f\"save average run time: {save_tot_duration/(LOOP_CNT - 1):.4f} seconds, save average throughput: {save_tot_throughput/(LOOP_CNT - 1):.4f} Gbps\"\n    )\n\n\ndef benchmark_dist_arr_load():\n    \"\"\"\n    Benchmark results on local disk:\n    - one host:\n        TensorStore: load average run time: 4.0709 seconds, load average throughput: 1.9651 Gbps\n        np.load:     load average run time: 1.5235 seconds, load average throughput: 5.2512 Gbps\n    \n    - two hosts:\n        TensorStore: load average run time: 3.6650 seconds, load average throughput: 2.1828 Gbps\n        np.load:     load average run time: 0.7644 seconds, load average throughput: 10.4655 Gbps\n    \"\"\"\n    device_cluster = get_global_cluster()\n    physical_mesh = device_cluster.get_physical_mesh()\n    logical_mesh = physical_mesh.get_logical_mesh()\n\n    rngkey = random.PRNGKey(0)\n    arr_shape = (64 * 1024, 16 * 1024)  #1GB\n    arr = random.normal(rngkey, arr_shape)\n\n    sharding_spec = logical_mesh.make_tile_spec(arr, [0, 1], [0, 1])\n\n    load_tot_duration = 0.0\n    load_tot_throughput = 0.0\n    outdir = \"/tmp/benchmark_save\"\n    for i in range(LOOP_CNT):\n        print(f\"load from {outdir}\")\n\n        # load benchmark\n        start = time.time()\n        print(\"start\", time.time())\n        jax.block_until_ready(\n            DistributedArray.load(outdir, jax.ShapedArray(arr.shape, jnp.int32),\n                                  physical_mesh, sharding_spec))\n        print(\"end\", time.time())\n        duration = time.time() - start\n        throughput = arr.size * 32 / 1024 / 1024 / 1024 / duration\n        if i >= 1:\n            load_tot_duration += duration\n            load_tot_throughput += throughput\n        print(\n            f\"loop {i} load, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps\"\n        )\n    print(\n        f\"load average run time: {load_tot_duration/(LOOP_CNT - 1):.4f} seconds, load average throughput: {load_tot_throughput/(LOOP_CNT - 1):.4f} Gbps\"\n    )\n\n\ndef benchmark_mlp_dist_save():\n    \"\"\"\n    Benchmark results on EFS:\n    - alpa.save_checkpoint:\n        save average run time: 161.8653 seconds, save average throughput: 0.1483 Gbps\n        load average run time:  40.2772 seconds, load average throughput: 0.5965 Gbps\n    \n    Benchmark results on local disk:\n    - one host:\n        np.save (batch version) save average run time: 1.3313 seconds, save average throughput: 18.0300 Gbps\n\n    - two hosts:\n        TensorStore:            save average run time: 19.9880 seconds, save average throughput: 1.2009 Gbps\n        np.save:                save average run time:  2.4631 seconds, save average throughput: 9.7452 Gbps\n        np.save (batch version) save average run time: 1.2081 seconds, save average throughput: 19.8683 Gbps\n    \n    - four hosts:\n        np.save (batch version) \n    \"\"\"\n    # Init model and optimizer\n    batch_size = 64\n    hidden_dim = 8192  # 3072M\n    input_dim = output_dim = hidden_dim\n    model = MLPModel(hidden_dim=hidden_dim,\n                     output_dim=output_dim,\n                     manual_pipeline_layer=True)\n\n    # Init batch args\n    rngkey = random.PRNGKey(0)\n    x = random.normal(rngkey, (batch_size, input_dim), jnp.float32)\n    y = jax.random.normal(rngkey, (batch_size, output_dim), jnp.float32)\n    batch = {'x': x, 'y': y}\n\n    state = create_train_state(rngkey, model, [x])\n    model_size = count_params(state)\n    print(f\"model size: {model_size * 4 / 1024 / 1024} MB\")\n\n    # Compile\n    method = PipeshardParallel(num_micro_batches=2)\n    parallel_train_step = get_mlp_train_step(method, True, False, False)\n    parallel_state = parallel_train_step(state, batch)[0]\n\n    save_tot_duration = 0.0\n    save_tot_throughput = 0.0\n    outdir = \"/home/ubuntu/efs/benchmark_mlp_save\"\n    cachedir = \"/tmp/benchmark_mlp_save\"\n    for i in range(LOOP_CNT):\n        subprocess.run([\"rm\", \"-rf\", outdir])\n        subprocess.run([\"rm\", \"-rf\", cachedir])\n        print(f\"save to {outdir}\")\n        # benchmark saving\n        start = time.time()\n        if i == 0:\n            alpa_save_checkpoint(\"/tmp/warmup\", parallel_state, 1)\n            jax.block_until_ready(parallel_state)\n        else:\n            alpa_save_checkpoint(outdir, parallel_state, 1, cachedir)\n            #alpa_save_checkpoint(\"/tmp/warmup\", parallel_state, 1)\n            jax.block_until_ready(parallel_state)\n        duration = time.time() - start\n        throughput = model_size * 32 / 1024 / 1024 / 1024 / duration\n        if i >= 1:\n            save_tot_duration += duration\n            save_tot_throughput += throughput\n        print(\n            f\"loop {i} save, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps\"\n        )\n\n    print(\n        f\"save average run time: {save_tot_duration/(LOOP_CNT - 1):.4f} seconds, save average throughput: {save_tot_throughput/(LOOP_CNT - 1):.4f} Gbps\"\n    )\n\n\ndef benchmark_mlp_dist_load():\n    \"\"\"\n    Benchmark results on local disk:\n    - one hosts:\n        np.load (batch version) load average run time: 1.6670 seconds, load average throughput: 14.3985 Gbps\n\n    - two hosts:\n        TensorStore:            load average run time: 4.4443 seconds, load average throughput: 5.4008 Gbps\n        np.load:                load average run time: 3.2214 seconds, load average throughput: 7.4511 Gbps\n        np.load (batch version) load average run time: 1.6163 seconds, load average throughput: 14.8510 Gbps\n    \n    - four hosts:\n        np.load (batch version) \n    \"\"\"\n    # Init model and optimizer\n    batch_size = 64\n    hidden_dim = 8192  # 3072M\n    input_dim = output_dim = hidden_dim\n    model = MLPModel(hidden_dim=hidden_dim,\n                     output_dim=output_dim,\n                     manual_pipeline_layer=True)\n\n    # Init batch args\n    rngkey = random.PRNGKey(0)\n    x = random.normal(rngkey, (batch_size, input_dim), jnp.float32)\n    y = jax.random.normal(rngkey, (batch_size, output_dim), jnp.float32)\n    batch = {'x': x, 'y': y}\n\n    state = create_train_state(rngkey, model, [x])\n    model_size = count_params(state)\n    print(f\"model size: {model_size * 4 / 1024 / 1024} MB\")\n\n    # Compile\n    method = PipeshardParallel(num_micro_batches=2)\n    parallel_train_step = get_mlp_train_step(method, True, False, False)\n    executable = parallel_train_step.get_executable(state, batch)\n    state_ss, _ = executable.get_load_info()\n    _ = parallel_train_step(state, batch)[0]\n\n    load_tot_duration = 0.0\n    load_tot_throughput = 0.0\n    outdir = \"/tmp/benchmark_mlp_load\"\n    for i in range(LOOP_CNT):\n        print(f\"load from {outdir}\")\n        # benchmark loading\n        start = time.time()\n        load_state = alpa_restore_checkpoint(outdir, 1, state_ss)\n        jax.block_until_ready(load_state)\n        duration = time.time() - start\n        throughput = model_size * 32 / 1024 / 1024 / 1024 / duration\n        if i >= 1:  # first loop for warmup\n            load_tot_duration += duration\n            load_tot_throughput += throughput\n        print(\n            f\"loop {i} load, time: {duration:.4f} seconds, throughput: {throughput:.4f} Gbps\"\n        )\n\n    print(\n        f\"load average run time: {load_tot_duration/(LOOP_CNT - 1):.4f} seconds, load average throughput: {load_tot_throughput/(LOOP_CNT - 1):.4f} Gbps\"\n    )\n\n\nif __name__ == \"__main__\":\n    alpa.init(cluster=\"ray\")\n    # print(\"ndarray benchmark on EFS:\")\n    # print(\"flax\")\n    # benchmark_ndarray_save_load(mode=\"flax\")\n    # print(\"\\nalpa\")\n    # benchmark_ndarray_save_load(mode=\"alpa\")\n    # print(\"\\nnumpy\")\n    # benchmark_ndarray_save_load(mode=\"numpy\")\n\n    # print(\"\\n\\nndarray benchmark on local disk:\")\n    # print(\"flax\")\n    # benchmark_ndarray_save_load(mode=\"flax\", to_efs=False)\n    # print(\"\\nalpa\")\n    # benchmark_ndarray_save_load(mode=\"alpa\", to_efs=False)\n    # print(\"\\nnumpy\")\n    # benchmark_ndarray_save_load(mode=\"numpy\", to_efs=False)\n\n    # print(\"mlp benchmark on EFS:\")\n    # benchmark_mlp_save(mode=\"flax\")\n    # benchmark_mlp_save(mode=\"alpa\")\n    # benchmark_mlp_save(mode=\"numpy\")\n\n    # print(\"mlp benchmark on local disk:\")\n    # benchmark_mlp_save(mode=\"flax\", to_efs=False)\n    # benchmark_mlp_save(mode=\"alpa\", to_efs=False)\n    # benchmark_mlp_save(mode=\"numpy\", to_efs=False)\n\n    # print(\"dist array save/load benchmark:\")\n    # benchmark_dist_arr_save()\n    # benchmark_dist_arr_load()\n\n    # print(\"mlp dist save/load benchmark:\")\n    # benchmark_mlp_dist_save()\n    benchmark_mlp_dist_load()\n    alpa.shutdown()\n"
  },
  {
    "path": "playground/alpa_micro_benchmark/test_export_hlo.py",
    "content": "\"\"\"Benchmark one case of intra-op only parallelism.\"\"\"\nfrom flax import linen as nn\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\nimport alpa\nfrom alpa import (parallelize, global_config, LocalPhysicalDeviceMesh,\n                  ShardParallel, AutoShardingOption)\nfrom alpa.model.bert_model import BertConfig, FlaxBertForMaskedLMModule, TrainState\nfrom alpa.model.gpt_model import FlaxGPTForLMModule\nfrom alpa.timer import timers\nfrom alpa.util import map_to_shape, count_communication_primitives, print_used_time, GB\n\n\ndef compute_gpt_parameter_count(num_layers, hidden_size, vocab_size):\n    return num_layers * (\n        # self-attention\n        hidden_size * (3 * hidden_size + 1) + hidden_size * (hidden_size + 1) +\n        # mlp\n        hidden_size * (4 * hidden_size + 1) + hidden_size * 4 *\n        (hidden_size + 1) +\n        # layer norm\n        hidden_size * 4) + vocab_size * (hidden_size + 1)\n\n\ndef create_train_state(rngkey, model, dtype, batch):\n    params = model.init_dummy(rngkey, batch[\"input_ids\"],\n                              batch[\"attention_mask\"], batch[\"token_type_ids\"],\n                              batch[\"position_ids\"])\n\n    def weight_decay_mask(pytree):\n        # do not use weight decay on layer norm and bias.\n        return jax.tree_map(lambda x: x.ndim > 1, pytree)\n\n    tx = optax.chain(\n        #optax.clip_by_global_norm(1.0),  # TODO(lmzheng): fix reduce-scatter for this\n        optax.adamw(learning_rate=1e-2, mask=weight_decay_mask))\n\n    mixed_precision = (dtype == jnp.float16)\n\n    state = TrainState.create(apply_fn=model.apply,\n                              params=params,\n                              tx=tx,\n                              mixed_precision=mixed_precision,\n                              dynamic_scale=None)\n    return state\n\n\ndef create_train_state_aval(rngkey, model, batch, dtype):\n    params = jax.eval_shape(model.init, rngkey, batch[\"input_ids\"],\n                            batch[\"attention_mask\"], batch[\"token_type_ids\"],\n                            batch[\"position_ids\"])\n\n    def weight_decay_mask(pytree):\n        # do not use weight decay on layer norm and bias.\n        return jax.tree_map(lambda x: x.ndim > 1, pytree)\n\n    tx = optax.chain(\n        #optax.clip_by_global_norm(1.0),  # TODO(lmzheng): fix reduce-scatter for this\n        optax.adamw(learning_rate=1e-2, mask=weight_decay_mask))\n    mixed_precision = (dtype == jnp.float16)\n    state = TrainState.create_aval(apply_fn=model.apply,\n                                   params=params,\n                                   tx=tx,\n                                   mixed_precision=mixed_precision,\n                                   dynamic_scale=None)\n    return state\n\n\ndef get_train_step(grad_func, method):\n\n    @parallelize(method=method)\n    def train_step(state, batch, rng_key):\n\n        def loss_func(params):\n            rngs = {\"dropout\": rng_key}\n            logits = state.apply_fn(params,\n                                    batch[\"input_ids\"],\n                                    batch[\"attention_mask\"],\n                                    batch[\"token_type_ids\"],\n                                    batch[\"position_ids\"],\n                                    deterministic=True,\n                                    rngs=rngs)[0]\n            label_mask = jnp.where(batch[\"labels\"] > 0, 1.0, 0.0)\n            labels = jax.nn.one_hot(batch[\"labels\"], logits.shape[-1])\n            loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1),\n                            axis=-1)\n            loss = (label_mask * loss).sum() / label_mask.sum()\n            return loss\n\n        grads = grad_func(loss_func)(state.params)\n        new_state = state.apply_gradients(grads=grads)\n        # TODO(lmzheng): add dynamic scaling for mixed-precision training\n        return new_state\n\n    return train_step\n\n\ndef benchmark_2d_one_case_gpt_bert(physical_mesh, model_type, benchmark_case):\n    print_used_time(None)\n\n    # Model configs\n    (batch_size, seq_len, hidden_size, num_layers, num_heads, vocab_size,\n     num_micro_batches, parallel_mode, parallel_args) = benchmark_case\n    (prefer_reduce_scatter, use_remat, (dp, op, pp),\n     force_batch_dim_mapping) = parallel_args\n\n    dtype = jnp.float16\n\n    # Parallel configs\n    assert pp == 1, \"Do not support pipeline parallelism\"\n    if num_micro_batches > 1:\n        grad_func = alpa.grad\n    else:\n        num_micro_batches = None\n        grad_func = jax.grad\n\n    as_option = AutoShardingOption()\n    if force_batch_dim_mapping:  # Always map batch dim to mesh dim 0\n        as_option.force_batch_dim_to_mesh_dim = 0\n    as_option.prefer_reduce_scatter = prefer_reduce_scatter\n    if parallel_mode == \"zero-3\":\n        as_option.force_zero_stage_3 = True\n    elif parallel_mode in [\"shard-largest\"]:\n        as_option.force_simple_heuristic = other\n        global_config.remat_using_while = True\n\n    logical_mesh = physical_mesh.get_logical_mesh([dp, op])\n    method = ShardParallel(devices=logical_mesh,\n                           num_micro_batches=num_micro_batches,\n                           auto_sharding_option=as_option)\n    print_used_time(\"Setup device mesh\")\n\n    # Prepare input batch\n    batch = {\n        \"input_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"attention_mask\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"token_type_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"position_ids\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n        \"labels\": jnp.ones((batch_size, seq_len), dtype=jnp.int32),\n    }\n    print_used_time(\"Prepare input\")\n\n    # Init train state\n    if model_type == \"gpt\":\n        model = FlaxGPTForLMModule(BertConfig(\n            num_hidden_layers=num_layers,\n            hidden_size=hidden_size,\n            intermediate_size=hidden_size * 4,\n            num_attention_heads=num_heads,\n            vocab_size=vocab_size,\n            max_position_embeddings=seq_len,\n            type_vocab_size=0,\n            gradient_checkpointing=use_remat,\n        ),\n                                   dtype=dtype)\n    elif model_type == \"bert\":\n        model = FlaxBertForMaskedLMModule(BertConfig(\n            num_hidden_layers=num_layers,\n            hidden_size=hidden_size,\n            intermediate_size=hidden_size * 4,\n            num_attention_heads=num_heads,\n            vocab_size=vocab_size,\n            max_position_embeddings=seq_len,\n            type_vocab_size=0,\n            gradient_checkpointing=use_remat,\n        ),\n                                          dtype=dtype)\n    else:\n        raise ValueError(f\"Invalid model {model_type}\")\n\n    rngkey = jax.random.PRNGKey(0)\n    state = create_train_state_aval(rngkey, model, batch, dtype)\n    print_used_time(\"Create train state\")\n\n    # Compile executable\n    train_step = get_train_step(grad_func, method)\n    executable = train_step.get_executable(state, batch, rngkey)\n    print_used_time(\"Compile (driver)\")\n\n    return executable\n\n\nif __name__ == \"__main__\":\n    global_config.xla_gpu_autotune_level = 0\n    model_type = \"gpt\"\n\n    num_nodes = 2\n    num_devices_per_node = 8\n    _ = None\n\n    # B = batch_size, S = seq_len, H = hidden_size, L = num_layers, V = vocab_size\n    # head = num_heads,\n    # NB = num_micro_batches, PM = parallel_mode\n    # 3D config = 3D parallel config (Data, Operator, Pipeline)\n    # RS = prefer_reduce_scatter, Remat = use_rematerialization,\n    # FM = force_batch_dim_mapping\n    #B,  S,     H      L,  #head, V,     NB,\n    benchmark_case = (\n        8,\n        1024,\n        1024,\n        6,\n        32,\n        51200,\n        1,\n        #PM,        RS,    Remat, 3D config,  FM\n        \"manual\",\n        (False, True, (2, 8, 1), False))\n    num_devices = num_nodes * num_devices_per_node\n\n    num_layers, hidden_size, vocab_size = (benchmark_case[3], benchmark_case[2],\n                                           benchmark_case[5])\n    param_count = compute_gpt_parameter_count(num_layers, hidden_size,\n                                              vocab_size)\n    print(f\"Param count: {param_count/1e9:.2f} B\")\n\n    # Define a fake physical mesh\n    physical_mesh = LocalPhysicalDeviceMesh(devices=[None] * num_devices)\n\n    # Compile a mesh executable\n    executable = benchmark_2d_one_case_gpt_bert(physical_mesh, \"gpt\",\n                                                benchmark_case)\n    print(f\"Auto sharding time: {timers('auto-sharding').elapsed():.2f} s\\n\")\n\n    # Write hlo ir to a file\n    print(\"Write hlo module to files...\")\n    with open(\"optimized_hlo.txt\", \"w\") as fout:\n        hlo_text = executable.get_hlo_text()\n        fout.write(hlo_text)\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all =\\\n            count_communication_primitives(hlo_text)\n        print(\n            f\"#total: {n_total}, #all-reduce: {n_all_reduce}, \"\n            f\"#all-gather: {n_all_gather}, #reduce-scatter: {n_reduce_scatter}, \"\n            f\"#all-to-all: {n_all_to_all}\")\n        print(\n            f\"Allocation: {executable.get_total_allocation_size() / (1<<30):.2f} GB\"\n        )\n\n    with open(\"after_spmd_partitioner_hlo.txt\", \"w\") as fout:\n        fout.write(executable.hlo_module.to_string())\n\n    with open(\"executable_hlo.proto\", \"wb\") as fout:\n        fout.write(executable.hlo_module.as_serialized_hlo_module_proto())\n\n    # Get the sharding specs of the inputs and outputs of the hlo module\n    # print(executable.input_sharding_specs)\n    # print(executable.output_sharding_specs)\n"
  },
  {
    "path": "playground/alpa_micro_benchmark/test_shard_array.py",
    "content": "import jax\nimport jax.numpy as jnp\nfrom jax.interpreters import pxla\nfrom jax.interpreters.pxla import (ShardingSpec,\n    NoSharding, Replicated, Chunked, ShardedAxis)\nimport numpy as np\nimport ray\n\nimport alpa\n\ndef benchmark(physical_mesh, shape, sharding_spec):\n    avals = []\n    shard_indices = []\n    sharding_specs = []\n    donated_invars = []\n    args = []\n\n    number = 2\n\n    for i in range(number):\n        array = jnp.ones(shape, jnp.float32)\n        indices = sharding_spec.indices(array.shape)\n\n        avals.append(jax.ShapedArray(array.shape, array.dtype))\n        sharding_specs.append(sharding_spec)\n        shard_indices.append(indices.flatten())\n        donated_invars.append(True)\n        args.append(array)\n\n    print(sharding_spec)\n    buffers = physical_mesh.shard_args_to_bufs(shard_indices, donated_invars, args)\n\n    return buffers\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\")\n\n    cluster = alpa.DeviceCluster()\n    physical_mesh = cluster.get_physical_mesh()\n\n    shape = (8192, 8192)\n\n    sharding_specs = [\n        ShardingSpec(\n            sharding=[NoSharding(), NoSharding(),],\n            mesh_mapping=[Replicated(8),]),\n        ShardingSpec(\n            sharding=[Chunked([8]), NoSharding(),],\n            mesh_mapping=[ShardedAxis(0),]),\n        ShardingSpec(\n            sharding=[NoSharding(), Chunked([8])],\n            mesh_mapping=[ShardedAxis(0),]),\n        ShardingSpec(\n            sharding=[Chunked([2]), Chunked([4])],\n            mesh_mapping=[ShardedAxis(0), ShardedAxis(1)]),\n    ]\n\n    for spec in sharding_specs:\n        benchmark(physical_mesh, shape, spec)\n\n"
  },
  {
    "path": "playground/auto_sharding_solver/README.md",
    "content": "# A Prototype of Auto-sharding Solver\n\nThis is only a prototype in python. It is not used by alpa.\n\n## Requirements\n```\npip3 install pulp\n```\n\n## Examples\n```\npython3 test_solver_mlp.py\n```\n"
  },
  {
    "path": "playground/auto_sharding_solver/cluster_env.py",
    "content": "\"\"\"Cluster Environment\"\"\"\nimport numpy as np\n\nfrom hlo import ShardingSpec, ShardingSpecType\nfrom common import compute_bytes, get_dim_last_value\n\n\nclass ClusterEnvironment:\n    def __init__(self, device_mesh, mesh_alpha, mesh_beta, memory_per_device, solver_option=None):\n        self.device_mesh = np.array(device_mesh)\n        self.mesh_alpha = mesh_alpha\n        self.mesh_beta = mesh_beta\n        assert len(self.mesh_alpha) == len(self.device_mesh.shape)\n        assert len(self.mesh_beta) == len(self.device_mesh.shape)\n        self.memory_per_device = memory_per_device\n        self.all_gather_penalty = 0\n        self.all_reduce_penalty = 0\n        self.reduce_scatter_penalty = 0\n        self.partial_reduction_penalty = 10\n        self.num_devices = np.prod(self.device_mesh.shape)\n\n        self.force_all_gather_cost = None\n        self.force_all_reduce_cost = None\n        self.force_reduce_scatter_cost = None\n\n        if solver_option:\n            self.force_all_gather_cost = solver_option.force_all_gather_cost\n            self.force_all_reduce_cost = solver_option.force_all_reduce_cost\n            self.force_reduce_scatter_cost = solver_option.force_reduce_scatter_cost\n\n    def all_gather_cost(self, num_bytes, mesh_dim=0):\n        if self.force_all_gather_cost:\n            return self.force_all_gather_cost\n\n        num_devices = self.device_mesh.shape[mesh_dim]\n        return (int(self.mesh_alpha[mesh_dim] +\n                self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes) +\n                0.1) + self.all_gather_penalty\n\n    def all_reduce_cost(self, num_bytes, mesh_dim=0):\n        if self.force_all_reduce_cost:\n            return self.force_all_reduce_cost\n\n        num_devices = self.device_mesh.shape[mesh_dim]\n        return (int(self.mesh_alpha[mesh_dim] +\n                self.mesh_beta[mesh_dim] * 2 * (num_devices - 1) / num_devices * num_bytes) +\n                0.01) + self.all_reduce_penalty\n\n    def reduce_scatter_cost(self, num_bytes, mesh_dim=0):\n        if self.force_reduce_scatter_cost:\n            return self.force_reduce_scatter_cost\n\n        num_devices = self.device_mesh.shape[mesh_dim]\n        return (int(self.mesh_alpha[mesh_dim] +\n                self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices * num_bytes) +\n                0.001)\n\n    def all_to_all_cost(self, num_bytes, mesh_dim=0):\n        num_devices = self.device_mesh.shape[mesh_dim]\n        penalty_factor = 1.5;\n        return (int(self.mesh_alpha[mesh_dim] +\n                self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices /\\\n                    num_devices * num_bytes * penalty_factor) +\n                0.001);\n\n    def get_tensor_dim_to_mesh_dim(self, shape, spec):\n        \"\"\"Map the tensor dimention to mesh dimension, -1 means replicated\"\"\"\n        if spec.type == ShardingSpecType.REPLICATED:\n            return [-1] * len(shape)\n\n        tile_assignment = np.array(spec.tile_assignment_devices).\\\n            reshape(spec.tile_assignment_dimensions)\n\n        tensor_dim_vals = tuple(get_dim_last_value(tile_assignment, i)\n            for i in range(len(shape)))\n\n        mesh_dim_vals = tuple(get_dim_last_value(self.device_mesh, j)\n            for j in range(len(self.device_mesh.shape)))\n\n        ret = [-1] * len(shape)\n        for i in range(len(shape)):\n            if spec.tile_assignment_dimensions[i] != 1:\n                found = False\n                for j in range(len(self.device_mesh.shape)):\n                    if tensor_dim_vals[i] == mesh_dim_vals[j]:\n                        ret[i] = j\n                        found = True\n                assert found\n\n        return ret\n\n    def resharding_cost(self, shape, src_spec, dst_spec):\n        if src_spec == dst_spec:\n            return 0\n\n        src_tensor_dim_to_mesh_dim = self.get_tensor_dim_to_mesh_dim(shape, src_spec)\n        dst_tensor_dim_to_mesh_dim = self.get_tensor_dim_to_mesh_dim(shape, dst_spec)\n\n        cost = 0\n        for i in range(len(shape)):\n            src_mesh_dim = src_tensor_dim_to_mesh_dim[i]\n            if src_mesh_dim == -1:\n                continue\n            if src_mesh_dim == dst_tensor_dim_to_mesh_dim[i]:\n                continue\n            cost += self.all_gather_cost(compute_bytes(shape), src_mesh_dim)\n\n        return cost\n\n"
  },
  {
    "path": "playground/auto_sharding_solver/common.py",
    "content": "\"\"\"Common Utilities\"\"\"\n\nimport numpy as np\n\n\ndef append_flatten_elements(result, array, indices, cur_depth, cur_indices):\n    \"\"\"Append elements of `array` to `result`. The `indices` is a generalized\n       multi-dimensional index that can index a whole row (use -1 to indicate this)\"\"\"\n    if cur_depth == len(array.shape) - 1:\n        result.append(array[tuple(cur_indices)])\n    else:\n        next_depth = cur_depth + 1\n        index = indices[next_depth]\n\n        if index == -1:\n            for i in range(array.shape[next_depth]):\n                cur_indices[next_depth] = i\n                append_flatten_elements(result, array, indices, next_depth, cur_indices)\n        else:\n            cur_indices[next_depth] = index\n            append_flatten_elements(result, array, indices, next_depth, cur_indices)\n\n\ndef get_dim_last_value(array, dim):\n    \"\"\"Get the value of the last element in a dimension\"\"\"\n    indices = tuple(0 if i != dim else array.shape[dim] - 1 for i in range(len(array.shape)))\n    return array[indices]\n\n\ndef transpose_flatten(array, shape, dimensions):\n    \"\"\"Transpose a flatten array\"\"\"\n    array = np.array(array)\n    return np.array(np.transpose(array.reshape(shape), dimensions)).flatten()\n\n\ndef reshape_flatten(array, shape, new_shape):\n    \"\"\"Reshape a flatten array\"\"\"\n    array = np.array(array)\n    return np.array(array.reshape(shape)).flatten()\n\n\ndef compute_bytes(shape):\n    return np.prod(shape) * 4\n\n"
  },
  {
    "path": "playground/auto_sharding_solver/hlo.py",
    "content": "\"\"\"Definition of HLO Instructions\"\"\"\n\nfrom collections import defaultdict\nfrom enum import Enum, auto\n\nimport numpy as np\n\nfrom common import compute_bytes, append_flatten_elements, transpose_flatten, reshape_flatten\n\n\nclass ShardingSpecType(Enum):\n    REPLICATED = auto()\n    MAXIMAL = auto()\n    OTHER = auto()\n    TUPLE = auto()\n    PARTIAL_REDUCTION = auto()\n\n\nINF_COST = 1e10  # infinity cost\n\n\nclass ShardingSpec:\n    def __init__(self, type_, tile_assignment_dimensions, tile_assignment_devices,\n                 replicate_on_last_tile_dim, partial_reduce_replication):\n        self.type = type_\n        self.tile_assignment_dimensions = tuple(tile_assignment_dimensions)\n        self.tile_assignment_devices = tuple(tile_assignment_devices)\n        self.replicate_on_last_tile_dim = replicate_on_last_tile_dim\n        self.partial_reduce_replication = partial_reduce_replication\n\n    def num_tile_devices(self):\n        if self.type == ShardingSpecType.REPLICATED:\n            return 1\n\n        assert self.type == ShardingSpecType.OTHER\n        ret = np.prod(self.tile_assignment_dimensions)\n        if self.replicate_on_last_tile_dim:\n            ret /= self.tile_assignment_dimensions[-1]\n        return ret\n\n    def transpose(self, dimensions):\n        if self.type == ShardingSpecType.REPLICATED:\n            return self\n\n        assert self.type == ShardingSpecType.OTHER\n\n        spec_trans_dims = list(dimensions)\n        if self.replicate_on_last_tile_dim:\n            spec_trans_dims.append(len(dimensions))\n\n        tile_assignment_dimensions = [self.tile_assignment_dimensions[i]\n            for i in spec_trans_dims]\n        tile_assignment_devices = transpose_flatten(self.tile_assignment_devices,\n            self.tile_assignment_dimensions, spec_trans_dims)\n\n        ret = ShardingSpec(self.type,\n                           tile_assignment_dimensions,\n                           tile_assignment_devices,\n                           self.replicate_on_last_tile_dim,\n                           self.partial_reduce_replication)\n        return ret\n\n    def broadcast(self, new_shape, dimensions):\n        if self.type == ShardingSpecType.REPLICATED:\n            return self\n\n        assert self.type == ShardingSpecType.OTHER\n\n        tile_assignment_dimensions = []\n        for i in range(len(new_shape)):\n            if i in dimensions:\n                tile_assignment_dimensions.append(\n                    self.tile_assignment_dimensions[dimensions.index(i)])\n            else:\n                tile_assignment_dimensions.append(1)\n\n        if self.replicate_on_last_tile_dim:\n            tile_assignment_dimensions.append(self.tile_assignment_dimensions[-1])\n\n        output_spec = ShardingSpec(self.type,\n                                   tile_assignment_dimensions,\n                                   self.tile_assignment_devices,\n                                   self.replicate_on_last_tile_dim,\n                                   self.partial_reduce_replication)\n        return output_spec\n\n    def reshape(self, old_shape, new_shape):\n        if self.type == ShardingSpecType.REPLICATED:\n            return self\n\n        assert self.type == ShardingSpecType.OTHER\n\n        # Construct a map that maps an old dimension to its corresponding new dimension\n        dim_mapping = {}\n        new_pt = -1\n        old_pt = -1\n        old_prod = 1\n        new_prod = 1\n        while True:\n            move_new = False\n            move_old = False\n\n            if new_prod == old_prod:\n                dim_mapping[old_pt + 1] = new_pt + 1\n                move_new = move_old = True\n            elif new_prod < old_prod:\n                move_new = True\n            else:\n                move_old = True\n\n            if move_new:\n                new_pt += 1\n                if new_pt < len(new_shape):\n                    new_prod *= new_shape[new_pt]\n                else:\n                    break\n            if move_old:\n                old_pt += 1\n                if old_pt < len(old_shape):\n                    old_prod *= old_shape[old_pt]\n                else:\n                    break\n\n        tile_assignment_dimensions = []\n        cur_prod = 1\n        state = 1  # 0: start  1: middle\n        i = 0\n\n        failed = False\n        while i < len(old_shape) and not failed:\n            if state == 0:\n                assert i in dim_mapping\n                while len(tile_assignment_dimensions) < dim_mapping[i]:\n                    tile_assignment_dimensions.append(1)\n                tile_assignment_dimensions.append(\n                    self.tile_assignment_dimensions[i])\n                state = 1\n                i += 1\n            elif state == 1:\n                if i in dim_mapping:\n                    state = 0\n                else:\n                    if self.tile_assignment_dimensions[i] == 1:\n                        i += 1\n                    else:\n                        failed = True\n\n        if failed:\n            return None\n\n        while len(tile_assignment_dimensions) < len(new_shape):\n            tile_assignment_dimensions.append(1)\n\n        if self.replicate_on_last_tile_dim:\n            tile_assignment_dimensions.append(self.tile_assignment_dimensions[-1])\n        output_spec = ShardingSpec(self.type,\n                                   tile_assignment_dimensions,\n                                   self.tile_assignment_devices,\n                                   self.replicate_on_last_tile_dim,\n                                   self.partial_reduce_replication)\n        return output_spec\n\n    @staticmethod\n    def tile_internal(shape, tensor_dims, mesh_dims, cluster_env, partial_reduce_replication):\n        assert len(tensor_dims) == len(mesh_dims)\n\n        tile_assignment_dimensions = [1] * len(shape)\n\n        # Split on certain mesh dimensions\n        split_prod = 1\n        for tensor_dim, mesh_dim in zip(tensor_dims, mesh_dims):\n            tile_assignment_dimensions[tensor_dim] = cluster_env.device_mesh.shape[mesh_dim]\n            split_prod *= cluster_env.device_mesh.shape[mesh_dim]\n\n        if split_prod == 1:\n            return ShardingSpec.replicated(cluster_env)\n\n        # Replicate on reminding mesh dimensions\n        if split_prod < cluster_env.num_devices:\n            tile_assignment_dimensions.append(cluster_env.num_devices // split_prod)\n            replicate_on_last_tile_dim = True\n        else:\n            replicate_on_last_tile_dim = False\n\n        # Map device ids from device_mesh to tile_assignment_devices\n        tile_assignment_devices = []\n        tmp_indices = [None] * len(cluster_env.device_mesh.shape)\n        def generate_tile_assignment_devices(tensor_dim, mesh_indices):\n            if tensor_dim == len(shape) - 1:\n                append_flatten_elements(tile_assignment_devices, cluster_env.device_mesh,\n                                        mesh_indices, -1, tmp_indices)\n            else:\n                next_tensor_dim = tensor_dim + 1\n                next_mesh_dim = -1\n\n                if next_tensor_dim in tensor_dims:\n                    next_mesh_dim = mesh_dims[tensor_dims.index(next_tensor_dim)]\n\n                for i in range(tile_assignment_dimensions[next_tensor_dim]):\n                    if next_mesh_dim != -1:\n                        mesh_indices[next_mesh_dim] = i\n                    generate_tile_assignment_devices(next_tensor_dim, mesh_indices)\n\n        generate_tile_assignment_devices(-1, [-1] * len(cluster_env.device_mesh.shape))\n\n        return ShardingSpec(ShardingSpecType.OTHER,\n                            tile_assignment_dimensions, tile_assignment_devices,\n                            replicate_on_last_tile_dim,\n                            False)\n\n    @staticmethod\n    def tile(shape, tensor_dims, mesh_dims, cluster_env):\n        return ShardingSpec.tile_internal(shape, tensor_dims, mesh_dims, cluster_env, False)\n\n    @staticmethod\n    def tile_partial_reduce(shape, tensor_dims, mesh_dims, cluster_env):\n        return ShardingSpec.tile_internal(shape, tensor_dims, mesh_dims, cluster_env, True)\n\n    @staticmethod\n    def replicated(cluster_env):\n        tile_assignment_devices = range(cluster_env.num_devices)\n        return ShardingSpec(ShardingSpecType.REPLICATED, (), tile_assignment_devices,\n                            False, False)\n\n    @staticmethod\n    def split(shape, dim, cluster_env):\n        tile_assignment_dimensions = [1] * len(shape)\n        tile_assignment_dimensions[dim] = cluster_env.num_devices\n        tile_assignment_devices = range(cluster_env.num_devices)\n        return ShardingSpec(ShardingSpecType.OTHER,\n                            tile_assignment_dimensions, tile_assignment_devices,\n                            False, False)\n\n    @staticmethod\n    def tuple():\n        return ShardingSpec(ShardingSpecType.TUPLE, (), (), False, False)\n\n    def __str__(self):\n        return f\"{self.tile_assignment_dimensions}\"\\\n               f\"{list(self.tile_assignment_devices)}\"\n\n    def __eq__(self, other):\n        return (self.type == other.type and\n                self.tile_assignment_dimensions == other.tile_assignment_dimensions and\n                self.tile_assignment_devices == other.tile_assignment_devices and\n                self.replicate_on_last_tile_dim == other.replicate_on_last_tile_dim and\n                self.partial_reduce_replication == other.partial_reduce_replication)\n\n\ndef resharding_cost_vector(cluster_env, source_ins, required_spec):\n    cost_vector = []\n    for strategy in source_ins.strategies:\n        cost_vector.append(cluster_env.resharding_cost(source_ins.shape,\n            strategy.output_spec, required_spec))\n    return cost_vector\n\n\ndef follow_ins_cost_vector(source_ins, index):\n    ret = [INF_COST] * len(source_ins.strategies)\n    ret[index] = 0\n    return ret\n\n\nclass InstructionStrategy:\n    def __init__(self, name, output_spec):\n        self.name = name\n        self.output_spec = output_spec\n\n\nclass OpCode(Enum):\n    PARAMETER = auto()\n    CONSTANT = auto()\n    BROADCAST = auto()\n    RESHAPE = auto()\n    TRANSPOSE = auto()\n    IDENTITY = auto()\n    EXP = auto()\n    FORCE_REPLICATED = auto()\n    ADD = auto()\n    SUBTRACT = auto()\n    MULTIPLY = auto()\n    DIV = auto()\n    COMPARE = auto()\n    SELECT = auto()\n    REDUCE = auto()\n    DOT = auto()\n    TUPLE = auto()\n\nop_code_ct = defaultdict(int)\n\n\n\nclass HloInstruction:\n    def __init__(self, op_code, shape, operands=[]):\n        # Attributes\n        self.op_code = op_code\n        self.shape = shape\n        self.operands = operands\n        self.name = f\"{str(op_code)[7:].lower()}.{op_code_ct[op_code]}\"\n        op_code_ct[op_code] += 1\n\n        # Cost\n        self.strategies = []\n        self.compute_costs = []\n        self.communication_costs = []\n        self.memory_costs = []\n        self.resharding_costs = []\n        self.follow_ins = None\n        self.depth = None\n\n        # The index in HloComputation\n        self.index = HloComputation.cur_env.append(self)\n        self.batch_dim = None\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        raise NotImplementedError(f\"{self.op_code}\")\n\n    def propagate_batch_dim(self, operand):\n        raise NotImplementedError(f\"{self.op_code}\")\n\n\nclass HloParameter(HloInstruction):\n    def __init__(self, shape, fix_strategy=None):\n        super().__init__(OpCode.PARAMETER, shape, [])\n        self.fix_strategy = fix_strategy\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        for i in range(len(self.shape)):\n            for j in range(len(cluster_env.device_mesh.shape)):\n                if (cluster_env.device_mesh.shape[j] == 1 or\n                    self.shape[i] < cluster_env.device_mesh.shape[j]):\n                    continue\n\n                name = f\"S{i} @ {j}\"\n                output_spec = ShardingSpec.tile(self.shape, [i], [j], cluster_env)\n                self.strategies.append(InstructionStrategy(name, output_spec))\n                self.compute_costs.append(0)\n                self.communication_costs.append(0)\n                self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n\n        self.strategies.append(InstructionStrategy(\"R\", ShardingSpec.replicated(cluster_env)))\n        self.compute_costs.append(2)\n        self.communication_costs.append(0)\n        self.memory_costs.append(compute_bytes(self.shape))\n\n        if self.fix_strategy:\n            new_strategies = []\n            new_compute_costs = []\n            new_communication_costs = []\n            new_memory_costs = []\n\n            # filter strategies\n            for i in range(len(self.strategies)):\n                if self.strategies[i].name == self.fix_strategy:\n                    new_strategies.append(self.strategies[i])\n                    new_compute_costs.append(self.compute_costs[i])\n                    new_communication_costs.append(self.communication_costs[i])\n                    new_memory_costs.append(self.memory_costs[i])\n\n            self.strategies = new_strategies\n            self.compute_costs = new_compute_costs\n            self.communication_costs = new_communication_costs\n            self.memory_costs = new_memory_costs\n\n    def __str__(self):\n        return f\"{self.name} {self.shape} = parameter()\"\n\n\nclass HloConstant(HloInstruction):\n    def __init__(self, value):\n        super().__init__(OpCode.CONSTANT, (), [])\n        self.value = value\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        self.strategies.append(InstructionStrategy(\"R\", ShardingSpec.replicated(cluster_env)))\n        self.compute_costs.append(0)\n        self.communication_costs.append(0)\n        self.memory_costs.append(compute_bytes(self.shape))\n\n    def __str__(self):\n        return f\"{self.name} {self.shape} = constant({self.value})\"\n\n\nclass HloBroadcast(HloInstruction):\n    def __init__(self, operand, shape, dimensions=()):\n        for i in dimensions:\n            assert shape[i] == operand.shape[dimensions.index(i)]\n        super().__init__(OpCode.BROADCAST, shape, [operand])\n        self.dimensions = dimensions\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        follow = self.operands[0]\n        self.follow_ins = follow\n\n        for sid in range(len(follow.strategies)):\n            output_spec = follow.strategies[sid].output_spec.broadcast(\n                    self.shape, self.dimensions)\n            name = f\"{output_spec.tile_assignment_dimensions}\"\n            self.strategies.append(InstructionStrategy(name, output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(0)\n            self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n            self.resharding_costs.append([follow_ins_cost_vector(follow, sid)])\n\n    def __str__(self):\n        return f\"{self.name} {self.shape} = broadcast({self.operands[0].name})\"\n\n\nclass HloReshape(HloInstruction):\n    def __init__(self, operand, new_shape):\n        # todo: mark this as inplace\n        assert np.prod(operand.shape) == np.prod(new_shape)\n        super().__init__(OpCode.RESHAPE, new_shape, [operand])\n        self.new_shape = new_shape\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        follow = self.operands[0]\n        self.follow_ins = follow\n        old_shape = self.operands[0].shape\n        new_shape = self.new_shape\n\n        for sid in range(len(follow.strategies)):\n            output_spec = follow.strategies[sid].output_spec.reshape(\n                    follow.shape, self.shape)\n            if output_spec is None:\n                continue\n\n            name = f\"{output_spec.tile_assignment_dimensions}\"\n            self.strategies.append(InstructionStrategy(name, output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(0)\n            self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n            self.resharding_costs.append([follow_ins_cost_vector(follow, sid)])\n\n    def __str__(self):\n        return f\"{self.name} {self.shape} = reshape({self.operands[0].name})\"\n\n\nclass HloTranspose(HloInstruction):\n    def __init__(self, operand, dimensions):\n        assert len(dimensions) == len(operand.shape)\n        new_shape = tuple(operand.shape[i] for i in dimensions)\n        super().__init__(OpCode.TRANSPOSE, new_shape, [operand])\n        self.dimensions = dimensions\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        follow = self.operands[0]\n        self.follow_ins = follow\n\n        for sid in range(len(follow.strategies)):\n            output_spec = follow.strategies[sid].output_spec.transpose(self.dimensions)\n            name = f\"{output_spec.tile_assignment_dimensions}\"\n            self.strategies.append(InstructionStrategy(name, output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(0)\n            self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n            self.resharding_costs.append([follow_ins_cost_vector(follow, sid)])\n\n    def __str__(self):\n        return f\"{self.name} {self.shape} = transpose({self.operands[0].name}) \" +\\\n               f\"dimensions={self.dimensions}\"\n\n\nclass HloElementwise(HloInstruction):\n    def __init__(self, op_code, operands):\n        for i in range(0, len(operands)):\n            assert operands[0].shape == operands[i].shape\n        super().__init__(op_code, operands[0].shape, operands)\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        depths = [operand.depth for operand in self.operands]\n        follow_idx = np.argmax(depths)\n\n        follow = self.operands[follow_idx]\n        self.follow_ins = follow\n\n        for sid in range(len(follow.strategies)):\n            output_spec = follow.strategies[sid].output_spec\n\n            name = f\"{output_spec.tile_assignment_dimensions}\"\n            self.strategies.append(InstructionStrategy(name, output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(0)\n            self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n\n            resharding_costs = []\n            for k in range(len(self.operands)):\n                if k == follow_idx:\n                    resharding_costs.append(\n                        follow_ins_cost_vector(follow, sid))\n                else:\n                    resharding_costs.append(\n                    resharding_cost_vector(cluster_env, self.operands[k], output_spec))\n            self.resharding_costs.append(resharding_costs)\n\n    def propagate_batch_dim(self, ins):\n        self.batch_dim = ins.batch_dim\n        return True\n\n    def __str__(self):\n        fun_name = str(self.op_code)[7:].lower()\n        args = \", \".join(f\"{self.operands[i].name}\" for i in range(len(self.operands)))\n        return f\"{self.name} {self.shape} = {fun_name}({args})\"\n\n\nclass HloIdentity(HloElementwise):\n    def __init__(self, operand):\n        super().__init__(OpCode.IDENTITY, [operand])\n\n\nclass HloExp(HloElementwise):\n    def __init__(self, operand):\n        super().__init__(OpCode.EXP, [operand])\n\n\nclass HloForceReplicated(HloElementwise):\n    def __init__(self, operand):\n        super().__init__(OpCode.FORCE_REPLICATED, [operand])\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        self.strategies.append(InstructionStrategy(\"R\",\n            ShardingSpec.replicated(cluster_env)))\n        self.compute_costs.append(0)\n        self.communication_costs.append(0)\n        self.memory_costs.append(0)\n        self.resharding_costs.append([\n            resharding_cost_vector(cluster_env, self.operands[0],\n                ShardingSpec.replicated(cluster_env))\n        ])\n\n\nclass HloAdd(HloElementwise):\n    def __init__(self, lhs, rhs):\n        super().__init__(OpCode.ADD, [lhs, rhs])\n\n\nclass HloSubtract(HloElementwise):\n    def __init__(self, lhs, rhs):\n        super().__init__(OpCode.SUBTRACT, [lhs, rhs])\n\n\nclass HloMutiply(HloElementwise):\n    def __init__(self, lhs, rhs):\n        super().__init__(OpCode.MULTIPLY, [lhs, rhs])\n\n\nclass HloDiv(HloElementwise):\n    def __init__(self, lhs, rhs):\n        super().__init__(OpCode.DIV, [lhs, rhs])\n\n\nclass HloCompare(HloElementwise):\n    def __init__(self, lhs, rhs):\n        super().__init__(OpCode.COMPARE, [lhs, rhs])\n\n\nclass HloSelect(HloElementwise):\n    def __init__(self, pred, true_value, false_value):\n        super().__init__(OpCode.SELECT, [pred, true_value, false_value])\n\n\nclass HloReduce(HloInstruction):\n    def __init__(self, operand, dimensions):\n        new_shape = tuple(operand.shape[i] for i in range(len(operand.shape)) if i not in dimensions)\n        super().__init__(OpCode.REDUCE, new_shape, [operand])\n        self.dimensions = dimensions\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        operand = self.operands[0]\n        self.follow_ins = operand\n\n        # Map old dims to new dim\n        old_dim_to_new_dim = []\n        pt = 0\n        for old_dim in range(len(operand.shape)):\n            if old_dim in self.dimensions:\n                old_dim_to_new_dim.append(-1)\n            else:\n                old_dim_to_new_dim.append(pt)\n                pt += 1\n        assert pt == len(self.shape)\n\n        # Create follow strategies\n        for sid in range(len(operand.strategies)):\n            tensor_dim_to_mesh = cluster_env.get_tensor_dim_to_mesh_dim(\n                operand.shape, operand.strategies[sid].output_spec)\n\n            tile_tensor_dims = []\n            tile_mesh_dims = []\n            all_reduce_dims = []\n\n            for tensor_dim in range(len(operand.shape)):\n                mesh_dim = tensor_dim_to_mesh[tensor_dim]\n                if tensor_dim in self.dimensions:\n                    if mesh_dim == -1:  # reduce on a replicated dim\n                        continue\n                    else:               # reduce on a split dim\n                        all_reduce_dims.append(mesh_dim)\n                else:\n                    if mesh_dim == -1: # follow replicated dim\n                        pass\n                    else:              # follow split dim\n                        tile_tensor_dims.append(old_dim_to_new_dim[tensor_dim])\n                        tile_mesh_dims.append(mesh_dim)\n\n            output_spec = ShardingSpec.tile(self.shape, tile_tensor_dims, tile_mesh_dims, cluster_env)\n\n            mem_cost = compute_bytes(self.shape) / output_spec.num_tile_devices()\n            comm_cost = 0\n            for mesh_dim in all_reduce_dims:\n                comm_cost += cluster_env.all_reduce_cost(mem_cost, mesh_dim)\n\n            reduce_dims_str = \"\".join([str(x) for x in all_reduce_dims])\n            if reduce_dims_str:\n                name = f\"follow (allreduce @ {reduce_dims_str})\"\n            else:\n                name = f\"{output_spec.tile_assignment_dimensions}\"\n\n            self.strategies.append(InstructionStrategy(name, output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(comm_cost)\n            self.memory_costs.append(mem_cost)\n            self.resharding_costs.append([follow_ins_cost_vector(operand, sid)])\n\n    def __str__(self):\n        return f\"{self.name} {self.shape} = reduce({self.operands[0].name}) \" +\\\n               f\"dimensions={self.dimensions}\"\n\n\nclass HloDot(HloInstruction):\n    def __init__(self, lhs, rhs,\n                 lhs_batch_dims=(), lhs_contracting_dims=(1,),\n                 rhs_batch_dims=(), rhs_contracting_dims=(0,)):\n        # shape inference\n        lhs_space_shape = \\\n            tuple(lhs.shape[i] for i in range(len(lhs.shape))\n                  if i not in lhs_contracting_dims and i not in lhs_batch_dims)\n        rhs_space_shape = \\\n            tuple(rhs.shape[i] for i in range(len(rhs.shape))\n                  if i not in rhs_contracting_dims and i not in rhs_batch_dims)\n        lhs_batch_shape = tuple(lhs.shape[i] for i in lhs_batch_dims)\n\n        shape = lhs_batch_shape + lhs_space_shape + rhs_space_shape\n\n        for i, j in zip(lhs_contracting_dims, rhs_contracting_dims):\n            assert lhs.shape[i] == rhs.shape[j]\n        for i, j in zip(lhs_batch_dims, rhs_batch_dims):\n            assert lhs.shape[i] == rhs.shape[j]\n\n        super().__init__(OpCode.DOT, shape, [lhs, rhs])\n        self.lhs = lhs\n        self.lhs_batch_dims = lhs_batch_dims\n        self.lhs_contracting_dims = lhs_contracting_dims\n        self.lhs_space_dims = tuple(set(range(len(lhs.shape))) - set(self.lhs_batch_dims) - set(self.lhs_contracting_dims))\n        assert len(self.lhs_contracting_dims) == 1\n        assert len(self.lhs_space_dims) == 1\n        self.rhs = rhs\n        self.rhs_batch_dims = rhs_batch_dims\n        self.rhs_contracting_dims = rhs_contracting_dims\n        self.rhs_space_dims = tuple(set(range(len(rhs.shape))) - set(self.rhs_batch_dims) - set(self.rhs_contracting_dims))\n        assert len(self.rhs_contracting_dims) == 1\n        assert len(self.rhs_space_dims) == 1\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        lhs = self.lhs\n        lhs_batch_dims = self.lhs_batch_dims\n        lhs_space_dim = self.lhs_space_dims[0]\n        lhs_con_dim = self.lhs_contracting_dims[0]\n\n        rhs = self.rhs\n        rhs_batch_dims = self.rhs_batch_dims\n        rhs_space_dim = self.rhs_space_dims[0]\n        rhs_con_dim = self.rhs_contracting_dims[0]\n\n        space_base_dim = len(self.lhs_batch_dims)\n\n        assert len(cluster_env.device_mesh.shape) == 2\n\n        # Split lhs space dim + rhs space dim\n        # @ {0, 1}\n        output_spec =\\\n            ShardingSpec.tile(self.shape, [space_base_dim, space_base_dim + 1], [0, 1], cluster_env)\n        self.strategies.append(InstructionStrategy(\"SS = SR x RS @ {0,1}\", output_spec))\n        self.compute_costs.append(0)\n        self.communication_costs.append(0)\n        self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n        self.resharding_costs.append([\n            resharding_cost_vector(cluster_env, lhs,\n                ShardingSpec.tile(lhs.shape, [lhs_space_dim], [0], cluster_env)),\n            resharding_cost_vector(cluster_env, rhs,\n                ShardingSpec.tile(rhs.shape, [rhs_space_dim], [1], cluster_env))\n        ])\n\n        # @ {1, 0}\n        output_spec =\\\n            ShardingSpec.tile(self.shape, [space_base_dim, space_base_dim + 1], [1, 0], cluster_env)\n        self.strategies.append(InstructionStrategy(\"SS = SR x RS @ {1,0}\", output_spec))\n        self.compute_costs.append(0)\n        self.communication_costs.append(0)\n        self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n        self.resharding_costs.append([\n            resharding_cost_vector(cluster_env, lhs,\n                ShardingSpec.tile(lhs.shape, [lhs_space_dim], [1], cluster_env)),\n            resharding_cost_vector(cluster_env, rhs,\n                ShardingSpec.tile(rhs.shape, [rhs_space_dim], [0], cluster_env))\n        ])\n\n        # Split lhs space dim + contracting dim\n        # @ {0, 1}\n        if cluster_env.device_mesh.shape[1] > 1:\n            output_spec = ShardingSpec.tile(self.shape, [space_base_dim], [0], cluster_env)\n            memory_cost = compute_bytes(self.shape) / output_spec.num_tile_devices()\n            self.strategies.append(\n                InstructionStrategy(\"SR = SS x SR @ {0,1} (allreduce @ 1)\", output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(cluster_env.all_reduce_cost(memory_cost, 1))\n            self.memory_costs.append(memory_cost)\n            self.resharding_costs.append([\n                resharding_cost_vector(cluster_env, lhs,\n                    ShardingSpec.tile(lhs.shape, [lhs_space_dim, lhs_con_dim], [0, 1], cluster_env)),\n                resharding_cost_vector(cluster_env, rhs,\n                    ShardingSpec.tile(rhs.shape, [rhs_con_dim], [1], cluster_env))\n            ])\n\n        # @ {1, 0}\n        if cluster_env.device_mesh.shape[0] > 1:\n            output_spec = ShardingSpec.tile(self.shape, [space_base_dim], [1], cluster_env)\n            memory_cost = compute_bytes(self.shape) / output_spec.num_tile_devices()\n            self.strategies.append(\n                InstructionStrategy(\"SR = SS x SR @ {1,0} (allreduce @ 0)\", output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(cluster_env.all_reduce_cost(memory_cost, 0))\n            self.memory_costs.append(memory_cost)\n            self.resharding_costs.append([\n                resharding_cost_vector(cluster_env, lhs,\n                    ShardingSpec.tile(lhs.shape, [lhs_space_dim, lhs_con_dim], [1, 0], cluster_env)),\n                resharding_cost_vector(cluster_env, rhs,\n                    ShardingSpec.tile(rhs.shape, [rhs_con_dim], [0], cluster_env))\n            ])\n\n        # Split rhs space dim + contracting dim\n        # @ {0, 1}\n        if cluster_env.device_mesh.shape[0] > 1 and cluster_env.device_mesh.shape[1] > 1:\n            output_spec = ShardingSpec.tile(self.shape, [space_base_dim+1], [1], cluster_env)\n            memory_cost = compute_bytes(self.shape) / output_spec.num_tile_devices()\n            self.strategies.append(\n                InstructionStrategy(\"RS = RS x SS @ {0,1} (allreduce @ 0)\", output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(cluster_env.all_reduce_cost(memory_cost, 0))\n            self.memory_costs.append(memory_cost)\n            self.resharding_costs.append([\n                resharding_cost_vector(cluster_env, lhs,\n                    ShardingSpec.tile(lhs.shape, [lhs_con_dim], [0], cluster_env)),\n                resharding_cost_vector(cluster_env, rhs,\n                    ShardingSpec.tile(rhs.shape, [rhs_con_dim, rhs_space_dim], [0, 1], cluster_env))\n            ])\n\n        # @ {1, 0}\n        if cluster_env.device_mesh.shape[0] > 1 and cluster_env.device_mesh.shape[1] > 1:\n            output_spec = ShardingSpec.tile(self.shape, [space_base_dim+1], [0], cluster_env)\n            memory_cost = compute_bytes(self.shape) / output_spec.num_tile_devices()\n            self.strategies.append(\n                InstructionStrategy(\"RS = RS x SS @ {1,0} (allreduce @ 1)\", output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(cluster_env.all_reduce_cost(memory_cost, 1))\n            self.memory_costs.append(memory_cost)\n            self.resharding_costs.append([\n                resharding_cost_vector(cluster_env, lhs,\n                    ShardingSpec.tile(lhs.shape, [lhs_con_dim], [1], cluster_env)),\n                resharding_cost_vector(cluster_env, rhs,\n                    ShardingSpec.tile(rhs.shape, [rhs_con_dim, rhs_space_dim], [1, 0], cluster_env))\n            ])\n\n        # Split one batch dim\n        for i in range(len(self.lhs_batch_dims)):\n            for j in range(len(cluster_env.device_mesh.shape)):\n                if (cluster_env.device_mesh.shape[j] == 1 or\n                    self.shape[i] < cluster_env.device_mesh.shape[j]):\n                    continue\n\n                output_spec = ShardingSpec.tile(self.shape, [i], [j], cluster_env)\n                self.strategies.append(InstructionStrategy(f\"Sb_{i} = Sb x Sb @ {j}\", output_spec))\n                self.compute_costs.append(0)\n                self.communication_costs.append(0)\n                self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n                self.resharding_costs.append([\n                    resharding_cost_vector(cluster_env, lhs,\n                        ShardingSpec.tile(lhs.shape, [lhs_batch_dims[i]], [j], cluster_env)),\n                    resharding_cost_vector(cluster_env, rhs,\n                        ShardingSpec.tile(rhs.shape, [rhs_batch_dims[i]], [j], cluster_env))\n                ])\n\n        # Split two batch dims\n        if len(self.lhs_batch_dims) == 2 and cluster_env.device_mesh.shape[0] > 1\\\n                and cluster_env.device_mesh.shape[1] > 1:\n\n            self.strategies = []\n            self.compute_costs = []\n            self.communication_costs = []\n            self.memory_costs = []\n            self.resharding_costs = []\n\n            # Split two batch dims\n            output_spec = ShardingSpec.tile(self.shape, [0, 1], [0, 1], cluster_env)\n            self.strategies.append(InstructionStrategy(\"Sb = Sb x Sb @ {0,1}\", output_spec))\n            self.compute_costs.append(0)\n            self.communication_costs.append(0)\n            self.memory_costs.append(compute_bytes(self.shape) / output_spec.num_tile_devices())\n            self.resharding_costs.append([\n                resharding_cost_vector(cluster_env, lhs,\n                    ShardingSpec.tile(lhs.shape, [lhs_batch_dims[0], lhs_batch_dims[1]], [0, 1], cluster_env)),\n                resharding_cost_vector(cluster_env, rhs,\n                    ShardingSpec.tile(rhs.shape, [rhs_batch_dims[0], rhs_batch_dims[1]], [0, 1], cluster_env))\n            ])\n\n        # If force batch dim to a mesh dim, filter out invalid strategies\n        if solver_option.force_batch_dim_to_mesh_dim is not None and self.batch_dim is not None:\n            filter_indices = []\n            for i in range(len(self.strategies)):\n                tensor_dim_to_mesh_dim = cluster_env.get_tensor_dim_to_mesh_dim(\n                    self.shape, self.strategies[i].output_spec)\n                if tensor_dim_to_mesh_dim[self.batch_dim] == solver_option.force_batch_dim_to_mesh_dim:\n                    filter_indices.append(i)\n\n            self.strategies = [self.strategies[i] for i in filter_indices]\n            self.compute_costs = [self.compute_costs[i] for i in filter_indices]\n            self.communication_costs = [self.communication_costs[i] for i in filter_indices]\n            self.memory_costs = [self.memory_costs[i] for i in filter_indices]\n            self.resharding_costs = [self.resharding_costs[i] for i in filter_indices]\n\n    def propagate_batch_dim(self, operand):\n        index = self.operands.index(operand)\n\n        if index == 0:\n            for i in range(len(self.lhs_batch_dims)):\n                if operand.batch_dim == self.lhs_batch_dims[i]:\n                    self.batch_dim = i\n                    return True\n            if operand.batch_dim == self.lhs_space_dims[0]:\n                self.batch_dim = len(self.lhs_batch_dims)\n                return True\n            if operand.batch_dim in self.lhs_contracting_dims:\n                return False\n        else:\n            for i in range(len(self.rhs_batch_dims)):\n                if operand.batch_dim == self.rhs_batch_dims[i]:\n                    self.batch_dim = i\n                    return True\n            if operand.batch_dim == self.rhs_space_dims[0]:\n                self.batch_dim = len(self.rhs_batch_dims)\n                return True\n            if operand.batch_dim in self.rhs_contracting_dims:\n                return False\n\n    def __str__(self):\n        return f\"{self.name} {self.shape} = dot({self.lhs.name}, {self.rhs.name}) \"\\\n               f\" lhs_con_dim={self.lhs_contracting_dims},\"\\\n               f\" rhs_con_dim={self.rhs_contracting_dims}\"\n\n\nclass HloTuple(HloInstruction):\n    def __init__(self, operands):\n        super().__init__(OpCode.TUPLE, (), operands)\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        self.strategies.append(InstructionStrategy(\"tuple\", ShardingSpec.tuple()))\n        self.memory_costs.append(0)\n        self.compute_costs.append(0)\n        self.communication_costs.append(0)\n        self.resharding_costs.append([np.zeros(len(operand.strategies))\n            for operand in self.operands])\n\n    def __str__(self):\n        names = tuple(x.name for x in self.operands)\n        return f\"{self.name} {self.shape} = tuple{names}\"\n\n\nclass HloComputation:\n    cur_env = None\n\n    def __init__(self):\n        self.ct = 0\n        self.instructions = []\n        self.alias_list = []\n        self.alias_cost_vector = []\n\n        self.parameters = []\n\n        self.strategy_built = False\n\n    def append(self, instruction):\n        ct = len(self.instructions)\n        self.instructions.append(instruction)\n\n        if instruction.op_code == OpCode.PARAMETER:\n            self.parameters.append(instruction)\n\n        return ct\n\n    def liveness_analysis(self):\n        liveness_dict = dict()\n\n        live_set = set()\n\n        for t in range(len(self.instructions)-1, -1, -1):\n            inst = self.instructions[t]\n\n            live_set.add(inst)\n            for operand in inst.operands:\n                live_set.add(operand)\n\n            liveness_dict[t] = set(live_set)\n\n            live_set.remove(inst)\n\n        return liveness_dict\n\n    def set_alias(self, alias_list):\n        self.alias_list = alias_list\n\n    def concurrency_analysis(self):\n        frontier_list = []\n        edge_dict = defaultdict(list)\n\n        # Build degree dict\n        #out_degree = defaultdict(lambda : 0)\n        #for ins in self.instructions:\n        #    for operand in ins.operands:\n        #        out_degree[operand] += 1\n\n        degree = defaultdict(lambda : 0)\n        for ins in self.instructions:\n            for operand in ins.operands:\n                degree[ins] += 1\n                edge_dict[operand].append(ins)\n\n        # Init frontier\n        collected = 0\n        current_frontier = []\n        for ins in self.instructions:\n            if degree[ins] == 0:\n                current_frontier.append(ins)\n                collected += 1\n        frontier_list.append(current_frontier)\n\n        # Push forward frontier\n        while collected < len(self.instructions):\n            current_frontier = frontier_list[-1]\n            next_frontier = []\n            for ins in current_frontier:\n                for node in edge_dict[ins]:\n                    degree[node] -= 1\n                    if degree[node] == 0:\n                        next_frontier.append(node)\n                        collected += 1\n            frontier_list.append(next_frontier)\n\n        for i, frontier in enumerate(frontier_list):\n            print(i)\n            for ins in frontier:\n                print(ins)\n\n    def forward_backward_analysis(self):\n        used_by = defaultdict(list)\n        for ins in self.instructions:\n            for operand in ins.operands:\n                used_by[operand].append(ins.index)\n\n        sep_id = 0\n        for param in self.parameters:\n            if len(used_by[param]) > 2:\n                backward_id = used_by[param][0]\n                sep_id = max(sep_id, backward_id + 1)\n\n        return sep_id\n\n    def batch_dim_analysis(self):\n        # Build used by dict\n        used_by = defaultdict(list)\n        for ins in self.instructions:\n            for operand in ins.operands:\n                used_by[operand].append(ins)\n\n        # Find source.\n        # Rule: The first dim of parameters that are only used once\n        #possible_inputs = []\n        #for param in self.parameters:\n        #    if len(used_by[param]) == 1:\n        #        possible_inputs.append(param)\n        #source = possible_inputs[0]\n        source = self.instructions[0]\n        source.batch_dim = 0\n\n        # Dim propagation\n        queue = [source]\n        visited = set([source])\n\n        while len(queue) > 0:\n            ins = queue.pop(0)\n\n            # Propagate to operand\n\n            # Propagate to used_by\n            for consumer in used_by[ins]:\n                #print(f\"Propagate from {ins} to {consumer}\")\n                success = consumer.propagate_batch_dim(ins)\n                if not success:\n                    continue\n                if consumer.index not in visited:\n                    visited.add(consumer)\n                    queue.append(consumer)\n\n    def depth_analysis(self):\n        edge_dict = defaultdict(list)\n\n        degree = defaultdict(lambda : 0)\n        for ins in self.instructions:\n            for operand in ins.operands:\n                degree[ins] += 1\n                edge_dict[operand].append(ins)\n\n        # Init frontier\n        collected = 0\n        current_frontier = []\n        for ins in self.instructions:\n            if degree[ins] == 0:\n                ins.depth = 0\n                current_frontier.append(ins)\n                collected += 1\n\n        # Push forward frontier\n        depth = 0\n        while collected < len(self.instructions):\n            next_frontier = []\n            for ins in current_frontier:\n                for node in edge_dict[ins]:\n                    degree[node] -= 1\n                    if degree[node] == 0:\n                        next_frontier.append(node)\n                        collected += 1\n\n            depth += 1\n            current_frontier = next_frontier\n            for ins in current_frontier:\n                ins.depth = depth\n\n    def build_strategy_and_cost(self, cluster_env, solver_option):\n        if self.strategy_built:\n            for ins in self.instructions:\n                ins.strategies = []\n                ins.compute_costs = []\n                ins.communication_costs = []\n                ins.memory_costs = []\n                ins.resharding_costs = []\n                ins.follow_ins = None\n\n            self.alias_cost_vector = []\n\n        # Analyze depth for all instructions\n        self.depth_analysis()\n\n        # Analyze batch dim\n        if solver_option.force_batch_dim_to_mesh_dim is not None:\n            batch_dim = self.batch_dim_analysis()\n            print(\"===== Batch Dim Analysis =====\")\n            for i in range(len(self.instructions)):\n                print(f\"Time {i:2d}: {self.instructions[i]}  Batch: {self.instructions[i].batch_dim}\")\n\n        # Build strategies and costs for each instruction\n        for ins in self.instructions:\n            ins.build_strategy_and_cost(cluster_env, solver_option)\n\n        # Build alias costs\n        for (ins_a, ins_b) in self.alias_list:\n            assert ins_a.shape == ins_b.shape\n            cost_vector = []\n            for stra_a in ins_a.strategies:\n                for stra_b in ins_b.strategies:\n                    if stra_a.output_spec == stra_b.output_spec:\n                        cost_vector.append(0)\n                    else:\n                        cost_vector.append(1)\n            self.alias_cost_vector.append(cost_vector)\n\n        self.strategy_built = True\n\n    def __enter__(self):\n        assert HloComputation.cur_env is None\n        HloComputation.cur_env = self\n\n    def __exit__(self, *args, **kwargs):\n        HloComputation.cur_env = None\n\n    def __str__(self):\n        strs = []\n        for i, ins in enumerate(self.instructions):\n            strs.append(f\"{i:2d}: \" + str(ins))\n        return \"\\n\".join(strs)\n \n"
  },
  {
    "path": "playground/auto_sharding_solver/run_all.sh",
    "content": "#!/bin/bash\n\npython3 -m unittest -bv *.py\n\n"
  },
  {
    "path": "playground/auto_sharding_solver/solver.py",
    "content": "\"\"\"ILP Solver\"\"\"\nimport numpy as np\n\nfrom alpa.shard_parallel.auto_sharding import _call_solver_serialized_args\n\n\ndef call_solver(N, M, s_len, s_follow, E, A, L, c, d, m, r, v, s_init):\n    \"\"\"Serialize python lists to flatten numpy arraies and call solver\"\"\"\n    # Serialize strategy lengths\n    s_len_np = np.array(s_len, dtype=np.int32)\n    s_follow_np = np.array(s_follow, dtype=np.int32)\n\n    # Serialize edge set\n    len_edges = len(E)\n    E_np = np.empty((len_edges, 2), dtype=np.int32)\n    for (idx, (i, j)) in enumerate(E):\n        E_np[idx][:] = [i, j]\n\n    # Serialize alias set\n    len_aliases = len(A)\n    A_np = np.empty((len_aliases, 2), dtype=np.int32)\n    for (idx, (i, j)) in enumerate(A):\n        A_np[idx][:] = [i, j]\n\n    # Serialize liveness set\n    len_liveness_set = N + sum(len(v) for v in L)\n    L_np = np.empty((len_liveness_set,), dtype=np.int32)\n    L_np[0:N] = [len(v) for v in L]\n    L_np[N:] = [x for v in L for x in v]\n\n    # Serialize node costs\n    len_node_costs = sum(len(v) for v in c)\n    c_np = np.empty((len_node_costs,), dtype=np.float32)\n    d_np = np.empty((len_node_costs,), dtype=np.float32)\n    m_np = np.empty((len_node_costs,), dtype=np.float32)\n    c_np[:] = [x for v in c for x in v]\n    d_np[:] = [x for v in d for x in v]\n    m_np[:] = [x for v in m for x in v]\n\n    # Serialize edge costs\n    len_edge_costs = sum(len(vec) for vec in r)\n    r_np = np.empty((len_edge_costs,), dtype=np.float32)\n    r_np[:] = [x for vec in r for x in vec]\n\n    # Serialize alias costs\n    len_alias_costs = sum(len(vec) for vec in v)\n    v_np = np.empty((len_alias_costs,), dtype=np.float32)\n    v_np[:] = [x for vec in v for x in vec]\n\n    # Serialize init value\n    s_init_np = None\n\n    return _call_solver_serialized_args(\n        N, M, s_len_np, s_follow_np, E_np, A_np, L_np,\n        c_np, d_np, m_np, r_np, v_np, s_init_np)\n\n\nclass CostGraph:\n    def __init__(self, node_lens, edges, edge_costs, to_merge_pair):\n        self.node_lens = node_lens\n        self.adjacency = dict()   # map a node to its neighbors\n        self.edge_costs = dict()  # map an edge to its cost matrix\n        self.reindexing_vector = dict()  # map a node to its reindexing vector\n        self.merged_to = dict()   # map an merged node to its destination\n        self.to_merge_pair = to_merge_pair  # the input follow pairs\n\n        for i in range(len(node_lens)):\n            self.adjacency[i] = set()\n\n        # For redundant edges, we will overwrite the results with\n        # the last value\n        for ((i, j), cost) in zip(edges, edge_costs):\n            cost = np.reshape(cost, (self.node_lens[i], self.node_lens[j]))\n\n            self.add_edge_cost(i, j, cost)\n\n    def get_edge_cost(self, i, j):\n        if i <= j:\n            return self.edge_costs[(i, j)]\n        else:\n            return self.edge_costs[(j, i)].transpose()\n\n    def add_edge_cost(self, i, j, cost):\n        if i > j:\n            i, j = j, i\n            cost = cost.transpose()\n\n        if (i, j) in self.edge_costs:\n            assert i in self.adjacency[j]\n            assert j in self.adjacency[i]\n            self.edge_costs[(i, j)] += cost\n        else:\n            self.adjacency[i].add(j)\n            self.adjacency[j].add(i)\n            self.edge_costs[(i, j)] = cost\n\n    def remove_edge(self, i, j):\n        if i > j:\n            i, j = j, i\n\n        assert j in self.adjacency[i]\n        assert i in self.adjacency[j]\n        assert (i, j) in self.edge_costs\n\n        self.adjacency[i].remove(j)\n        self.adjacency[j].remove(i)\n        del self.edge_costs[(i, j)]\n\n    def merge_node(self, src, dst):\n        \"\"\"Merge node src to node dst\"\"\"\n        print(f\"merge {src} to {dst}\")\n        assert dst in self.adjacency[src]\n        assert src in self.adjacency[dst]\n        assert dst not in self.merged_to\n        assert src != dst\n\n        edge_cost = self.get_edge_cost(dst, src)\n\n        # Find the strategy to follow greedily\n        reindexing = []\n        candidates = list(range(self.node_lens[src]))\n        for i in range(self.node_lens[dst]):\n            # Pick the strategy with the lowest cost to follow.\n            # If there are multiple strategies with the same lowest costs,\n            # prefer to follow \"replicated\", which has the largest index.\n            keys = [(edge_cost[i][j], -j) for j in range(self.node_lens[src])]\n            candidates.sort(key=lambda j: keys[j])\n            reindexing.append(candidates[0])\n\n        self.merged_to[src] = dst\n        self.reindexing_vector[src] = reindexing\n\n        # Merge edge cost matrix\n        adj_list = list(self.adjacency[src])\n        for adj in adj_list:\n            if adj == dst:\n                continue\n            added_edge_cost = np.empty((self.node_lens[dst], self.node_lens[adj]))\n            for i in range(self.node_lens[dst]):\n                j = reindexing[i]\n                edge_cost_src_adj = self.get_edge_cost(src, adj)\n                for k in range(self.node_lens[adj]):\n                    added_edge_cost[i][k] = edge_cost_src_adj[j][k] + edge_cost[i][j]\n\n            self.add_edge_cost(dst, adj, added_edge_cost)\n\n        # Remove edges\n        for adj in adj_list:\n            self.remove_edge(src, adj)\n\n    def query_destination(self, node):\n        if node in self.merged_to:\n            old_dst = self.merged_to[node]\n            new_dst = self.query_destination(old_dst)\n            if old_dst != new_dst:\n                # Compress path\n                old_reindexing_vector = self.reindexing_vector[node]\n                new_reindexing_vector = []\n                for i in range(self.node_lens[new_dst]):\n                    new_reindexing_vector.append(\n                        old_reindexing_vector[self.reindexing_vector[old_dst][i]])\n\n                self.reindexing_vector[node] = new_reindexing_vector\n                self.merged_to[node] = new_dst\n            return new_dst\n        else:\n            return node\n\n    def simplify(self):\n        for (src, dst) in self.to_merge_pair:\n            assert src not in self.merged_to\n            dst = self.query_destination(dst)\n            if src != dst:\n                self.merge_node(src, dst)\n\n    def export_result(self):\n        E = []\n        r = []\n        s_follow = []\n\n        for i in range(len(self.node_lens)):\n            if i in self.merged_to:\n                s_follow.append(self.query_destination(i))\n            else:\n                s_follow.append(-1)\n\n        for ((i, j), v) in self.edge_costs.items():\n            v = v.reshape(-1)\n            E.append((i, j))\n            r.append(v)\n\n            assert len(v) == self.node_lens[i] * self.node_lens[j]\n\n        return s_follow, E, r, self.reindexing_vector\n\n    def __str__(self):\n        ret = \"\"\n        for i in range(len(self.node_lens)):\n            ret += f\"Node {i}: {self.node_lens[i]}\\n\"\n\n        edges = list(self.edge_costs.keys())\n        edges.sort()\n\n        for (i, j) in edges:\n            ret += f\"Edge {(i, j)}:\\n\"\n            ret += str(self.edge_costs[(i, j)]) + \"\\n\"\n\n        return ret\n\n\nclass SolverOption:\n    def __init__(self):\n        self.force_batch_dim_to_mesh_dim = None\n\n        self.forward_backward_sep_id = None\n        self.force_all_reduce_cost = None\n        self.force_all_gather_cost = None\n        self.force_reduce_scatter_cost = None\n\n\ndef solve_auto_sharding(computation, cluster_env, solver_option=None):\n    print(\"===== Hlo Computation =====\")\n    print(computation, \"\\n\")\n\n    print(\"===== Liveness Analysis =====\")\n    liveness_dict = computation.liveness_analysis()\n    for i in range(len(computation.instructions)):\n        names = [ins.name for ins in liveness_dict[i]]\n        names.sort()\n        print(f\"Time: {i}, Live set: {names}\")\n\n    if solver_option is None:\n        solver_option = SolverOption()\n\n    # Build strategies and costs\n    computation.build_strategy_and_cost(cluster_env, solver_option)\n\n    # Build all constants for ILP\n    N = len(computation.instructions)\n    M = cluster_env.memory_per_device\n\n    s_len = []\n    follow_pair = []\n    E = []\n    A = []\n    L = []\n    c = []\n    d = []\n    m = []\n    r = []\n    v = []\n    for i in range(N):\n        ins = computation.instructions[i]\n        s_len.append(len(ins.strategies))\n        L.append([ins.index for ins in liveness_dict[i]])\n        c.append(ins.compute_costs)\n        d.append(ins.communication_costs)\n        m.append(ins.memory_costs)\n\n        if ins.follow_ins is not None:\n            follow_pair.append((ins.index, ins.follow_ins.index))\n\n        for op_idx, operand in enumerate(ins.operands):\n            E.append((operand.index, i))\n\n            src = operand.index\n            dst = i\n\n            #ins.resharding_costs  # [s_i, operand_idx, s_operand]\n            cost = []\n            for p in range(len(computation.instructions[src].strategies)):\n                for q in range(len(computation.instructions[dst].strategies)):\n                    cost.append(ins.resharding_costs[q][op_idx][p])\n            r.append(cost)\n\n    # Simplify the graph by merging nodes\n    cost_graph = CostGraph(s_len, E, r, follow_pair)\n    cost_graph.simplify()\n    s_follow, E, r, reindexing_vector = cost_graph.export_result()\n\n    for src, dst in enumerate(s_follow):\n        if dst >= 0:\n            s_len[src] = len(reindexing_vector[src])\n            c[src] = np.array(c[src])[reindexing_vector[src]]\n            d[src] = np.array(d[src])[reindexing_vector[src]]\n            m[src] = np.array(m[src])[reindexing_vector[src]]\n\n    # Deal with alias\n    for ((ins_a, ins_b), cost_vector) in zip(computation.alias_list,\n                                             computation.alias_cost_vector):\n\n        idx_a, idx_b = ins_a.index, ins_b.index\n        cost_vector = np.array(cost_vector).reshape(\n            len(ins_a.strategies), len(ins_b.strategies))\n\n        if s_follow[idx_a] >= 0:\n            reindexing_a = reindexing_vector[idx_a]\n            idx_a = s_follow[idx_a]\n        else:\n            reindexing_a = range(len(ins_a.strategies))\n\n        if s_follow[idx_b] >= 0:\n            reindexing_b = reindexing_vector[idx_b]\n            idx_b = s_follow[idx_b]\n        else:\n            reindexing_b = range(len(ins_b.strategies))\n\n        if idx_a != idx_b:\n            A.append((idx_a, idx_b))\n            new_cost_vector = []\n            for i in reindexing_a:\n                for j in reindexing_b:\n                    new_cost_vector.append(cost_vector[i, j])\n            v.append(new_cost_vector)\n\n    s_val, e_val, objective, status = call_solver(N, M, s_len, s_follow, E, A, L,\n                                                  c, d, m, r, v, s_init=None)\n\n    if True:\n        # Print sharding spec\n        instructions = computation.instructions\n        print(\"===== Sharding Strategy =====\")\n        for i in range(N):\n            if s_follow[i] < 0:\n                stra_idx = s_val[i]\n                name = instructions[i].strategies[stra_idx].name\n                follow_map = \"\"\n                spec = instructions[i].strategies[stra_idx].output_spec\n            else:\n                dst = s_follow[i]\n                stra_idx = reindexing_vector[i][s_val[i]]\n                name = instructions[i].strategies[stra_idx].name + f\" follow {dst}\"\n                spec = instructions[i].strategies[stra_idx].output_spec\n\n                follow_map = \"\"\n                for idx in range(len(reindexing_vector[i])):\n                    stra_idx = reindexing_vector[i][idx]\n                    follow_map += f\"[{instructions[dst].strategies[idx].name} -> \"\\\n                            f\"{instructions[i].strategies[stra_idx].name}] \"\n            #print(f\"Time {i:2d}: {computation.instructions[i]}  Strategy: {name} Spec: {spec}\")\n            print(f\"Time {i:2d}: {computation.instructions[i]}  Strategy: {name}\")\n            #if follow_map:\n            #    print(follow_map)\n\n        # Print edge cost\n        for (idx, (i, j)) in enumerate(E):\n            if r[idx][e_val[idx]] > 0:\n                print(f\"Edge cost {(i, j)} : {r[idx][e_val[idx]]}\")\n\n        # Print peak memory\n        print(\"===== Memory Usage =====\")\n        for t in range(N):\n            mem = 0\n            for i in L[t]:\n                mem += m[i][s_val[i]]\n            print(f\"Time {t}, memory: {mem / 1024**2: .2f} MB\")\n\n    return objective\n"
  },
  {
    "path": "playground/auto_sharding_solver/test_cost.py",
    "content": "import numpy as np\n\nfrom cluster_env import ClusterEnvironment\n\ndef s(*shape):\n    return np.prod(shape) * 4\n\nenv = ClusterEnvironment(np.ones((8, 1)), [1, 1], [0.02, 0.02], 0)\n\na = env.all_reduce_cost(s(16, 14, 14, 8192)) + env.all_reduce_cost(s(16, 28, 28, 2048)) + \\\n    env.all_to_all_cost(s(16, 28, 28, 4096))\n\nprint(a)\n\n\nb = env.all_gather_cost(s(16, 28, 28, 4096)) + env.all_gather_cost(s(1, 1, 4096, 8192))\nprint(b)\n\n\n"
  },
  {
    "path": "playground/auto_sharding_solver/test_sharding_spec.py",
    "content": "from hlo import ShardingSpec, ShardingSpecType\nfrom cluster_env import ClusterEnvironment\nfrom common import compute_bytes\n\n\ndef test_tile():\n    cluster_env = ClusterEnvironment([[0, 1, 2], [3, 4, 5]], [1,1], [1,1], None)\n\n    sharding = ShardingSpec.tile((12, 12), [0, 1], [0, 1], cluster_env)\n    assert sharding.tile_assignment_dimensions == (2, 3)\n    assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5)\n    assert sharding.replicate_on_last_tile_dim == False\n\n    sharding = ShardingSpec.tile((12, 12), [1, 0], [1, 0], cluster_env)\n    assert sharding.tile_assignment_dimensions == (2, 3)\n    assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5)\n    assert sharding.replicate_on_last_tile_dim == False\n\n    sharding = ShardingSpec.tile((12, 12), [0, 1], [1, 0], cluster_env)\n    assert sharding.tile_assignment_dimensions == (3, 2)\n    assert sharding.tile_assignment_devices == (0, 3, 1, 4, 2, 5)\n    assert sharding.replicate_on_last_tile_dim == False\n\n    sharding = ShardingSpec.tile((12, 12), [0], [0], cluster_env)\n    assert sharding.tile_assignment_dimensions == (2, 1, 3)\n    assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5)\n    assert sharding.replicate_on_last_tile_dim == True\n\n    sharding = ShardingSpec.tile((12, 12), [0], [1], cluster_env)\n    assert sharding.tile_assignment_dimensions == (3, 1, 2)\n    assert sharding.tile_assignment_devices == (0, 3, 1, 4, 2, 5)\n    assert sharding.replicate_on_last_tile_dim == True\n\n    sharding = ShardingSpec.tile((12, 12), [1], [1], cluster_env)\n    assert sharding.tile_assignment_dimensions == (1, 3, 2)\n    assert sharding.tile_assignment_devices == (0, 3, 1, 4, 2, 5)\n    assert sharding.replicate_on_last_tile_dim == True\n\n    sharding = ShardingSpec.tile((12, 12), [1], [0], cluster_env)\n    assert sharding.tile_assignment_dimensions == (1, 2, 3)\n    assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5)\n    assert sharding.replicate_on_last_tile_dim == True\n\n    sharding = ShardingSpec.tile((12, 12, 12), [0, 1], [0, 1], cluster_env)\n    assert sharding.tile_assignment_dimensions == (2, 3, 1)\n    assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5)\n    assert sharding.replicate_on_last_tile_dim == False\n\n    sharding = ShardingSpec.tile((12, 12, 12), [0, 1], [1, 0], cluster_env)\n    assert sharding.tile_assignment_dimensions == (3, 2, 1)\n    assert sharding.tile_assignment_devices == (0, 3, 1, 4, 2, 5)\n    assert sharding.replicate_on_last_tile_dim == False\n\n    sharding = ShardingSpec.tile((12, 12, 12), [1], [0], cluster_env)\n    assert sharding.tile_assignment_dimensions == (1, 2, 1, 3)\n    assert sharding.tile_assignment_devices == (0, 1, 2, 3, 4, 5)\n    assert sharding.replicate_on_last_tile_dim == True\n\n\ndef test_tile2():\n    cluster_env = ClusterEnvironment([[0, 1, 2, 3]], [1,1], [1,1], None)\n    sharding = ShardingSpec.tile((12, 12), [1], [1], cluster_env)\n    assert sharding.tile_assignment_dimensions == (1, 4)\n    assert sharding.tile_assignment_devices == (0, 1, 2, 3)\n    assert sharding.replicate_on_last_tile_dim == False\n\n    sharding = ShardingSpec.tile((12, 12), [1], [0], cluster_env)\n    assert sharding.type == ShardingSpecType.REPLICATED\n\n    cluster_env = ClusterEnvironment([[0], [1], [2], [3]], [1,1], [1,1], None)\n    sharding = ShardingSpec.tile((12, 12), [1], [0], cluster_env)\n    assert sharding.tile_assignment_dimensions == (1, 4)\n    assert sharding.tile_assignment_devices == (0, 1, 2, 3)\n    assert sharding.replicate_on_last_tile_dim == False\n\n    sharding = ShardingSpec.tile((12, 12), [1], [1], cluster_env)\n    assert sharding.type == ShardingSpecType.REPLICATED\n\n\ndef test_tile3():\n    cluster_env = ClusterEnvironment([[0, 1], [2, 3]], [1,1], [1,1], None)\n    shape = (12, 12)\n    src = ShardingSpec.split(shape, 1, cluster_env)\n    dst = ShardingSpec.tile(shape, [0], [0], cluster_env)\n\n    print(src)\n    print(dst)\n    cost = cluster_env.resharding_cost(shape, src, dst)\n\n    print(cost)\n\n\ndef assert_allclose(x, y):\n    assert abs((x - y) / (y + 1e-8))  < 0.01\n\n\ndef test_resharding_cost():\n    cluster_env = ClusterEnvironment([[0, 1, 2], [3, 4, 5]], [1, 1], [1, 1], None)\n    shape = (128, 128)\n\n    src = ShardingSpec.tile(shape, [0], [0], cluster_env)\n    dst = ShardingSpec.tile(shape, [0], [0], cluster_env)\n    cost = cluster_env.resharding_cost(shape, src, dst)\n    assert_allclose(cost, 0)\n\n    src = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env)\n    dst = ShardingSpec.tile(shape, [1, 0], [1, 0], cluster_env)\n    cost = cluster_env.resharding_cost(shape, src, dst)\n    assert_allclose(cost, 0)\n\n    src = ShardingSpec.tile(shape, [0], [0], cluster_env)\n    dst = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env)\n    cost = cluster_env.resharding_cost(shape, src, dst)\n    assert_allclose(cost, 0)\n\n    src = ShardingSpec.tile(shape, [0], [0], cluster_env)\n    dst = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env)\n    cost = cluster_env.resharding_cost(shape, src, dst)\n    assert_allclose(cost, 0)\n\n    src = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env)\n    dst = ShardingSpec.tile(shape, [0], [0], cluster_env)\n    cost = cluster_env.resharding_cost(shape, src, dst)\n    assert_allclose(cost, cluster_env.all_gather_cost(compute_bytes(shape), 1))\n\n    src = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env)\n    dst = ShardingSpec.replicated(cluster_env)\n    cost = cluster_env.resharding_cost(shape, src, dst)\n    assert_allclose(cost, cluster_env.all_gather_cost(compute_bytes(shape), 0)\n                        + cluster_env.all_gather_cost(compute_bytes(shape), 1))\n\n\ndef test_resharding_cost2():\n    cluster_env = ClusterEnvironment([[0], [1], [2], [3]], [1,1], [1,1], None)\n    shape = (128, 128)\n\n    src = ShardingSpec.tile(shape, [0, 1], [0, 1], cluster_env)\n    dst = ShardingSpec.tile(shape, [0], [0], cluster_env)\n    cost = cluster_env.resharding_cost(shape, src, dst)\n    assert_allclose(cost, 0)\n\n\nif __name__ == \"__main__\":\n    test_tile()\n    test_tile2()\n    #test_tile3()\n    test_resharding_cost()\n    test_resharding_cost2()\n\n"
  },
  {
    "path": "playground/auto_sharding_solver/test_solver_attention.py",
    "content": "\"\"\"\nUsage:\npython3 -m unittest -bv test_solver_attention.py\n\"\"\"\nfrom collections import defaultdict\nfrom enum import Enum\nimport unittest\n\nimport numpy as np\n\nfrom hlo import *\nfrom cluster_env import ClusterEnvironment\nfrom solver import solve_auto_sharding, SolverOption\n\nMB = 1024 ** 2\n\ndef assert_close(x, y):\n    assert abs(x / y - 1) < 0.001, f\"{x} vs. {y}\"\n\n\ndef solve_without_all_gather(computation, mesh_shape):\n    device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n    solver_option = SolverOption()\n    solver_option.force_all_gather_cost = 1e8\n    cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1],\n                                     memory_per_device=1000 * MB,\n                                     solver_option=solver_option)\n    objective = solve_auto_sharding(computation, cluster_env, solver_option)\n    return objective, cluster_env\n\n\ndef get_attention_forward_computation(batch_size, seq_len, hidden_dim, num_head, force_replicated_output):\n    per_head = hidden_dim // num_head\n    computation = HloComputation()\n\n    with computation:\n        # hidden states\n        hidden_states = HloParameter((batch_size, seq_len, hidden_dim))\n        hidden_states = HloReshape(hidden_states, (batch_size * seq_len, hidden_dim))\n\n        # query matmul\n        weight_query_dense = HloParameter((hidden_dim, num_head, per_head))\n        weight_query_dense_ = HloReshape(weight_query_dense, (hidden_dim, hidden_dim))\n        query = HloDot(hidden_states, weight_query_dense_)\n        query = HloReshape(query, (batch_size, seq_len, num_head, per_head))\n\n        # query bias_add\n        bias_query_dense = HloParameter((num_head, per_head))\n        bias_query_dense_ = HloBroadcast(bias_query_dense, (batch_size, seq_len, num_head, per_head), dimensions=(2, 3))\n        query = HloAdd(query, bias_query_dense_)\n\n        # query normalization\n        c = HloConstant(0.125)\n        c = HloBroadcast(c, (batch_size, seq_len, num_head, per_head))\n        query = HloMutiply(c, query)\n        # query transpose\n        query = HloTranspose(query, [0, 2, 1, 3])\n\n        # key matmul\n        weight_key_dense = HloParameter((hidden_dim, num_head, per_head))\n        weight_key_dense_ = HloReshape(weight_key_dense, (hidden_dim, hidden_dim))\n        key = HloDot(hidden_states, weight_key_dense_)\n        key = HloReshape(key, (batch_size, seq_len, num_head, per_head))\n\n        # key bias_add\n        bias_key_dense = HloParameter((num_head, per_head))\n        bias_key_dense_ = HloBroadcast(bias_key_dense, (batch_size, seq_len, num_head, per_head), dimensions=(2, 3))\n        key = HloAdd(key, bias_key_dense_)\n\n        # key transpose\n        key = HloTranspose(key, [0, 2, 3, 1])\n\n        # att_weight\n        att_weight = HloDot(query, key,\n                            lhs_batch_dims=(0,1), lhs_contracting_dims=(3,),\n                            rhs_batch_dims=(0,1), rhs_contracting_dims=(2,))\n\n        # mask\n        mask = HloParameter((batch_size, seq_len))\n\n        # attention_bias_pred\n        zero = HloConstant(0)\n        zero = HloBroadcast(zero, (batch_size, seq_len))\n        pred = HloCompare(mask, zero)\n\n        # all zero\n        zero = HloConstant(0)\n        zero = HloBroadcast(zero, (batch_size, seq_len))\n\n        # all neg-infinity\n        neg_inf = HloConstant(-1e10)\n        neg_inf = HloBroadcast(neg_inf, (batch_size, seq_len))\n\n        # attention bias\n        select = HloSelect(pred, zero, neg_inf)\n\n        # attention bias_add\n        att_bias = HloBroadcast(select, (batch_size, num_head, seq_len, seq_len), dimensions=(0, 3))\n        att_weight = HloAdd(att_weight, att_bias)\n\n        # softmax_max\n        max_reduce = HloReduce(att_weight, dimensions=(3,))\n        max_reduce = HloBroadcast(max_reduce, (batch_size, num_head, seq_len, seq_len), dimensions=(0, 1, 2))\n        diff = HloSubtract(att_weight, max_reduce)\n        exp = HloExp(diff)\n        # softmax_sum\n        sum_reduce = HloReduce(exp, dimensions=(3,))\n        sum_reduce = HloBroadcast(sum_reduce, (batch_size, num_head, seq_len, seq_len), dimensions=(0, 1, 2))\n        # softmax_norm\n        softmax = HloDiv(exp, sum_reduce)\n\n        # value matmul\n        weight_value_dense = HloParameter((hidden_dim, num_head, per_head))\n        weight_value_dense_ = HloReshape(weight_value_dense, (hidden_dim, hidden_dim))\n        value = HloDot(hidden_states, weight_value_dense_)\n        value = HloReshape(value, (batch_size, seq_len, num_head, per_head))\n\n        # value bias_add\n        bias_value_dense = HloParameter((num_head, per_head))\n        bias_value_dense_ = HloBroadcast(bias_value_dense, (batch_size, seq_len, num_head, per_head), dimensions=(2, 3))\n        value = HloAdd(value, bias_value_dense_)\n\n        # value transpose\n        value = HloTranspose(value, [0, 2, 3, 1])\n\n        # self attention\n        self_att = HloDot(value, softmax,\n                          lhs_batch_dims=(0, 1), lhs_contracting_dims=(3,),\n                          rhs_batch_dims=(0, 1), rhs_contracting_dims=(3,))\n        self_att = HloTranspose(self_att, [0, 3, 1, 2])\n        self_att = HloReshape(self_att, [batch_size * seq_len, hidden_dim])\n\n        # out matmul\n        weight_out_dense = HloParameter((hidden_dim, num_head, per_head))\n        weight_out_dense_ = HloReshape(weight_out_dense, (hidden_dim, hidden_dim))\n        out = HloDot(self_att, weight_out_dense_)\n        out = HloReshape(out, (batch_size, seq_len, hidden_dim))\n\n        # out bias_add\n        bias_out_dense = HloParameter((hidden_dim,))\n        bias_out_dense_ = HloBroadcast(bias_out_dense, (batch_size, seq_len, hidden_dim), dimensions=(2,))\n        out = HloAdd(out, bias_out_dense_)\n\n        if force_replicated_output:\n            out = HloForceReplicated(out)\n\n        out = HloTuple([out,\n                        weight_value_dense, bias_value_dense, \n                        weight_query_dense, bias_query_dense,\n                        weight_key_dense, bias_key_dense,\n                        weight_out_dense, bias_out_dense,\n        ])\n\n    return computation\n\nclass AttentionSolverTest(unittest.TestCase):\n    def test_tranpose(self):\n        # Build Hlo Computation\n        computation = HloComputation()\n        dim_0 = 128\n        dim_1 = 2048\n\n        with computation:\n            x = HloParameter((dim_1, dim_0))\n            y = HloParameter((dim_0, dim_1))\n            x = HloTranspose(x, [1, 0])\n            y = HloTranspose(y, [1, 0])\n            out = HloDot(x, y)\n            out = HloTranspose(out, [1, 0])\n            out = HloForceReplicated(out)\n            out = HloTuple((out,))\n\n        # Solve\n        mesh_shape = [1, 4]\n        objective, cluster_env = solve_without_all_gather(computation, mesh_shape)\n\n        expected = cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 1)\n        print(\"Objective:\", objective)\n        print(\"Expected:\", expected)\n        assert_close(objective, expected)\n\n    def test_mulit_tranpose(self):\n        # Build Hlo Computation\n        computation = HloComputation()\n        dim_0 = 128\n        dim_1 = 2048\n\n        with computation:\n            x = HloParameter((dim_1, dim_0))\n            y = HloParameter((dim_0, dim_1))\n            x = HloTranspose(x, [1, 0])\n            y = HloTranspose(y, [1, 0])\n            x = HloTranspose(x, [1, 0])\n            y = HloTranspose(y, [1, 0])\n            x = HloTranspose(x, [1, 0])\n            y = HloTranspose(y, [1, 0])\n            out = HloDot(x, y)\n            out = HloTranspose(out, [1, 0])\n            out = HloTranspose(out, [1, 0])\n            out = HloForceReplicated(out)\n            out = HloTuple((out,))\n\n        # Solve\n        mesh_shape = [4, 1]\n        objective, cluster_env = solve_without_all_gather(computation, mesh_shape)\n\n        expected = cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 0)\n        print(\"Objective:\", objective)\n        print(\"Expected:\", expected)\n        assert_close(objective, expected)\n\n\n    def test_reshape(self):\n        # Build Hlo Computation\n        computation = HloComputation()\n        dim_0 = 128\n        dim_1 = 2048\n\n        with computation:\n            x = HloParameter((dim_0, dim_1 // 2, 2))\n            y = HloParameter((dim_1 // 2, 2, dim_0))\n            x = HloReshape(x, (dim_0, dim_1))\n            y = HloReshape(y, (dim_1, dim_0))\n            out = HloDot(x, y)\n            out = HloForceReplicated(out)\n            out = HloTuple((out,))\n\n        # Solve\n        mesh_shape = [1, 4]\n        objective, cluster_env = solve_without_all_gather(computation, mesh_shape)\n\n        expected = cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 1)\n        print(\"Objective:\", objective)\n        print(\"Expected:\", expected)\n        assert_close(objective, expected)\n\n    def test_mulit_reshape(self):\n        # Build Hlo Computation\n        computation = HloComputation()\n        dim_0 = 128\n        dim_1 = 2048\n\n        with computation:\n            x = HloParameter((dim_0, dim_1 // 2, 2))\n            y = HloParameter((dim_1 // 2, 2, dim_0))\n            x = HloReshape(x, (dim_0, dim_1))\n            y = HloReshape(y, (dim_1, dim_0))\n            x = HloReshape(x, (dim_0 // 4, 4, dim_1))\n            y = HloReshape(y, (dim_1 // 4, 4, dim_0))\n            x = HloReshape(x, (dim_0, dim_1))\n            y = HloReshape(y, (dim_1, dim_0))\n            out = HloDot(x, y)\n            out = HloReshape(out, (dim_0, 2, dim_0 // 2))\n            out = HloForceReplicated(out)\n            out = HloTuple((out,))\n\n        # Solve\n        mesh_shape = [4, 1]\n        objective, cluster_env = solve_without_all_gather(computation, mesh_shape)\n\n        expected = cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 0)\n        print(\"Objective:\", objective)\n        print(\"Expected:\", expected)\n        assert_close(objective, expected)\n\n    def test_allreduce_simplification(self):\n        # Build Hlo Computation\n        computation = HloComputation()\n        dim_0 = 128\n        dim_1 = 2048\n\n        with computation:\n            x = HloParameter((dim_0, dim_1))\n            y = HloParameter((dim_1, dim_0))\n            h1 = HloDot(x, y)\n            h2 = HloDot(x, y)\n            out = HloAdd(h1, h2)\n            out = HloForceReplicated(out)\n            out = HloTuple((out,))\n\n        # Solve\n        mesh_shape = [1, 4]\n        objective, cluster_env = solve_without_all_gather(computation, mesh_shape)\n\n        expected = 2 * cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 1)\n        print(\"Objective:\", objective)\n        print(\"Expected:\", expected)\n        assert_close(objective, expected)\n\n    def test_allreduce_simplification_out_reuse(self):\n        # Build Hlo Computation\n        computation = HloComputation()\n        dim_0 = 128\n        dim_1 = 2048\n\n        with computation:\n            x = HloParameter((dim_0, dim_1))\n            y = HloParameter((dim_1, dim_0))\n            z = HloParameter((dim_0 // 4, 4, dim_0))\n            h1 = HloDot(x, y)\n            h2 = HloDot(x, y)\n            h3 = HloDot(x, y)\n            h1 = HloReshape(h1, (dim_0 // 4, 4, dim_0))\n            h2 = HloReshape(h2, (dim_0 // 4, 4, dim_0))\n            h3 = HloReshape(h3, (dim_0 // 4, 4, dim_0))\n            out = z\n            out = HloAdd(out, h1)\n            out = HloAdd(out, h2)\n            out = HloAdd(out, h3)\n            b1 = HloExp(out)\n            b2 = HloExp(out)\n            b3 = HloExp(out)\n            b4 = HloExp(out)\n            b5 = HloExp(out)\n            b6 = HloExp(out)\n            b7 = HloForceReplicated(b6)\n            out = HloTuple((b1, b2, b3, b4, b5, b6, b7))\n\n        # Solve\n        mesh_shape = [1, 4]\n        objective, cluster_env = solve_without_all_gather(computation, mesh_shape)\n\n        expected = 3 * cluster_env.all_reduce_cost(dim_0 * dim_0 * 4, 1)\n        print(\"Objective:\", objective)\n        print(\"Expected:\", expected)\n        assert_close(objective, expected)\n\n    def test_attention_forward(self):\n        # Build Hlo Computation\n        batch_size = 4\n        seq_len = 128\n        hidden_dim = 512\n        num_head = 16\n\n        computation = get_attention_forward_computation(\n            batch_size, seq_len, hidden_dim, num_head, True)\n\n        # Solve\n        for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]):\n            objective, cluster_env = solve_without_all_gather(computation, mesh_shape)\n\n            expected = cluster_env.all_reduce_cost(batch_size * seq_len * hidden_dim * 4, i)\n            print(\"Objective:\", objective)\n            print(\"Expected:\", expected)\n            assert_close(objective, expected)\n\n    def test_attention_forward_2d_mesh(self):\n        # Build Hlo Computation\n        batch_size = 4\n        seq_len = 128\n        hidden_dim = 2048\n        num_head = 16\n\n        computation = get_attention_forward_computation(\n            batch_size, seq_len, hidden_dim, num_head, False)\n\n        # Solve\n        mesh_shape = [4, 4]\n        objective, cluster_env = solve_without_all_gather(computation, mesh_shape)\n\n        expected = cluster_env.all_reduce_cost(\n            batch_size * seq_len * hidden_dim * 4 / mesh_shape[0], 1)\n        print(\"Objective:\", objective)\n        print(\"Expected:\", expected)\n        assert_close(objective, expected)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(AttentionSolverTest('test_tranpose'))\n    suite.addTest(AttentionSolverTest('test_mulit_tranpose'))\n    suite.addTest(AttentionSolverTest('test_reshape'))\n    suite.addTest(AttentionSolverTest('test_mulit_reshape'))\n    suite.addTest(AttentionSolverTest('test_allreduce_simplification'))\n    suite.addTest(AttentionSolverTest('test_allreduce_simplification_out_reuse'))\n    suite.addTest(AttentionSolverTest('test_attention_forward'))\n    suite.addTest(AttentionSolverTest('test_attention_forward_2d_mesh'))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n\n"
  },
  {
    "path": "playground/auto_sharding_solver/test_solver_mlp.py",
    "content": "\"\"\"\nUsage:\npython3 -m unittest -bv test_solver_mlp.py\n\"\"\"\nfrom collections import defaultdict\nfrom enum import Enum\nimport unittest\n\nimport numpy as np\n\nfrom hlo import *\nfrom cluster_env import ClusterEnvironment\nfrom solver import solve_auto_sharding, SolverOption\n\nMB = 1024 ** 2\n\n\ndef assert_close(x, y):\n    assert abs(x / y - 1) < 0.001, f\"{x} vs. {y}\"\n\n\ndef get_mlp_2_layer_computation(batch_size, input_dim, hidden_dim, output_dim):\n    computation = HloComputation()\n    with computation:\n        x = HloParameter((batch_size, input_dim))\n        y = HloParameter((batch_size, output_dim))\n        w1 = HloParameter((input_dim, hidden_dim))\n        w2 = HloParameter((hidden_dim, output_dim))\n\n        ## forward\n        h1 = HloDot(x, w1)\n        h2 = HloDot(h1, w2)\n        loss = HloSubtract(h2, y)\n\n        ## backward\n        coef = HloConstant(2 / batch_size / output_dim)\n        coef = HloBroadcast(coef, (batch_size, output_dim))\n        grad_loss = HloMutiply(loss, coef)\n\n        grad_w2 = HloDot(h1, grad_loss,\n                         lhs_contracting_dims=(0,),\n                         rhs_contracting_dims=(0,),)\n        new_w2 = HloSubtract(w2, grad_w2)\n        grad_h1 = HloDot(grad_loss, w2,\n                         lhs_contracting_dims=(1,),\n                         rhs_contracting_dims=(1,),)\n\n        grad_w1 = HloDot(x, grad_h1,\n                         lhs_contracting_dims=(0,),\n                         rhs_contracting_dims=(0,),)\n        new_w1 = HloSubtract(w1, grad_w1)\n        out = HloTuple((new_w1, new_w2))\n\n        ## alias\n        computation.set_alias([(w1, new_w1), (w2, new_w2)])\n\n        \"\"\"\n         0: parameter.0 (128, 1024) = parameter()\n         1: parameter.1 (128, 1024) = parameter()\n         2: parameter.2 (1024, 1024) = parameter()\n         3: parameter.3 (1024, 1024) = parameter()\n         4: dot.0 (128, 1024) = dot(parameter.0, parameter.2)  lhs_con_dim=(1,), rhs_con_dim=(0,)\n         5: dot.1 (128, 1024) = dot(dot.0, parameter.3)  lhs_con_dim=(1,), rhs_con_dim=(0,)\n         6: subtract.0 (128, 1024) = subtract(dot.1, parameter.1)\n         7: constant.0 () = constant(1.52587891e-05)\n         8: broadcast.0 (128, 1024) = broadcast(constant.0)\n         9: multiply.0 (128, 1024) = multiply(subtract.0, broadcast.0)\n        10: dot.2 (1024, 1024) = dot(dot.0, multiply.0)  lhs_con_dim=(0,), rhs_con_dim=(0,)\n        11: subtract.1 (1024, 1024) = subtract(parameter.2, dot.2)\n        12: dot.3 (128, 1024) = dot(multiply.0, parameter.3)  lhs_con_dim=(1,), rhs_con_dim=(1,)\n        13: dot.4 (1024, 1024) = dot(parameter.0, dot.3)  lhs_con_dim=(0,), rhs_con_dim=(0,)\n        14: subtract.2 (1024, 1024) = subtract(parameter.2, dot.4)\n        15: tuple.0 () = tuple('subtract.2', 'subtract.1') \n        \"\"\"\n    return computation\n\n\ndef get_mlp_2_layer_bias_computation(batch_size, input_dim, hidden_dim, output_dim):\n    computation = HloComputation()\n    with computation:\n        x = HloParameter((batch_size, input_dim))\n        y = HloParameter((batch_size, output_dim))\n        w1 = HloParameter((input_dim, hidden_dim))\n        w2 = HloParameter((hidden_dim, output_dim))\n        b1 = HloParameter((hidden_dim,))\n        b2 = HloParameter((output_dim,))\n\n        ## forward\n        h1 = HloDot(x, w1)\n        bb1 = HloBroadcast(b1, (batch_size, hidden_dim), dimensions=(1,))\n        h1_add = HloAdd(h1, bb1)\n\n        h2 = HloDot(h1_add, w2)\n        bb2 = HloBroadcast(b2, (batch_size, output_dim), dimensions=(1,))\n        h2_add = HloAdd(h2, bb2)\n\n        loss = HloSubtract(h2_add, y)\n\n        ## backward\n        coef = HloConstant(2 / batch_size / output_dim)\n        coef = HloBroadcast(coef, (batch_size, output_dim))\n        grad_loss = HloMutiply(loss, coef)\n\n        grad_w2 = HloDot(h1_add, grad_loss,\n                         lhs_contracting_dims=(0,),\n                         rhs_contracting_dims=(0,),)\n        new_w2 = HloSubtract(w2, grad_w2)\n\n        grad_h1 = HloDot(grad_loss, w2,\n                         lhs_contracting_dims=(1,),\n                         rhs_contracting_dims=(1,),)\n\n        grad_w1 = HloDot(x, grad_h1,\n                         lhs_contracting_dims=(0,),\n                         rhs_contracting_dims=(0,),)\n        new_w1 = HloSubtract(w1, grad_w1)\n\n        grad_b1 = HloReduce(grad_h1, dimensions=[0])\n        new_b1 = HloSubtract(b1, grad_b1)\n\n        grad_b2 = HloReduce(grad_loss, dimensions=[0])\n        new_b2 = HloSubtract(b2, grad_b2)\n\n        out = HloTuple((new_w1, new_w2, new_b1, new_b2))\n\n        ## alias\n        computation.set_alias([(w1, new_w1), (w2, new_w2)])\n\n    return computation\n\n\ndef get_mlp_n_layer_computation(num_layers, batch_size, input_dim, hidden_dim, output_dim):\n    computation = HloComputation()\n    with computation:\n        x = HloParameter((batch_size, input_dim))\n        y = HloParameter((batch_size, output_dim))\n        w_first = HloParameter((input_dim, hidden_dim))\n        w_inter = []\n        for i in range(num_layers - 2):\n            manual_strategy = \"S0\" if i % 2 == 0 else \"S1\"\n            w_inter.append(HloParameter((hidden_dim, hidden_dim)))\n        w_last = HloParameter((hidden_dim, output_dim))\n\n        # forward\n        h_first = HloDot(x, w_first)\n        h_now = h_first\n        h_inter = []\n        for i in range(num_layers - 2):\n            h_now = HloDot(h_now, w_inter[i])\n            h_inter.append(h_now)\n        h_last = HloDot(h_now, w_last)\n\n        loss = HloSubtract(h_last, y)\n\n        # backward\n        coef = HloConstant(2 / batch_size / output_dim)\n        coef = HloBroadcast(coef, (batch_size, output_dim))\n        grad_loss = HloMutiply(loss, coef)\n        grad_h_now = grad_loss\n\n        grad_w_last = HloDot(h_inter[-1], grad_h_now,\n                             lhs_contracting_dims=(0,),\n                             rhs_contracting_dims=(0,),)\n        new_w_last = HloSubtract(w_last, grad_w_last)\n        grad_h_now = HloDot(grad_h_now, w_last,\n                             lhs_contracting_dims=(1,),\n                             rhs_contracting_dims=(1,),)\n\n        new_w_inter = []\n        for i in range(num_layers - 3, -1, -1):\n            grad_w = HloDot(h_inter[i-1], grad_h_now,\n                            lhs_contracting_dims=(0,),\n                            rhs_contracting_dims=(0,),)\n            new_w = HloSubtract(w_inter[i], grad_w)\n            grad_h_now = HloDot(grad_h_now, w_inter[i],\n                                lhs_contracting_dims=(1,),\n                                rhs_contracting_dims=(1,),)\n            new_w_inter.append(new_w)\n\n        grad_w_first = HloDot(x, grad_h_now,\n                              lhs_contracting_dims=(0,),\n                              rhs_contracting_dims=(0,),)\n        new_w_first = HloSubtract(w_first, grad_w_first)\n\n        out = HloTuple([new_w_first] + new_w_inter + [new_w_last])\n\n        # alias\n        alias_list = [(w_first, new_w_first), (w_last, new_w_last)] +\\\n            [(w_old, w_new) for w_old, w_new in zip(w_inter, reversed(new_w_inter))]\n        computation.set_alias(alias_list)\n    return computation\n\n\nclass MLPSolverTest(unittest.TestCase):\n    def test_mlp_2_layer_data_parallel(self):\n        # Build Hlo Computation\n        batch_size = 1024\n        hidden_dim = 128\n\n        computation = get_mlp_2_layer_computation(batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]):\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            # The expecte cost is always two all-reduce on weights\n            expected = 2 * cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4, i)\n            assert_close(objective, expected)\n\n    def test_mlp_2_layer_model_parallel(self):\n        # Build Hlo Computation\n        batch_size = 128\n        hidden_dim = 1024\n\n        computation = get_mlp_2_layer_computation(batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]):\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            # The expecte cost is always one all-reduce on activations\n            expected = cluster_env.all_reduce_cost(batch_size * hidden_dim * 4, i)\n            assert_close(objective, expected)\n\n    def test_mlp_n_layer_data_parallel(self):\n        # Build Hlo Computation\n        num_layers = 12\n        batch_size = 1024\n        hidden_dim = 128\n\n        computation = get_mlp_n_layer_computation(num_layers, batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]):\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            expected = num_layers *\\\n                       cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4, i)\n            assert_close(objective, expected)\n\n    def test_mlp_n_layer_model_parallel(self):\n        # Build Hlo Computation\n        num_layers = 12\n        batch_size = 128\n        hidden_dim = 1024\n\n        computation = get_mlp_n_layer_computation(num_layers, batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        for i, mesh_shape in enumerate([ (4, 1), (1, 4) ]):\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            expected = (num_layers - 1) *\\\n                       cluster_env.all_reduce_cost(batch_size * hidden_dim * 4, i)\n            assert_close(objective, expected)\n\n    def test_mlp_2_layer_2d_mesh(self):\n        # Build Hlo Computation\n        batch_size = 1024\n        hidden_dim = 128\n\n        computation = get_mlp_2_layer_computation(batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        for mesh_shape in [(4, 8), (8, 4), (3, 4)]:\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 0.01],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            expected =\\\n                2 * cluster_env.all_reduce_cost(\n                    hidden_dim * hidden_dim * 4 / mesh_shape[1], 0) +\\\n               cluster_env.all_reduce_cost(batch_size * hidden_dim * 4 / mesh_shape[0], 1)\n            assert_close(objective, expected)\n\n    def test_mlp_n_layer_2d_mesh(self):\n        # Build Hlo Computation\n        num_layers = 12\n        batch_size = 1024\n        hidden_dim = 128\n\n        computation = get_mlp_n_layer_computation(num_layers, batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        for mesh_shape in [(4, 8), (8, 4), (3, 4)]:\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 0.01],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            expected = \\\n                num_layers * cluster_env.all_reduce_cost(\n                    hidden_dim * hidden_dim * 4 / mesh_shape[1], 0) +\\\n                (num_layers - 1)  * cluster_env.all_reduce_cost(\n                   batch_size * hidden_dim * 4 / mesh_shape[0], 1)\n            assert_close(objective, expected)\n\n    def test_mlp_2_layer_bias_data_parallel(self):\n        # Build Hlo Computation\n        batch_size = 1024\n        hidden_dim = 128\n\n        computation = get_mlp_2_layer_bias_computation(batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            expected = \\\n                cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4, i) * 2 +\\\n                cluster_env.all_reduce_cost(hidden_dim * 4, i) * 2\n            assert_close(objective, expected)\n\n    def test_mlp_2_layer_bias_model_parallel(self):\n        # Build Hlo Computation\n        batch_size = 128\n        hidden_dim = 1024\n\n        computation = get_mlp_2_layer_bias_computation(batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            expected = cluster_env.all_reduce_cost(batch_size * hidden_dim * 4, i)\n            assert_close(objective, expected)\n\n    def test_mlp_2_layer_bias_2d_mesh(self):\n        # Build Hlo Computation\n        batch_size = 1024\n        hidden_dim = 128\n\n        computation = get_mlp_2_layer_bias_computation(batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        for mesh_shape in [(4, 8), (8, 4), (3, 4)]:\n            device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n            cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 0.01],\n                                             memory_per_device=1000 * MB)\n            objective = solve_auto_sharding(computation, cluster_env)\n\n            expected = \\\n                cluster_env.all_reduce_cost(batch_size * hidden_dim * 4 / mesh_shape[0], 1) +\\\n                cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4 / mesh_shape[1], 0) * 2 +\\\n                cluster_env.all_reduce_cost(hidden_dim * 4, 0) +\\\n                cluster_env.all_reduce_cost(hidden_dim * 4 / mesh_shape[1], 0)\n            assert_close(objective, expected)\n\n\n    def test_mlp_2_layer_force_data_parallel(self):\n        # Build Hlo Computation\n        batch_size = 128\n        hidden_dim = 1024\n\n        computation = get_mlp_2_layer_computation(batch_size, hidden_dim,\n            hidden_dim, hidden_dim)\n\n        # Test different device meshes\n        mesh_shape = [4, 1]\n        device_mesh = np.arange(np.prod(mesh_shape)).reshape(mesh_shape)\n        solver_option = SolverOption()\n        solver_option.force_batch_dim_to_mesh_dim = 0\n        solver_option.force_all_gather_cost = 1e10\n        cluster_env = ClusterEnvironment(device_mesh, [1, 1], [1, 1],\n                                         memory_per_device=1000 * MB,\n                                         solver_option=solver_option)\n        objective = solve_auto_sharding(computation, cluster_env, solver_option)\n\n        # The expecte cost is always one all-reduce on activations\n        expected = 2 * cluster_env.all_reduce_cost(hidden_dim * hidden_dim * 4, 0)\n        assert_close(objective, expected)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(MLPSolverTest('test_mlp_2_layer_data_parallel'))\n    suite.addTest(MLPSolverTest('test_mlp_2_layer_model_parallel'))\n    suite.addTest(MLPSolverTest('test_mlp_n_layer_data_parallel'))\n    suite.addTest(MLPSolverTest('test_mlp_n_layer_model_parallel'))\n\n    suite.addTest(MLPSolverTest('test_mlp_2_layer_2d_mesh'))\n    suite.addTest(MLPSolverTest('test_mlp_n_layer_2d_mesh'))\n\n    suite.addTest(MLPSolverTest('test_mlp_2_layer_bias_data_parallel'))\n    suite.addTest(MLPSolverTest('test_mlp_2_layer_bias_model_parallel'))\n    suite.addTest(MLPSolverTest('test_mlp_2_layer_bias_2d_mesh'))\n\n    suite.addTest(MLPSolverTest('test_mlp_2_layer_force_data_parallel'))\n\n    return suite\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n\n"
  },
  {
    "path": "playground/jax_basic/slice_jaxpr.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"import jax\\n\",\n    \"import jax.numpy as jnp\\n\",\n    \"from jax import jit, grad, vmap\\n\",\n    \"from jax import random\\n\",\n    \"\\n\",\n    \"from functools import wraps, partial\\n\",\n    \"from jax import core\\n\",\n    \"from jax import lax\\n\",\n    \"from jax._src.util import safe_map\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"foo\\n\",\n      \"=====\\n\",\n      \"invars: [a]\\n\",\n      \"outvars: [b]\\n\",\n      \"constvars: []\\n\",\n      \"equation: [a, 1] add [b] {}\\n\",\n      \"\\n\",\n      \"jaxpr: { lambda  ; a.\\n\",\n      \"  let b = add a 1\\n\",\n      \"  in (b,) }\\n\",\n      \"\\n\",\n      \"bar\\n\",\n      \"=====\\n\",\n      \"invars: [a, b, c]\\n\",\n      \"outvars: [g, c]\\n\",\n      \"constvars: []\\n\",\n      \"equation: [a, c] dot_general [d] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': None}\\n\",\n      \"equation: [d, b] add [e] {}\\n\",\n      \"equation: [1.0] broadcast_in_dim [f] {'shape': (5,), 'broadcast_dimensions': ()}\\n\",\n      \"equation: [e, f] add [g] {}\\n\",\n      \"\\n\",\n      \"jaxpr: { lambda  ; a b c.\\n\",\n      \"  let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n      \"                       precision=None\\n\",\n      \"                       preferred_element_type=None ] a c\\n\",\n      \"      e = add d b\\n\",\n      \"      f = broadcast_in_dim[ broadcast_dimensions=(  )\\n\",\n      \"                            shape=(5,) ] 1.0\\n\",\n      \"      g = add e f\\n\",\n      \"  in (g, c) }\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"def examine_jaxpr(closed_jaxpr):\\n\",\n    \"  jaxpr = closed_jaxpr.jaxpr\\n\",\n    \"  print(\\\"invars:\\\", jaxpr.invars)\\n\",\n    \"  print(\\\"outvars:\\\", jaxpr.outvars)\\n\",\n    \"  print(\\\"constvars:\\\", jaxpr.constvars)\\n\",\n    \"  for eqn in jaxpr.eqns:\\n\",\n    \"    print(\\\"equation:\\\", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)\\n\",\n    \"  print()\\n\",\n    \"  print(\\\"jaxpr:\\\", jaxpr)\\n\",\n    \"\\n\",\n    \"def foo(x):\\n\",\n    \"  return x + 1\\n\",\n    \"print(\\\"foo\\\")\\n\",\n    \"print(\\\"=====\\\")\\n\",\n    \"examine_jaxpr(jax.make_jaxpr(foo)(5))\\n\",\n    \"\\n\",\n    \"print()\\n\",\n    \"\\n\",\n    \"def bar(w, b, x):\\n\",\n    \"  return jnp.dot(w, x) + b + jnp.ones(5), x\\n\",\n    \"print(\\\"bar\\\")\\n\",\n    \"print(\\\"=====\\\")\\n\",\n    \"examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from collections import OrderedDict\\n\",\n    \"\\n\",\n    \"def slice_closed_jaxpr(closed_jaxpr, start=None, end=None):\\n\",\n    \"#     print(\\\"closed_jaxpr.consts:\\\", closed_jaxpr.consts)\\n\",\n    \"#     print(\\\"closed_jaxpr.jaxpr.constvars:\\\", closed_jaxpr.jaxpr.constvars)\\n\",\n    \"#     print(\\\"closed_jaxpr.jaxpr.invars:\\\", closed_jaxpr.jaxpr.invars)\\n\",\n    \"#     print(\\\"closed_jaxpr.jaxpr.outvars:\\\", closed_jaxpr.jaxpr.outvars)\\n\",\n    \"    invars = set(closed_jaxpr.jaxpr.invars)\\n\",\n    \"    consts_dir = OrderedDict(zip(closed_jaxpr.jaxpr.constvars, closed_jaxpr.consts))\\n\",\n    \"    \\n\",\n    \"    pred_intermediate_vars = set()\\n\",\n    \"    \\n\",\n    \"    slice_consts_dir = OrderedDict()\\n\",\n    \"    slice_invars = []\\n\",\n    \"    slice_outvars = []\\n\",\n    \"    slice_eqns = []\\n\",\n    \"    slice_intermediate_vars = set()\\n\",\n    \"\\n\",\n    \"    succ_intermediate_vars = set()\\n\",\n    \"    \\n\",\n    \"    start = start if start is not None else 0\\n\",\n    \"    end = end if end is not None else len(closed_jaxpr.jaxpr.eqns)\\n\",\n    \"    \\n\",\n    \"    for index, eqn in enumerate(closed_jaxpr.jaxpr.eqns):\\n\",\n    \"#         print(index, eqn, eqn.invars, eqn.outvars)\\n\",\n    \"        if index < start:\\n\",\n    \"            pred_intermediate_vars.update(eqn.outvars)\\n\",\n    \"        elif start <= index < end:\\n\",\n    \"            slice_eqns.append(eqn)\\n\",\n    \"            for var in eqn.invars:\\n\",\n    \"                if isinstance(var, core.Literal):\\n\",\n    \"                    continue\\n\",\n    \"                elif var in consts_dir:\\n\",\n    \"                    if var not in slice_consts_dir:\\n\",\n    \"                        slice_consts_dir[var] = consts_dir[var]\\n\",\n    \"                elif (var in invars) or (var in pred_intermediate_vars):\\n\",\n    \"                    if var not in slice_invars: # FIXME: this is O(n^2)\\n\",\n    \"                        slice_invars.append(var)\\n\",\n    \"                else:\\n\",\n    \"                    assert var in slice_intermediate_vars\\n\",\n    \"            slice_intermediate_vars.update(eqn.outvars)\\n\",\n    \"        else:  # end <= index\\n\",\n    \"            for var in eqn.invars:\\n\",\n    \"                if isinstance(var, core.Literal):\\n\",\n    \"                    continue\\n\",\n    \"                elif (var in invars) or (var in pred_intermediate_vars):\\n\",\n    \"                    if var not in slice_invars: # FIXME: this is O(n^2)\\n\",\n    \"                        slice_invars.append(var)\\n\",\n    \"                    if var not in slice_outvars: # FIXME: this is O(n^2)\\n\",\n    \"                        slice_outvars.append(var)\\n\",\n    \"                elif var in slice_intermediate_vars:\\n\",\n    \"                    if var not in slice_outvars: # FIXME: this is O(n^2)\\n\",\n    \"                        slice_outvars.append(var)                    \\n\",\n    \"                else:\\n\",\n    \"                    assert (var in consts_dir) or (var in succ_intermediate_vars)\\n\",\n    \"            succ_intermediate_vars.update(eqn.outvars)\\n\",\n    \"\\n\",\n    \"    for var in closed_jaxpr.jaxpr.outvars:\\n\",\n    \"        if (var in invars) or (var in pred_intermediate_vars):\\n\",\n    \"            if var not in slice_invars: # FIXME: this is O(n^2)\\n\",\n    \"                slice_invars.append(var)\\n\",\n    \"            if var not in slice_outvars: # FIXME: this is O(n^2)\\n\",\n    \"                slice_outvars.append(var)\\n\",\n    \"        elif var in slice_intermediate_vars:\\n\",\n    \"            if var not in slice_outvars: # FIXME: this is O(n^2)\\n\",\n    \"                slice_outvars.append(var)                    \\n\",\n    \"        else:\\n\",\n    \"            assert (var in consts_dir) or (var in succ_intermediate_vars)\\n\",\n    \"\\n\",\n    \"#     print(\\\"pred_intermediate_vars\\\", pred_intermediate_vars)\\n\",\n    \"#     print(\\\"slice_consts_dir\\\", slice_consts_dir)\\n\",\n    \"#     print(\\\"slice_invars\\\", slice_invars)\\n\",\n    \"#     print(\\\"slice_outvars\\\", slice_outvars)\\n\",\n    \"#     print(\\\"slice_eqns\\\", slice_eqns)\\n\",\n    \"#     print(\\\"slice_intermediate_vars\\\", slice_intermediate_vars)\\n\",\n    \"#     print(\\\"succ_intermediate_vars\\\", succ_intermediate_vars)\\n\",\n    \"    slice_jaxpr = core.Jaxpr(slice_consts_dir.keys(), slice_invars, slice_outvars, slice_eqns)\\n\",\n    \"    slice_closed_jaxpr = core.ClosedJaxpr(slice_jaxpr, slice_consts_dir.values())\\n\",\n    \"    return slice_closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b.\\n\",\n       \"  let c = broadcast_in_dim[ broadcast_dimensions=(  )\\n\",\n       \"                            shape=(5,) ] 1.0\\n\",\n       \"      d = sin c\\n\",\n       \"      e = tanh a\\n\",\n       \"      f = mul d e\\n\",\n       \"      g = sin f\\n\",\n       \"      h = cos g\\n\",\n       \"      i = exp h\\n\",\n       \"  in (i, b) }\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"def f(x, z):\\n\",\n    \"    y = jnp.sin(jnp.ones_like(x))\\n\",\n    \"    x = y * jnp.tanh(x)\\n\",\n    \"    x = jnp.sin(x)\\n\",\n    \"    x = jnp.cos(x)\\n\",\n    \"    x = jnp.exp(x)\\n\",\n    \"    return x, z\\n\",\n    \"closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5), jnp.ones(6))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b.\\n\",\n       \"  let c = broadcast_in_dim[ broadcast_dimensions=(  )\\n\",\n       \"                            shape=(5,) ] 1.0\\n\",\n       \"      d = sin c\\n\",\n       \"      e = tanh a\\n\",\n       \"      f = mul d e\\n\",\n       \"  in (f, b) }\"\n      ]\n     },\n     \"execution_count\": 5,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"closed_jaxpr_slice1 = slice_closed_jaxpr(closed_jaxpr, start=0, end=4)\\n\",\n    \"closed_jaxpr_slice1\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; f b.\\n\",\n       \"  let g = sin f\\n\",\n       \"      h = cos g\\n\",\n       \"      i = exp h\\n\",\n       \"  in (i, b) }\"\n      ]\n     },\n     \"execution_count\": 6,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"closed_jaxpr_slice2 = slice_closed_jaxpr(closed_jaxpr, start=4)\\n\",\n    \"closed_jaxpr_slice2\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\\n\",\n       \" DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]\"\n      ]\n     },\n     \"execution_count\": 7,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"core.jaxpr_as_fun(closed_jaxpr)(jnp.ones(5), jnp.ones(6))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"[DeviceArray([0.6408594, 0.6408594, 0.6408594, 0.6408594, 0.6408594], dtype=float32), DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\\n\",\n       \" DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]\"\n      ]\n     },\n     \"execution_count\": 8,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"intermediate = core.jaxpr_as_fun(closed_jaxpr_slice1)(jnp.ones(5), jnp.ones(6))\\n\",\n    \"print(intermediate)\\n\",\n    \"core.jaxpr_as_fun(closed_jaxpr_slice2)(*intermediate)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"[DeviceArray([2.2853706, 2.2853706, 2.2853706, 2.2853706, 2.2853706], dtype=float32),\\n\",\n       \" DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)]\"\n      ]\n     },\n     \"execution_count\": 9,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"intermediate = jit(core.jaxpr_as_fun(closed_jaxpr_slice1))(jnp.ones(5), jnp.ones(6))\\n\",\n    \"jit(core.jaxpr_as_fun(closed_jaxpr_slice2))(*intermediate)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# TODO: merge with Lianmin's code\\n\",\n    \"# TODO: PyTree inputs\\n\",\n    \"# Q: How about lax.cond & lax.while?\\n\",\n    \"#    Ideally we should inline lax.cond & lax.while\\n\",\n    \"# Q: How about backward?\\n\",\n    \"# Q: How to slice a computation into different stages, given that jaxpr is actually a graph?\\n\",\n    \"# Why JaxPR? Try XLA\\n\",\n    \"# Forward & backward device assignment (very general)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 11,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b.\\n\",\n       \"  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] a b\\n\",\n       \"      d = exp c\\n\",\n       \"  in (d,) }\"\n      ]\n     },\n     \"execution_count\": 11,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"# @jax.jit\\n\",\n    \"def matmul(w, x):\\n\",\n    \"    return w @ x\\n\",\n    \"\\n\",\n    \"def f(w, x):\\n\",\n    \"    x = matmul(w, x)\\n\",\n    \"    x = jnp.exp(x)\\n\",\n    \"    return x\\n\",\n    \"\\n\",\n    \"closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 12,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b.\\n\",\n       \"  let c = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] a b\\n\",\n       \"      d = exp c\\n\",\n       \"  in (d,) }\"\n      ]\n     },\n     \"execution_count\": 12,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"with jax.disable_jit():\\n\",\n    \"    closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 13,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from jax import core\\n\",\n    \"from jax.lib import xla_client\\n\",\n    \"from jax.interpreters import xla, ad\\n\",\n    \"\\n\",\n    \"pipeline_start_p = core.Primitive(\\\"pipeline_start\\\")  # Create the primitive\\n\",\n    \"pipeline_start_p.multiple_results = True\\n\",\n    \"pipeline_end_p = core.Primitive(\\\"pipeline_end\\\")  # Create the primitive\\n\",\n    \"pipeline_end_p.multiple_results = True\\n\",\n    \"\\n\",\n    \"def mark_pipeline_start(*args, name):\\n\",\n    \"    return pipeline_start_p.bind(*args, name=name)\\n\",\n    \"\\n\",\n    \"def mark_pipeline_end(*args, name):\\n\",\n    \"    return pipeline_end_p.bind(*args, name=name)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"def pipeline_impl(*args, name):\\n\",\n    \"    if len(args) == 0:\\n\",\n    \"        return (None, )\\n\",\n    \"    else:\\n\",\n    \"        return args\\n\",\n    \"\\n\",\n    \"def pipeline_abstract_eval(*args, name):\\n\",\n    \"    if len(args) == 0:\\n\",\n    \"        return (core.abstract_unit, )\\n\",\n    \"    else:\\n\",\n    \"        return args\\n\",\n    \"\\n\",\n    \"def pipeline_xla_translation(c, *args, name):\\n\",\n    \"    if len(args) == 0:\\n\",\n    \"        return xla_client.ops.Tuple(c, (xla_client.ops.Constant(c, np.float32(0.0)), ))\\n\",\n    \"    else:\\n\",\n    \"        return xla_client.ops.Tuple(c, args)\\n\",\n    \"\\n\",\n    \"def pipeline_start_value_and_jvp(arg_values, arg_tangents, name):\\n\",\n    \"    primal_outs = mark_pipeline_start(*arg_values, name=name)\\n\",\n    \"    tangent_outs = mark_pipeline_start(*arg_tangents, name=\\\"jvp_\\\" + name)\\n\",\n    \"    return primal_outs, tangent_outs\\n\",\n    \"    \\n\",\n    \"def pipeline_start_transpose(ct, *args, name):\\n\",\n    \"    res = mark_pipeline_end(*ct, name=\\\"vjp_\\\" + name)\\n\",\n    \"    return res\\n\",\n    \"\\n\",\n    \"def pipeline_end_value_and_jvp(arg_values, arg_tangents, name):\\n\",\n    \"    primal_outs = mark_pipeline_end(*arg_values, name=name)\\n\",\n    \"    tangent_outs = mark_pipeline_end(*arg_tangents, name=\\\"jvp_\\\" + name)\\n\",\n    \"    return primal_outs, tangent_outs\\n\",\n    \"    \\n\",\n    \"def pipeline_end_transpose(ct, *args, name):\\n\",\n    \"    res = mark_pipeline_start(*ct, name=\\\"vjp_\\\" + name)\\n\",\n    \"    return res\\n\",\n    \"\\n\",\n    \"    \\n\",\n    \"pipeline_start_p.def_impl(pipeline_impl)\\n\",\n    \"pipeline_start_p.def_abstract_eval(pipeline_abstract_eval)\\n\",\n    \"xla.backend_specific_translations['cpu'][pipeline_start_p] = pipeline_xla_translation\\n\",\n    \"xla.backend_specific_translations['gpu'][pipeline_start_p] = pipeline_xla_translation\\n\",\n    \"xla.backend_specific_translations['tpu'][pipeline_start_p] = pipeline_xla_translation\\n\",\n    \"ad.primitive_jvps[pipeline_start_p] = pipeline_start_value_and_jvp\\n\",\n    \"ad.primitive_transposes[pipeline_start_p] = pipeline_start_transpose\\n\",\n    \"\\n\",\n    \"pipeline_end_p.def_impl(pipeline_impl)\\n\",\n    \"pipeline_end_p.def_abstract_eval(pipeline_abstract_eval)\\n\",\n    \"xla.backend_specific_translations['cpu'][pipeline_end_p] = pipeline_xla_translation\\n\",\n    \"xla.backend_specific_translations['gpu'][pipeline_end_p] = pipeline_xla_translation\\n\",\n    \"xla.backend_specific_translations['tpu'][pipeline_end_p] = pipeline_xla_translation\\n\",\n    \"ad.primitive_jvps[pipeline_end_p] = pipeline_end_value_and_jvp\\n\",\n    \"ad.primitive_transposes[pipeline_end_p] = pipeline_end_transpose\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 14,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b.\\n\",\n       \"  let c d = pipeline_start[ name=1 ] a b\\n\",\n       \"      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] c d\\n\",\n       \"      f = pipeline_end[ name=1 ] e\\n\",\n       \"      g = pipeline_start[ name=2 ] f\\n\",\n       \"      h = exp g\\n\",\n       \"      i = reduce_sum[ axes=(0,) ] h\\n\",\n       \"      _ = mul i 7.0\\n\",\n       \"      j = pipeline_end[ name=2 ] i\\n\",\n       \"  in (j,) }\"\n      ]\n     },\n     \"execution_count\": 14,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"def f_original(w, x):\\n\",\n    \"    x = matmul(w, x)\\n\",\n    \"    x = jnp.exp(x)\\n\",\n    \"    x = jnp.sum(x)\\n\",\n    \"    y = 7 * x\\n\",\n    \"    return x\\n\",\n    \"\\n\",\n    \"def f(w, x):\\n\",\n    \"    w, x = mark_pipeline_start(w, x, name=\\\"1\\\")\\n\",\n    \"    x = matmul(w, x)\\n\",\n    \"    x, = mark_pipeline_end(x, name=\\\"1\\\")\\n\",\n    \"    x, = mark_pipeline_start(x, name=\\\"2\\\")\\n\",\n    \"    x = jnp.exp(x)\\n\",\n    \"    x = jnp.sum(x)\\n\",\n    \"    y = 7 * x\\n\",\n    \"    x, = mark_pipeline_end(x, name=\\\"2\\\")\\n\",\n    \"    return x\\n\",\n    \"with jax.disable_jit():\\n\",\n    \"    closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 15,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"DeviceArray(742.0658, dtype=float32)\"\n      ]\n     },\n     \"execution_count\": 15,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"jax.jit(f)(jnp.ones((5, 5)), jnp.ones(5))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 16,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b.\\n\",\n       \"  let c d = pipeline_start[ name=1 ] a b\\n\",\n       \"      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] c d\\n\",\n       \"      f = pipeline_end[ name=1 ] e\\n\",\n       \"      g = pipeline_start[ name=2 ] f\\n\",\n       \"      h = exp g\\n\",\n       \"      i = reduce_sum[ axes=(0,) ] h\\n\",\n       \"      _ = mul i 7.0\\n\",\n       \"      _ = pipeline_end[ name=2 ] i\\n\",\n       \"      j = pipeline_start[ name=vjp_jvp_2 ] 1.0\\n\",\n       \"      k = broadcast_in_dim[ broadcast_dimensions=(  )\\n\",\n       \"                            shape=(5,) ] j\\n\",\n       \"      l = mul k h\\n\",\n       \"      m = pipeline_end[ name=vjp_jvp_2 ] l\\n\",\n       \"      n = pipeline_start[ name=vjp_jvp_1 ] m\\n\",\n       \"      o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] n c\\n\",\n       \"      p = dot_general[ dimension_numbers=(((), ()), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] n d\\n\",\n       \"      q r = pipeline_end[ name=vjp_jvp_1 ] p o\\n\",\n       \"  in (q, r) }\"\n      ]\n     },\n     \"execution_count\": 16,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"with jax.disable_jit():\\n\",\n    \"    closed_jaxpr = jax.make_jaxpr(jax.grad(jax.jit(f), argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 17,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"(DeviceArray([[148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\\n\",\n       \"              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\\n\",\n       \"              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\\n\",\n       \"              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316],\\n\",\n       \"              [148.41316, 148.41316, 148.41316, 148.41316, 148.41316]],            dtype=float32),\\n\",\n       \" DeviceArray([742.0658, 742.0658, 742.0658, 742.0658, 742.0658], dtype=float32))\"\n      ]\n     },\n     \"execution_count\": 17,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"(jax.grad(f, argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 18,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b c d.\\n\",\n       \"  let e f = pipeline_start[ name=1 ] a b\\n\",\n       \"      g h = pipeline_start[ name=jvp_1 ] c d\\n\",\n       \"      i = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] e f\\n\",\n       \"      j = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] g f\\n\",\n       \"      k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] e h\\n\",\n       \"      l = add_any j k\\n\",\n       \"      m = pipeline_end[ name=1 ] i\\n\",\n       \"      n = pipeline_end[ name=jvp_1 ] l\\n\",\n       \"      o = pipeline_start[ name=2 ] m\\n\",\n       \"      p = pipeline_start[ name=jvp_2 ] n\\n\",\n       \"      q = exp o\\n\",\n       \"      r = mul p q\\n\",\n       \"      s = reduce_sum[ axes=(0,) ] q\\n\",\n       \"      t = reduce_sum[ axes=(0,) ] r\\n\",\n       \"      _ = mul s 7.0\\n\",\n       \"      _ = mul t 7.0\\n\",\n       \"      u = pipeline_end[ name=2 ] s\\n\",\n       \"      v = pipeline_end[ name=jvp_2 ] t\\n\",\n       \"  in (u, v) }\"\n      ]\n     },\n     \"execution_count\": 18,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"with jax.disable_jit():\\n\",\n    \"    closed_jaxpr = jax.make_jaxpr(partial(jax.jvp, f))((jnp.ones((5, 5)), jnp.ones(5)), (jnp.ones((5, 5)), jnp.ones(5)))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 25,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"pipeline_p = Primitive('pipeline')\\n\",\n    \"pipeline_p.multiple_results = True\\n\",\n    \"\\n\",\n    \"def mark_pipeline(*args, name, mark_type):\\n\",\n    \"    if mark_type not in ('start', 'end', 'jvp_start', 'jvp_end'):\\n\",\n    \"        raise ValueError('Unknown mark type: %s' % mark_type)\\n\",\n    \"    return pipeline_p.bind(*args, name=name, mark_type=mark_type)\\n\",\n    \"\\n\",\n    \"def _pipeline_impl(*args, **kwargs):\\n\",\n    \"    # The pipeline marker acts as an identity function\\n\",\n    \"    return args if len(args) > 0 else (None, )\\n\",\n    \"\\n\",\n    \"def _pipeline_abstract_eval(*args, **kwargs):\\n\",\n    \"    return args if len(args) > 0 else (abstract_unit, )\\n\",\n    \"\\n\",\n    \"def _pipeline_xla_translation(c, *args, **kwargs):\\n\",\n    \"    return xc.ops.Tuple(c, args) if len(args) > 0 else xc.ops.Tuple(c, (xc.ops.Constant(c, np.float32(0.0)), ))\\n\",\n    \"\\n\",\n    \"def _pipeline_value_and_jvp(arg_values, arg_tangents, name, mark_type):\\n\",\n    \"    primal_outs = mark_pipeline(*arg_values, name=name, mark_type=mark_type)\\n\",\n    \"    # TODO(zhuohan): Check the semantics here works for higher order gradients.\\n\",\n    \"    if mark_type == \\\"start\\\" or mark_type == \\\"jvp_start\\\":\\n\",\n    \"        tangent_mark_type = \\\"jvp_start\\\"\\n\",\n    \"    elif mark_type == \\\"end\\\" or mark_type == \\\"jvp_end\\\":\\n\",\n    \"        tangent_mark_type = \\\"jvp_end\\\"\\n\",\n    \"    else:\\n\",\n    \"        raise ValueError(\\\"Invalid mark_type\\\")\\n\",\n    \"    tangent_outs = mark_pipeline(*arg_tangents, name=name, mark_type=tangent_mark_type)\\n\",\n    \"    return primal_outs, tangent_outs\\n\",\n    \"\\n\",\n    \"def _pipeline_transpose(ct, *args, name, mark_type):\\n\",\n    \"    # TODO(zhuohan): Check the semantics here works for higher order gradients.\\n\",\n    \"    if mark_type == \\\"start\\\" or mark_type == \\\"jvp_start\\\":\\n\",\n    \"        transposed_mark_type = \\\"end\\\"\\n\",\n    \"    elif mark_type == \\\"end\\\" or mark_type == \\\"jvp_end\\\":\\n\",\n    \"        transposed_mark_type = \\\"start\\\"\\n\",\n    \"    else:\\n\",\n    \"        raise ValueError(\\\"Invalid mark_type\\\")\\n\",\n    \"    res = mark_pipeline(*ct, name=name, mark_type=transposed_mark_type)\\n\",\n    \"    return res\\n\",\n    \"\\n\",\n    \"pipeline_p.def_impl(_pipeline_impl)\\n\",\n    \"pipeline_p.def_abstract_eval(_pipeline_abstract_eval)\\n\",\n    \"xla.translations[pipeline_p] = _pipeline_xla_translation\\n\",\n    \"ad.primitive_jvps[pipeline_p] = _pipeline_value_and_jvp\\n\",\n    \"ad.primitive_transposes[pipeline_p] = _pipeline_transpose\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b.\\n\",\n       \"  let c d = pipeline[ mark_type=start\\n\",\n       \"                      name=1 ] a b\\n\",\n       \"      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] c d\\n\",\n       \"      f = pipeline[ mark_type=end\\n\",\n       \"                    name=1 ] e\\n\",\n       \"      g = pipeline[ mark_type=start\\n\",\n       \"                    name=2 ] f\\n\",\n       \"      h = exp g\\n\",\n       \"      i = reduce_sum[ axes=(0,) ] h\\n\",\n       \"      _ = mul i 7.0\\n\",\n       \"      j = pipeline[ mark_type=end\\n\",\n       \"                    name=2 ] i\\n\",\n       \"  in (j,) }\"\n      ]\n     },\n     \"execution_count\": 26,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"def f(w, x):\\n\",\n    \"    w, x = mark_pipeline(w, x, name=\\\"1\\\", mark_type='start')\\n\",\n    \"    x = matmul(w, x)\\n\",\n    \"    x, = mark_pipeline(x, name=\\\"1\\\", mark_type='end')\\n\",\n    \"    x, = mark_pipeline(x, name=\\\"2\\\", mark_type='start')\\n\",\n    \"    x = jnp.exp(x)\\n\",\n    \"    x = jnp.sum(x)\\n\",\n    \"    y = 7 * x\\n\",\n    \"    x, = mark_pipeline(x, name=\\\"2\\\", mark_type='end')\\n\",\n    \"    return x\\n\",\n    \"with jax.disable_jit():\\n\",\n    \"    closed_jaxpr = jax.make_jaxpr(f)(jnp.ones((5, 5)), jnp.ones(5))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 23,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b.\\n\",\n       \"  let c d = pipeline[ mark_type=start\\n\",\n       \"                      name=1 ] a b\\n\",\n       \"      e = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] c d\\n\",\n       \"      f = pipeline[ mark_type=end\\n\",\n       \"                    name=1 ] e\\n\",\n       \"      g = pipeline[ mark_type=start\\n\",\n       \"                    name=2 ] f\\n\",\n       \"      h = exp g\\n\",\n       \"      i = reduce_sum[ axes=(0,) ] h\\n\",\n       \"      _ = mul i 7.0\\n\",\n       \"      _ = pipeline[ mark_type=end\\n\",\n       \"                    name=2 ] i\\n\",\n       \"      j = pipeline[ mark_type=start\\n\",\n       \"                    name=2 ] 1.0\\n\",\n       \"      k = broadcast_in_dim[ broadcast_dimensions=(  )\\n\",\n       \"                            shape=(5,) ] j\\n\",\n       \"      l = mul k h\\n\",\n       \"      m = pipeline[ mark_type=end\\n\",\n       \"                    name=2 ] l\\n\",\n       \"      n = pipeline[ mark_type=start\\n\",\n       \"                    name=1 ] m\\n\",\n       \"      o = dot_general[ dimension_numbers=(((0,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] n c\\n\",\n       \"      p = dot_general[ dimension_numbers=(((), ()), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] n d\\n\",\n       \"      q r = pipeline[ mark_type=end\\n\",\n       \"                      name=1 ] p o\\n\",\n       \"  in (q, r) }\"\n      ]\n     },\n     \"execution_count\": 23,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"with jax.disable_jit():\\n\",\n    \"    closed_jaxpr = jax.make_jaxpr(jax.grad(jax.jit(f), argnums=[0, 1]))(jnp.ones((5, 5)), jnp.ones(5))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 24,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{ lambda  ; a b c d.\\n\",\n       \"  let e f = pipeline[ mark_type=start\\n\",\n       \"                      name=1 ] a b\\n\",\n       \"      g h = pipeline[ mark_type=jvp_start\\n\",\n       \"                      name=1 ] c d\\n\",\n       \"      i = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] e f\\n\",\n       \"      j = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] g f\\n\",\n       \"      k = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))\\n\",\n       \"                       precision=None\\n\",\n       \"                       preferred_element_type=None ] e h\\n\",\n       \"      l = add_any j k\\n\",\n       \"      m = pipeline[ mark_type=end\\n\",\n       \"                    name=1 ] i\\n\",\n       \"      n = pipeline[ mark_type=jvp_end\\n\",\n       \"                    name=1 ] l\\n\",\n       \"      o = pipeline[ mark_type=start\\n\",\n       \"                    name=2 ] m\\n\",\n       \"      p = pipeline[ mark_type=jvp_start\\n\",\n       \"                    name=2 ] n\\n\",\n       \"      q = exp o\\n\",\n       \"      r = mul p q\\n\",\n       \"      s = reduce_sum[ axes=(0,) ] q\\n\",\n       \"      t = reduce_sum[ axes=(0,) ] r\\n\",\n       \"      _ = mul s 7.0\\n\",\n       \"      _ = mul t 7.0\\n\",\n       \"      u = pipeline[ mark_type=end\\n\",\n       \"                    name=2 ] s\\n\",\n       \"      v = pipeline[ mark_type=jvp_end\\n\",\n       \"                    name=2 ] t\\n\",\n       \"  in (u, v) }\"\n      ]\n     },\n     \"execution_count\": 24,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"with jax.disable_jit():\\n\",\n    \"    closed_jaxpr = jax.make_jaxpr(partial(jax.jvp, f))((jnp.ones((5, 5)), jnp.ones(5)), (jnp.ones((5, 5)), jnp.ones(5)))\\n\",\n    \"closed_jaxpr\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Environment (conda_anaconda3)\",\n   \"language\": \"python\",\n   \"name\": \"conda_anaconda3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.6\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 4\n}\n"
  },
  {
    "path": "playground/jax_basic/test_device_put.py",
    "content": "import time\n\nimport jax\nimport jax.numpy as jnp\n\nimport torch\nimport numpy as np\n\ndef benchmark_func(func):\n    warmup = 1\n    number = 2\n\n    for i in range(warmup):\n        func()\n    jax.local_devices()[0].synchronize_all_activity()\n\n    tic = time.time()\n    for i in range(number):\n        func()\n    toc = time.time()\n\n    return (toc - tic) / number\n\n\n\nif __name__ == \"__main__\":\n    num_samples = 20000\n    batch_size = 2048\n\n    print(\"Init data...\")\n    np.random.seed(0)\n    images = np.ones((num_samples, 224, 224, 3), dtype=np.float32)\n    labels = np.ones((num_samples,), dtype=np.int32)\n    steps_per_epoch = len(images) // batch_size\n\n    devices = jax.devices()\n\n    print(\"Load data...\")\n    shard_size = batch_size // len(devices)\n\n    def np_array_view():\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n\n    def np_array_copy():\n        for i in range(steps_per_epoch):\n            batch_images = np.array(images[i * batch_size: (i+1)*batch_size])\n            batch_labels = np.array(labels[i * batch_size: (i+1)*batch_size])\n\n    def jnp_array_copy():\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n            batch_images = jnp.array(batch_images)\n            batch_labels = jnp.array(batch_labels)\n\n    signal = jnp.ones((1024, 1024))\n\n    def jax_device_put():\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n            jax.device_put(batch_images)\n            jax.device_put(batch_labels)\n            signal.block_until_ready()\n\n    def jax_device_put2():\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n            jax.device_put(batch_images)\n            jax.device_put(batch_labels)\n            signal.block_until_ready()\n\n    def jax_device_put_sync():\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n            x = jax.device_put(batch_images)\n            jax.device_put(batch_labels)\n            x.block_until_ready()\n\n    def jax_device_put_multi_devices():\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n            for j, d in enumerate(devices):\n                jax.device_put(batch_images[j * shard_size:(j+1) * shard_size], d)\n                jax.device_put(batch_labels[j * shard_size:(j+1) * shard_size], d)\n\n    def jax_device_put_multi_devices_slow():\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n            for j, d in enumerate(devices):\n                jax.device_put(batch_images[j * shard_size:(j+1) * shard_size], d)\n                jax.device_put(batch_labels[j * shard_size:(j+1) * shard_size], d)\n\n    def jax_device_put_multi_devices_sync():\n        arrays = [None] * len(devices)\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n\n            for j, d in enumerate(devices):\n                arrays[j] = jax.device_put(batch_images[j * shard_size:(j+1) * shard_size], d)\n                jax.device_put(batch_labels[j * shard_size:(j+1) * shard_size], d)\n\n            for j in range(len(devices)):\n                arrays[j].block_until_ready()\n\n    def jax_device_put_multi_devices_sync_serial():\n        arrays = [None] * len(devices)\n        for i in range(steps_per_epoch):\n            batch_images = images[i * batch_size: (i+1)*batch_size]\n            batch_labels = labels[i * batch_size: (i+1)*batch_size]\n\n            for j, d in enumerate(devices):\n                arrays[j] = jax.device_put(batch_images[j * shard_size:(j+1) * shard_size], d)\n                jax.device_put(batch_labels[j * shard_size:(j+1) * shard_size], d)\n                arrays[j].block_until_ready()\n\n    #time_np_array_view = benchmark_func(np_array_view)\n    #time_np_array_copy = benchmark_func(np_array_copy)\n    #time_jnp_array_copy = benchmark_func(jnp_array_copy)\n    time_jax_device_put = benchmark_func(jax_device_put)\n    time_jax_device_put2 = benchmark_func(jax_device_put2)\n    time_jax_device_put_sync = benchmark_func(jax_device_put_sync)\n    time_jax_device_put_multi_devices = benchmark_func(jax_device_put_multi_devices)\n    time_jax_device_put_multi_devices_slow = benchmark_func(jax_device_put_multi_devices_slow)\n    time_jax_device_put_multi_devices_sync = benchmark_func(jax_device_put_multi_devices_sync)\n    time_jax_device_put_multi_devices_sync_serial = benchmark_func(jax_device_put_multi_devices_sync_serial)\n\n    print(f\"Steps: {steps_per_epoch}\")\n    #print(f\"np_array_view: {time_np_array_view * 1e3:.3f} ms\")\n    #print(f\"np_array_copy: {time_np_array_copy * 1e3:.3f} ms\")\n    #print(f\"jnp_array_copy: {time_jnp_array_copy * 1e3:.3f} ms\")\n    print(f\"jax_device_put: {time_jax_device_put * 1e3:.3f} ms\")\n    print(f\"jax_device_put2: {time_jax_device_put2 * 1e3:.3f} ms\")\n    print(f\"jax_device_put_sync: {time_jax_device_put_sync * 1e3:.3f} ms\")\n    print(f\"jax_device_put_multi_devices: {time_jax_device_put_multi_devices* 1e3:.3f} ms\")\n    print(f\"jax_device_put_multi_devices_slow: {time_jax_device_put_multi_devices_slow * 1e3:.3f} ms\")\n    print(f\"jax_device_put_multi_devices_sync: {time_jax_device_put_multi_devices_sync * 1e3:.3f} ms\")\n    print(f\"jax_device_put_multi_devices_sync_serial: {time_jax_device_put_multi_devices_sync_serial * 1e3:.3f} ms\")\n\n"
  },
  {
    "path": "playground/jax_basic/test_flop_count.py",
    "content": "import jax, jax.numpy as jnp\n\ndef func(a, b):\n    c =  jnp.asarray(a, jnp.int32) @ jnp.asarray(b, jnp.int32)\n    #c = a @ b\n    c = c.transpose()\n    c += a\n    return c\n\na = jnp.ones((100, 100))\nb = jnp.ones((100, 100))\n\nm = jax.xla_computation(func)(a, b).as_hlo_module()\nprint(m.to_string())\nr = jax.lib.xla_client._xla.hlo_module_count_flop_dot_conv_only(m)\nprint(r)\n\n"
  },
  {
    "path": "playground/jax_basic/test_jit.py",
    "content": "import numpy as np\nimport jax\nfrom jax import numpy as jnp\n\ndef test_jit_cache():\n\n    @jax.jit\n    def add_one(x):\n        return x + 1\n\n    a = jnp.ones(10)\n\n    print(add_one(a))\n    print(add_one(a))\n    print(add_one(a))\n\n\ndef test_cache_closure():\n    outer_scope = [0]\n\n    @jax.jit\n    def add_one(x):\n        print('call add_one')\n        return x + outer_scope[0]\n\n    a = jnp.ones(10)\n\n    print(add_one(a))\n    print(add_one(a))\n    outer_scope[0] = 1\n    print(add_one(a))\n\n\n\ndef test_non_jit():\n    a = jnp.array(np.ones(10))\n    b = jnp.array(np.ones(10))\n    c = a + b\n    c = a + c\n    c = a + c\n\n    print(c)\n\n\nif __name__ == \"__main__\":\n    #test_jit_cache()\n    test_cache_closure()\n    #test_non_jit()\n\n"
  },
  {
    "path": "playground/jax_basic/test_matmul_pmap.py",
    "content": "from functools import partial\n\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\n\ndef split(a, axis, factor):\n    assert a.shape[axis] % factor == 0\n    new_shape = a.shape[:axis] + (factor, a.shape[axis] // factor) + a.shape[axis+1:]\n    a = a.reshape(new_shape)\n    a = jax.pmap(lambda x: x, in_axes=axis, out_axes=axis)(a)\n    return a\n\ndef replica(a, factor):\n    a = jax.pmap(lambda x, y: x, in_axes=(None, 0), out_axes=None)(a, jnp.ones(factor))\n    return a\n\ndef unsplit(a, axis):\n    new_shape = a.shape[:axis] + (a.shape[axis] * a.shape[axis+1],) + a.shape[axis+2:]\n    return a.reshape(new_shape)\n\n\ndef test_matmul_k_partition():\n    def matmul_k_partition(lhs, rhs):\n        @partial(jax.pmap,\n                 axis_name='k',\n                 in_axes=(1, 0),\n                 out_axes=None)\n        def matmul(lhs, rhs):\n            res = lhs @ rhs\n            return jax.lax.psum(res, axis_name='k')\n\n        return matmul(lhs, rhs)\n\n    a = jnp.ones((1024, 1024))\n    b = jnp.ones((1024, 1024))\n\n    a = split(a, 1)\n    b = split(b, 0)\n    c = matmul_k_partition(a, b)\n\n    print(c.shape, c.sharding_spec)\n\n\ndef test_mlp_forward():\n    @partial(jax.pmap, in_axes=(None, 1), out_axes=1)\n    def matmul_r_s1_s1(x, w):\n        return x @ w\n\n    @partial(jax.pmap, in_axes=(1, 0), out_axes=None, axis_name='k')\n    def matmul_s1_s0_r(x, w):\n        res = x @ w\n        return jax.lax.psum(res, axis_name='k')\n\n    N = 1024\n    D = 1024\n\n    x = jnp.ones((N, D))\n    w1 = jnp.ones((D, D))\n    w2 = jnp.ones((D, D))\n\n    x = replica(x)\n    w1 = split(w1, axis=1)\n    w2 = split(w2, axis=0)\n\n    x = matmul_r_s1_s1(x, w1)\n    x = matmul_s1_s0_r(x, w2)\n\n\n@partial(jax.custom_vjp, nondiff_argnums=(1,))\ndef f_operator(x, axis_name):\n    return x\n\ndef f_operator_fwd(x, axis_name):\n    return f_operator(x), ()\n\ndef f_operator_bwd(axis_name, res, g):\n    return jax.lax.psum(x, axis_name=axis_name),\n\nf_operator.defvjp(f_operator_fwd, f_operator_bwd)\n\n@partial(jax.custom_vjp, nondiff_argnums=(1,))\ndef g_operator(x, axis_name):\n    return jax.lax.psum(x, axis_name=axis_name)\n\ndef g_operator_fwd(x, axis_name):\n    return g_operator(x, axis_name), ()\n\ndef g_operator_bwd(axis_name, res, g):\n    return g,\n\ng_operator.defvjp(g_operator_fwd, g_operator_bwd)\n\n\ndef test_mlp_model_parallel():\n    lr = 0.1\n    n_epoch = 1\n\n    def loss_serial(x, y, w1, w2):\n        x = x @ w1\n        x = jax.nn.relu(x)\n        x = x @ w2\n        return ((x - y) ** 2).mean()\n\n    def step_serial(x, y, w1, w2):\n        g_w1, g_w2 = jax.grad(loss_serial, argnums=(2, 3))(x, y, w1, w2)\n        return w1 - lr * g_w1, w2 - lr * g_w2\n\n    def train_serial(x, y, w1, w2):\n        for i in range(n_epoch):\n            w1, w2 = step_serial(x, y, w1, w2)\n        return w1, w2\n\n    def loss_parallel(x, y, w1, w2):\n        x = f_operator(x, axis_name='model_parallel')\n        x = x @ w1\n        x = jax.nn.relu(x)\n        x = x @ w2\n        x = g_operator(x, axis_name='model_parallel')\n        return ((x - y) ** 2).mean()\n\n    @partial(jax.pmap, in_axes=(None, None, 1, 0), out_axes=(1, 0),\n             axis_name='model_parallel')\n    def step_parallel(x, y, w1, w2):\n        g_w1, g_w2 = jax.grad(loss_parallel, argnums=(2, 3))(x, y, w1, w2)\n        return w1 - lr * g_w1, w2 - lr * g_w2\n\n    def train_parallel(x, y, w1, w2):\n        model_parallel = len(jax.devices())\n\n        w1 = split(w1, 1, model_parallel)\n        w2 = split(w2, 0, model_parallel)\n\n        for i in range(n_epoch):\n            w1, w2 = step_parallel(x, y, w1, w2)\n\n        return unsplit(w1, 1), unsplit(w2, 0)\n\n    N = 8\n    D = 128\n\n    np.random.seed(0)\n    x = np.random.uniform(size=(N, D))\n    y = np.random.uniform(size=(N, D))\n    w1 = np.random.uniform(size=(D, D))\n    w2 = np.random.uniform(size=(D, D))\n\n    w1_serial, w2_serial = train_serial(x, y, w1, w2)\n    w1_parallel, w2_parallel = train_parallel(x, y, w1, w2)\n\n    np.testing.assert_allclose(w1_serial, w1_parallel, rtol=1e-4)\n    np.testing.assert_allclose(w2_serial, w2_parallel, rtol=1e-4)\n\n\ndef test_mlp_data_parallel():\n    lr = 0.1\n    n_epoch = 1\n\n    def loss_serial(x, y, w1, w2):\n        x = x @ w1\n        x = jax.nn.relu(x)\n        x = x @ w2\n        return ((x - y) ** 2).mean()\n\n    def step_serial(x, y, w1, w2):\n        g_w1, g_w2 = jax.grad(loss_serial, argnums=(2, 3))(x, y, w1, w2)\n        return w1 - lr * g_w1, w2 - lr * g_w2\n\n    def train_serial(x, y, w1, w2):\n        for i in range(n_epoch):\n            w1, w2 = step_serial(x, y, w1, w2)\n        return w1, w2\n\n    def loss_parallel(x, y, w1, w2):\n        x = x @ w1\n        x = jax.nn.relu(x)\n        x = x @ w2\n        return ((x - y) ** 2).mean()\n\n    @partial(jax.pmap, in_axes=(0, 0, None, None), out_axes=(None, None),\n             axis_name='data_parallel')\n    def step_parallel(x, y, w1, w2):\n        g_w1, g_w2 = jax.grad(loss_parallel, argnums=(2, 3))(x, y, w1, w2)\n        g_w1 = jax.lax.pmean(g_w1, axis_name='data_parallel')\n        g_w2 = jax.lax.pmean(g_w2, axis_name='data_parallel')\n        return w1 - lr * g_w1, w2 - lr * g_w2\n\n    def train_parallel(x, y, w1, w2):\n        data_parallel = len(jax.devices())\n\n        x = split(x, 0, data_parallel)\n        y = split(y, 0, data_parallel)\n\n        for i in range(n_epoch):\n            w1, w2 = step_parallel(x, y, w1, w2)\n\n        return w1, w2\n\n    N = 8\n    D = 128\n\n    np.random.seed(0)\n    x = np.random.uniform(size=(N, D))\n    y = np.random.uniform(size=(N, D))\n    w1 = np.random.uniform(size=(D, D))\n    w2 = np.random.uniform(size=(D, D))\n\n    w1_serial, w2_serial = train_serial(x, y, w1, w2)\n    w1_parallel, w2_parallel = train_parallel(x, y, w1, w2)\n\n    np.testing.assert_allclose(w1_serial, w1_parallel, rtol=1e-4)\n    np.testing.assert_allclose(w2_serial, w2_parallel, rtol=1e-4)\n\n\ndef test_mlp_data_model_parallel():\n    lr = 0.1\n    n_epoch = 1\n\n    def loss_serial(x, y, w1, w2):\n        x = x @ w1\n        x = jax.nn.relu(x)\n        x = x @ w2\n        return ((x - y) ** 2).mean()\n\n    def step_serial(x, y, w1, w2):\n        g_w1, g_w2 = jax.grad(loss_serial, argnums=(2, 3))(x, y, w1, w2)\n        return w1 - lr * g_w1, w2 - lr * g_w2\n\n    def train_serial(x, y, w1, w2):\n        for i in range(n_epoch):\n            w1, w2 = step_serial(x, y, w1, w2)\n        return w1, w2\n\n    def loss_parallel(x, y, w1, w2):\n        x = f_operator(x, axis_name='model_parallel')\n        x = x @ w1\n        x = jax.nn.relu(x)\n        x = x @ w2\n        x = g_operator(x, axis_name='model_parallel')\n        return ((x - y) ** 2).mean()\n\n    @partial(jax.pmap, in_axes=(None, None, 1, 0), out_axes=(1, 0),\n             axis_name='model_parallel')\n    def step_model_parallel(x, y, w1, w2):\n        g_w1, g_w2 = jax.grad(loss_parallel, argnums=(2, 3))(x, y, w1, w2)\n        return g_w1, g_w2\n\n    @partial(jax.pmap, in_axes=(0, 0, None, None), out_axes=(None, None),\n             axis_name='data_parallel')\n    def step_data_parallel(x, y, w1, w2):\n        g_w1, g_w2 = step_model_parallel(x, y, w1, w2)\n        g_w1 = jax.lax.pmean(g_w1, axis_name='data_parallel')\n        g_w2 = jax.lax.pmean(g_w2, axis_name='data_parallel')\n        return w1 - lr * g_w1, w2 - lr * g_w2\n\n    def train_parallel(x, y, w1, w2):\n        model_parallel = 2\n        data_parallel = len(jax.devices()) // model_parallel\n\n        x = split(x, 0, data_parallel)\n        y = split(y, 0, data_parallel)\n        w1 = split(w1, 1, model_parallel)\n        w2 = split(w2, 0, model_parallel)\n\n        for i in range(n_epoch):\n            w1, w2 = step_data_parallel(x, y, w1, w2)\n\n        return unsplit(w1, 1), unsplit(w2, 0)\n\n    N = 8\n    D = 128\n\n    np.random.seed(0)\n    x = np.random.uniform(size=(N, D))\n    y = np.random.uniform(size=(N, D))\n    w1 = np.random.uniform(size=(D, D))\n    w2 = np.random.uniform(size=(D, D))\n\n    w1_serial, w2_serial = train_serial(x, y, w1, w2)\n    w1_parallel, w2_parallel = train_parallel(x, y, w1, w2)\n\n    np.testing.assert_allclose(w1_serial, w1_parallel, rtol=1e-4)\n    np.testing.assert_allclose(w2_serial, w2_parallel, rtol=1e-4)\n\n\nif __name__ == \"__main__\":\n    test_mlp_model_parallel()\n    test_mlp_data_parallel()\n    test_mlp_data_model_parallel()\n\n"
  },
  {
    "path": "playground/jax_basic/test_memory_allocator.py",
    "content": "\nimport os\nimport jax\nfrom jax import numpy as jnp\n\ndef run_cmd(x):\n    os.system(x)\n\ndef test_platform_allocator():\n    os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\"\n    #os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"false\"\n\n    a = jnp.ones(1 << 30)\n\n    run_cmd(\"nvidia-smi\")\n\n    a = None\n\n    run_cmd(\"nvidia-smi\")\n\n\nif __name__ == \"__main__\":\n    test_platform_allocator()\n\n"
  },
  {
    "path": "playground/jax_basic/test_mixed_precision.py",
    "content": "from flax import optim, linen as nn\nimport jax\nfrom jax import numpy as jnp\n\nimport alpa\nfrom alpa.model.bert_model import FlaxBertLayer, BertConfig\n\n\ndef inspect_params(optimizer):\n    \"\"\"For debug usage.\"\"\"\n    print(jax.tree_util.tree_map(lambda x: (x.shape, x.dtype), optimizer.target))\n\n\ndef test_mlp():\n    batch_size = 16\n    hidden_size = 128\n\n    class Model(nn.Module):\n        @nn.compact\n        def __call__(self, x):\n            x = nn.Dense(features=hidden_size, dtype=jnp.float16)(x)\n            x = nn.relu(x)\n            x = nn.Dense(features=hidden_size, dtype=jnp.float16)(x)\n            return x\n\n    @alpa.parallelize\n    def train_step(optimizer, batch, apply_fn):\n        def loss_func(params):\n            out = apply_fn(params, batch[\"x\"])\n            return jnp.mean((out - batch[\"y\"]) ** 2, dtype=jnp.float16) * 0.1234\n\n        grad = jax.grad(loss_func)(optimizer.target)\n        new_optimizer = optimizer.apply_gradient(grad)\n        return new_optimizer\n\n    x = jnp.ones((batch_size, hidden_size), dtype=jnp.float16)\n    y = jnp.ones((batch_size, hidden_size), dtype=jnp.float16)\n\n    # Init model and optimizer\n    model = Model()\n    rngkey = jax.random.PRNGKey(0)\n    params = model.init(rngkey, x)\n    optimizer = optim.GradientDescent(1e-2).create(params)\n\n    # JIT compile\n    optimizer = train_step(optimizer, {\"x\": x, \"y\": y}, model.apply)\n\n\ndef test_bert_layer():\n    batch_size = 64\n    seq_len = 64\n    hidden_size = 768\n\n    hidden_states = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float16)\n    attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n    label = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float16)\n\n    # Init model and optimizer\n    model = FlaxBertLayer(BertConfig(\n        hidden_size=hidden_size,\n    ), dtype=jnp.float16)\n    rngkey = jax.random.PRNGKey(0)\n    params = model.init(rngkey, hidden_states, attention_mask)\n    optimizer = optim.GradientDescent(1e-2).create(params)\n\n    @alpa.parallelize\n    def train_step(optimizer, batch):\n        def loss_func(params):\n            rngs = {\"dropout\": batch[\"rng\"]}\n            out = model.apply(params,\n                              batch[\"hidden_states\"],\n                              batch[\"attention_mask\"],\n                              rngs=rngs)[0]\n            return jnp.mean((out - batch[\"label\"]) ** 2)\n\n        grad = jax.grad(loss_func)(optimizer.target)\n        new_optimizer = optimizer.apply_gradient(grad)\n        return new_optimizer\n\n    # JIT compile\n    optimizer = train_step(optimizer,\n                           {\"hidden_states\": hidden_states,\n                            \"attention_mask\": attention_mask,\n                            \"label\": label,\n                            \"rng\": rngkey})\n    inspect_params(optimizer)\n\n\nif __name__ == \"__main__\":\n    #test_mlp()\n    test_bert_layer()\n\n"
  },
  {
    "path": "playground/jax_basic/test_pjit.py",
    "content": "from functools import partial\n\nimport numpy as np\n\nimport jax\nfrom jax import lax\nimport jax.numpy as jnp\nfrom jax.nn import relu\nfrom jax.experimental import PartitionSpec as P\nfrom jax.experimental.maps import mesh\nfrom jax.experimental.pjit import pjit, with_sharding_constraint\nfrom jax._src.random import _random_bits, threefry_2x32\nimport flax\nfrom flax import linen as nn\n\nfrom util import benchmark_func\n\ndef test_basic1d():\n    @partial(pjit,\n             in_axis_resources=(P('x'), P('x')),\n             out_axis_resources=None)\n    def f(x, y):\n        return x + y\n\n    x = np.ones((8, 8))\n\n    mesh_devices = np.array(jax.devices()[:2])\n    with mesh(mesh_devices, ('x',)):\n        actual = f(x, x + 1)\n\n\ndef test_matmul():\n    @partial(pjit,\n             in_axis_resources=(P('x', None), P('x', None)),\n             out_axis_resources=P('x', None))\n    def f(x, y):\n        return x @ y\n\n    x = np.random.randn(8, 4).astype(np.float32)\n    y = np.random.randn(4, 8).astype(np.float32)\n\n    mesh_devices = np.array(jax.devices()[:2])\n    with mesh(mesh_devices, ('x',)):\n        out = f(x, y)\n\n    np.testing.assert_allclose(out, x @ y, rtol=1e-5)\n\n\ndef test_failed_matmul_case_1():\n    # Case 1: SR = RR x SR\n    @partial(pjit,\n             in_axis_resources=(P(None, None), P('y', None)),\n             out_axis_resources=P('x', None))\n    def f(x, y):\n        return x @ y\n\n    x = np.random.randn(4, 128).astype(np.float32)\n    y = np.random.randn(128, 4).astype(np.float32)\n\n    mesh_devices = np.array(jax.devices()[:4]).reshape((2, 2))\n    with mesh(mesh_devices, ('x', 'y')):\n        out = f(x, y)\n\n\ndef test_failed_matmul_case_2():\n    # Case 2: SR = SR x SR\n    @partial(pjit,\n             in_axis_resources=(P('x', None), P('y', None)),\n             out_axis_resources=P('x', None))\n    def f(x, y):\n        return x @ y\n\n    x = np.random.randn(8, 4).astype(np.float32)\n    y = np.random.randn(4, 8).astype(np.float32)\n\n    mesh_devices = np.array(jax.devices()[:4]).reshape((2, 2))\n    with mesh(mesh_devices, ('x', 'y')):\n        out = f(x, y)\n\n    np.testing.assert_allclose(out, x @ y, rtol=1e-5)\n\n\ndef test_reduce_scatter():\n    @partial(pjit,\n             in_axis_resources=(P(None, 'x'), P('x', None)),\n             out_axis_resources=P('x', None))\n    def f(x, y):\n        return x @ y\n\n    x = np.random.randn(8, 4).astype(np.float32)\n    y = np.random.randn(4, 8).astype(np.float32)\n\n    mesh_devices = np.array(jax.devices()[:2])\n    with mesh(mesh_devices, ('x',)):\n        out = f(x, y)\n\n    np.testing.assert_allclose(np.array(out), x @ y, rtol=1e-5)\n\n\ndef split(a, axis):\n    in_axis_resources = [None] * len(a.shape)\n    in_axis_resources[axis] = 'x'\n\n    split_func = pjit(lambda x: x,\n                      in_axis_resources=P(*in_axis_resources),\n                      out_axis_resources=P(*in_axis_resources))\n\n    with mesh(np.array(jax.devices()), ('x',)):\n        a = split_func(a)\n    return a\n\n\ndef test_matmul_speed():\n    N = M = 1024\n    K = 1 << 19\n    n_devices = len(jax.devices())\n\n    x_jnp = jnp.empty((N, K), dtype=np.float32).block_until_ready()\n    y_jnp = jnp.empty((K, M), dtype=np.float32).block_until_ready()\n\n    @jax.jit\n    def matmul(x, y):\n        return x @ y\n\n    def serial_func():\n        z = matmul(x_jnp, y_jnp)\n        z.block_until_ready()\n\n    costs = benchmark_func(serial_func) * 1000\n    print(\"Mean Cost: %.3f ms (std: %.3f ms)\" % (np.mean(costs), np.std(costs)))\n\n    x_split = split(x_jnp, 1).block_until_ready()\n    y_split = split(y_jnp, 0).block_until_ready()\n\n    parallel_matmul = pjit(matmul,\n                           in_axis_resources=(P(None, 'x'), P('x', None)),\n                           out_axis_resources=None)\n\n    def parallel_func():\n        z = parallel_matmul(x_split, y_split)\n        z.block_until_ready()\n\n    with mesh(np.array(jax.devices()), ('x',)):\n        costs = benchmark_func(parallel_func) * 1000\n    print(\"Mean Cost: %.3f ms (std: %.3f ms)\" % (np.mean(costs), np.std(costs)))\n\n\ndef test_dict_arg():\n    @partial(pjit,\n             in_axis_resources=None,\n             out_axis_resources=None)\n    def f(inputs):\n        x = inputs['x']\n        y = inputs['y']\n        return x @ y\n\n    x = np.random.randn(8, 4).astype(np.float32)\n    y = np.random.randn(4, 8).astype(np.float32)\n\n    mesh_devices = np.array(jax.devices()[:2])\n    with mesh(mesh_devices, ('x',)):\n        out = f({\"x\": x, \"y\": y})\n\n    np.testing.assert_allclose(out, x @ y, rtol=1e-5)\n\n\ndef test_mlp_forward():\n    def loss_func(batch, weights):\n        x, y = batch\n        w1, w2 = weights\n\n        x = x @ w1\n        x = relu(x)\n        x = with_sharding_constraint(x, P('data_parallel', 'model_parallel'))\n        x = x @ w2\n        loss = x\n        #x = relu(x)\n        #loss = jnp.mean((x - y) ** 2)\n        return loss\n\n    loss_func_parallel = pjit(\n        loss_func,\n        in_axis_resources=((P('data_parallel', None), P('data_parallel', None)),\n                           (P(None, 'model_parallel'), P('model_parallel', None))),\n        out_axis_resources=None,\n    )\n\n    N = 8\n    D = 128\n\n    np.random.seed(1)\n    x = np.random.uniform(size=(N, D))\n    y = np.random.uniform(size=(N, D))\n    w1 = np.random.uniform(size=(D, D))\n    w2 = np.random.uniform(size=(D, D))\n\n    mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2)\n    with mesh(mesh_devices, ('data_parallel', 'model_parallel')):\n        loss_parallel = loss_func_parallel((x, y), (w1, w2))\n\n    #loss_serial = loss_func((x, y), (w1, w2))\n    #np.testing.assert_allclose(loss_serial, loss_parallel, rtol=1e-5)\n\n\ndef test_mlp_grad():\n    def loss_func(batch, weights):\n        x, y = batch\n        w1, w2 = weights\n\n        x = x @ w1\n        x = with_sharding_constraint(x, P('data_parallel', 'model_parallel'))\n        x = x @ w2\n        loss = jnp.mean((x - y) ** 2)\n        return loss\n\n    def step_serial(batch, weights):\n        gradients = jax.grad(loss_func, argnums=1)(batch, weights)\n        return tuple(w - g for w, g in zip(weights, gradients))\n\n    step_parallel = pjit(\n        step_serial,\n        in_axis_resources=((P('data_parallel', None), P('data_parallel', None)),\n                           (P(None, 'model_parallel'), P('model_parallel', None))),\n        out_axis_resources=((P(None, 'model_parallel'), P('model_parallel', None))),\n    )\n\n    step_serail = jax.jit(step_serial)\n\n    lr = 1\n    N = 256\n    D = 8192\n\n    np.random.seed(1)\n    x = np.random.uniform(size=(N, D))\n    y = np.random.uniform(size=(N, D))\n    w1 = np.random.uniform(size=(D, D))\n    w2 = np.random.uniform(size=(D, D))\n\n    mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2)\n    with mesh(mesh_devices, ('data_parallel', 'model_parallel')):\n        w1_parallel, w2_parallel = step_parallel((x, y), (w1, w2))\n\n    #w1_serial, w2_serial = step_serial((x, y), (w1, w2))\n    #np.testing.assert_allclose(w1_serial, w1_parallel, rtol=1e-5)\n    #np.testing.assert_allclose(w2_serial, w2_parallel, rtol=1e-5)\n\n\ndef test_random_bits():\n    @partial(pjit,\n             in_axis_resources=(P('x'), None),\n             out_axis_resources=P('x'))\n    def func(inputs, key):\n      random_uniform = lax.rng_uniform(0.0, 1.0, inputs.shape)\n      ret = inputs * random_uniform\n      return ret\n\n    inputs = jnp.ones((4096,))\n    rngkey = jax.random.PRNGKey(0)\n\n    mesh_devices = np.array(jax.devices()[:4])\n    with mesh(mesh_devices, ('x',)):\n        actual = func(inputs, rngkey)\n        print(actual)\n        actual = func(inputs, rngkey)\n        print(actual)\n\n\n# Monkey patch random generator to use stateful random generator.\n# This can simplify the computational graph\ndef fast_uniform(key, shape, dtype, minval=0.0, maxval=1.0):\n    shape = jax.core.as_named_shape(shape)\n    return lax.rng_uniform(minval, maxval, shape.positional)\n\ndef remove_fold_in(key, data):\n    return key\n\njax._src.random.uniform = fast_uniform\njax.random.uniform = fast_uniform\njax._src.random.fold_in = remove_fold_in\njax.random.fold_in = remove_fold_in\n\n\ndef test_dropout():\n    class Model(nn.Module):\n        @nn.compact\n        def __call__(self, x):\n            x = nn.Dropout(0.1, deterministic=False)(x)\n            return x\n\n    model = Model()\n\n    @partial(pjit,\n             in_axis_resources=(P('x', 'y', None), None),\n             out_axis_resources=P('x', 'y', None))\n    def func(inputs, key):\n      ret = model.apply({}, inputs, rngs={\"dropout\": key})\n      return ret\n\n    inputs = jnp.ones((512, 512, 16))\n    rngkey = jax.random.PRNGKey(0)\n\n    mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2)\n    with mesh(mesh_devices, ('x', 'y')):\n        actual = func(inputs, rngkey)\n        #print(actual)\n\n\ndef test_embedding():\n    vocab_size = 8192\n    hidden_size = 768\n    batch_size = 4\n    seq_len = 128\n\n    @partial(pjit,\n             in_axis_resources=(P(None, 'y'), P('x', None)),\n             out_axis_resources=P('x', None, 'y'))\n    def func(embedding, inputs):\n      ret = jnp.take(embedding, inputs, axis=0)\n      return ret\n\n    embedding = jnp.ones((vocab_size, hidden_size), dtype=np.float32)\n    inputs = jnp.ones((batch_size, seq_len), dtype=np.int32)\n\n    mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2)\n    with mesh(mesh_devices, ('x', 'y')):\n        actual = func(embedding, inputs)\n\n\ndef test_all_to_all():\n    @partial(pjit,\n             in_axis_resources=P('x', 'y', None),\n             out_axis_resources=P('x', None, 'y'))\n    def f(x):\n        return x\n\n    x = np.random.randn(2, 2, 4).astype(np.float32)\n\n    mesh_devices = np.array(jax.devices()[:4]).reshape(2, 2)\n    with mesh(mesh_devices, ('x', 'y')):\n        out = f(x)\n\n\nif __name__ == \"__main__\":\n    #test_basic1d()\n    #test_matmul()\n    #test_failed_matmul_case_1()\n    #test_failed_matmul_case_2()\n    #test_reduce_scatter()\n    #test_matmul_speed()\n    #test_dict_arg()\n\n    #test_mlp_forward()\n    #test_mlp_grad()\n\n    #test_random_bits()\n    #test_dropout()\n\n    #test_embedding()\n\n    test_all_to_all()\n\n"
  },
  {
    "path": "playground/jax_basic/test_pmap.py",
    "content": "from functools import partial \nimport jax\nfrom jax import lax\nimport jax.numpy as jnp\n\n\ndef debug_pmap():\n    @jax.pmap\n    def func(x, w):\n        return x @ w\n\n    y = func(jnp.ones((2, 4)), jnp.ones((2, 4)))\n    print(y, type(y))\n\n\ndef test_nested_pmap():\n    @partial(jax.pmap, axis_name='a0', in_axes=(0, None), out_axes=0)\n    def add(a, b):\n        # a.shape = (32, 64)\n        # b.shape = (64, 2, 32)\n        @partial(jax.pmap, axis_name='a1', in_axes=(None, 1), out_axes=1)\n        def add_inner(x, y):\n            # x.shape = (32, 64)\n            # y.shape = (64, 32)\n            return x @ y\n\n        # ret.shape = (32, 2, 32)\n        ret = add_inner(a, b)\n        return ret\n\n    a = jnp.ones((2, 32, 64))\n    b = jnp.ones((64, 2, 32))\n\n    #jaxpr = jax.make_jaxpr(add)(a, b)\n    #print(jaxpr)\n    #print(jaxpr.jaxpr.outvars[0].aval.shape)\n\n    c = add(a, b)\n    print(c)\n\n\ndef test_allreduce_sum():\n    @partial(jax.pmap, axis_name='i')\n    def normalize(x):\n        return x / lax.psum(x, 'i')\n\n    print(normalize(jnp.arange(2)))\n\n\nif __name__ == \"__main__\":\n    #debug_pmap()\n    #test_nested_pmap()\n\n    test_allreduce_sum()\n\n"
  },
  {
    "path": "playground/jax_basic/test_scan.py",
    "content": "from functools import partial\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax import linen as nn\nfrom flax import optim\n\nbatch_size = 32\nhidden_size = 128\n\nclass Layer(nn.Module):\n    @nn.compact\n    def __call__(self, x):\n        return\n\nclass Model(nn.Module):\n    def __call__(self, x):\n        cell = nn.scan(\n            nn.Dense,\n            variable_broadcast=\"params\",\n            in_axes=1,\n            out_axes=1,\n            split_rngs={\"params\": False},\n        )\n\n@partial(jax.jit, static_argnums=(2,))\ndef train_step(optimizer, batch, apply_fn):\n    def loss_func(params):\n        out = apply_fn(params, batch[\"x\"])\n        return jnp.mean((out - batch[\"y\"]) ** 2)\n\n    grad = jax.grad(loss_func)(optimizer.target)\n    new_optimizer = optimizer.apply_gradient(grad)\n    return new_optimizer\n\nx = jnp.ones((batch_size, hidden_size))\ny = jnp.ones((batch_size, hidden_size))\n\n# Init model and optimizer\nmodel = Model()\nrngkey = jax.random.PRNGKey(0)\nparams = model.init(rngkey, x)\noptimizer = optim.GradientDescent(1e-2).create(params)\n\n# JIT compile\noptimizer = train_step(optimizer, {\"x\": x, \"y\": y}, model.apply)\n\n"
  },
  {
    "path": "playground/jax_basic/test_sharding_spec.py",
    "content": "from functools import partial\nimport pickle\n\nimport numpy as np\n\nfrom jax.interpreters import pxla\nfrom jax.interpreters.pxla import ShardingSpec, Chunked, NoSharding, Replicated, ShardedAxis\n\n\ndef test_order():\n    a = pxla.ShardingSpec(sharding=(Chunked([2]), NoSharding()),\n                          mesh_mapping=(ShardedAxis(0), Replicated(2)))\n\n    print(\"--\")\n    print(a.indices((4, 4)).flatten()[0])\n    print(a.indices((4, 4)).flatten()[1])\n\n    b = pxla.ShardingSpec(sharding=(Chunked([2]), NoSharding()),\n                          mesh_mapping=(Replicated(2), ShardedAxis(0)))\n\n    print(\"--\")\n    print(b.indices((4, 4)).flatten()[0])\n    print(b.indices((4, 4)).flatten()[1])\n\n\ndef test_equivalent():\n    a = pxla.ShardingSpec(sharding=(Chunked([4]), Chunked([1])),\n                          mesh_mapping=(ShardedAxis(0), ShardedAxis(1)))\n\n    print(\"--\")\n    print(a.indices((4, 4)).flatten()[0])\n    print(a.indices((4, 4)).flatten()[1])\n    print(a.indices((4, 4)).flatten()[2])\n    print(a.indices((4, 4)).flatten()[3])\n\n    a = pxla.ShardingSpec(sharding=(Chunked([4]), NoSharding()),\n                          mesh_mapping=(Replicated(1), ShardedAxis(0)))\n\n    print(\"--\")\n    print(a.indices((4, 4)).flatten()[0])\n    print(a.indices((4, 4)).flatten()[1])\n    print(a.indices((4, 4)).flatten()[2])\n    print(a.indices((4, 4)).flatten()[3])\n\n\ndef test_multiple_chunks():\n    a = pxla.ShardingSpec(sharding=(Chunked([2, 2]),),\n                          mesh_mapping=(ShardedAxis(1), ShardedAxis(0)))\n\n    print(a.indices((4,)).flatten()[0])\n    print(a.indices((4,)).flatten()[1])\n    print(a.indices((4,)).flatten()[2])\n    print(a.indices((4,)).flatten()[3])\n\n\ndef test_pickle():\n    a = pxla.ShardingSpec(sharding=(Chunked([2, 2]),),\n                          mesh_mapping=(ShardedAxis(1), ShardedAxis(0)))\n\n    pickle.dump(a, open(\"tmp.pkl\", \"wb\"))\n\n    b = pickle.load(open(\"tmp.pkl\", \"rb\"))\n\n    assert a == b\n\n\ndef sharding_spec_getstate(self):\n    sharding = []\n    for x in self.sharding:\n        if isinstance(x, pxla.NoSharding):\n            sharding.append((0,))\n        elif isinstance(x, pxla.Chunked):\n            sharding.append((1, x.chunks))\n        elif isinstance(x, pxla.Unstacked):\n            sharding.append((2, x.size))\n        else:\n            raise ValueError(f\"Invalid sharding: {x}\")\n    mesh_mapping = []\n    for x in self.mesh_mapping:\n        if isinstance(x, pxla.ShardedAxis):\n            mesh_mapping.append((0, x.axis))\n        elif isinstance(x, pxla.Replicated):\n            mesh_mapping.append((1, x.replicas))\n        else:\n            raise ValueError(f\"Invalid sharding: {x}\")\n    return (sharding, mesh_mapping)\n\n\ndef sharding_spec_setstate(self, state_tuple):\n    sharding_encoding, mesh_mapping_encoding = state_tuple\n\n    sharding = []\n    for x in sharding_encoding:\n        if x[0] == 0:\n            sharding.append(pxla.NoSharding())\n        elif x[0] == 1:\n            sharding.append(pxla.Chunked(x[1]))\n        elif x[0] == 2:\n            sharding.append(pxla.Unstacked(x[1]))\n        else:\n            raise ValueError(f\"Invalid sharding: {x}\")\n\n    mesh_mapping = []\n    for x in mesh_mapping_encoding:\n        if x[0] == 0:\n            mesh_mapping.append(pxla.ShardedAxis(x[1]))\n        elif x[0] == 1:\n            mesh_mapping.append(pxla.Replicated(x[1]))\n        else:\n            raise ValueError(f\"Invalid sharding: {x}\")\n\n    self.__init__(\n        sharding=sharding,\n        mesh_mapping=mesh_mapping,\n    )\n\n\nsetattr(pxla.ShardingSpec, \"__getstate__\", sharding_spec_getstate)\nsetattr(pxla.ShardingSpec, \"__setstate__\", sharding_spec_setstate)\n\nif __name__ == \"__main__\":\n    #test_order()\n    #test_equivalent()\n    #test_multiple_chunks()\n    test_pickle()\n"
  },
  {
    "path": "playground/jax_basic/test_tuple_args.py",
    "content": "import jax\nfrom jax import numpy as jnp\n\n\n@jax.pmap\ndef many_args(*args):\n    x = 0\n    for i in range(len(args)):\n        x += args[i]\n    return x\n\nN = 110\n\nargs = [\n  jnp.ones((4, 10)) for _ in range(N)\n]\n\nout = many_args(*args)\nprint(out)\n\n"
  },
  {
    "path": "playground/jax_basic/test_while.py",
    "content": "from functools import partial\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax import linen as nn\nfrom flax import optim\n\nbatch_size = 32\nhidden_size = 128\n\nclass Model(nn.Module):\n    def setup(self):\n        self.weight = self.param(\"weight\",\n                                 jax.nn.initializers.zeros, (hidden_size, hidden_size))\n\n    def __call__(self, x):\n        def cond_func(args):\n            counter = args[0]\n            return counter < 5\n\n        def body_func(args):\n            counter, x = args \n            return [counter + 1, x @ self.weight]\n\n        return jax.lax.while_loop(cond_func, body_func, [0, x])[1]\n\n@partial(jax.jit, static_argnums=(2,))\ndef train_step(optimizer, batch, apply_fn):\n    def loss_func(params):\n        out = apply_fn(params, batch[\"x\"])\n        return jnp.mean((out - batch[\"y\"]) ** 2)\n\n    grad = jax.grad(loss_func)(optimizer.target)\n    new_optimizer = optimizer.apply_gradient(grad)\n    return new_optimizer\n\nx = jnp.ones((batch_size, hidden_size))\ny = jnp.ones((batch_size, hidden_size))\n\n# Init model and optimizer\nmodel = Model()\nrngkey = jax.random.PRNGKey(0)\nparams = model.init(rngkey, x)\noptimizer = optim.GradientDescent(1e-2).create(params)\n\n# JIT compile\noptimizer = train_step(optimizer, {\"x\": x, \"y\": y}, model.apply)\n\n"
  },
  {
    "path": "playground/jax_basic/test_xmap.py",
    "content": "from functools import partial \n\nimport numpy as np\n\nimport jax\nimport jax.numpy as jnp\nfrom jax.experimental.maps import Mesh, mesh, xmap\nfrom jax.lax import pdot, pmean, psum\nfrom jax.nn import relu\n\n\ndef test_dist_matmul():\n    func = xmap(\n        jnp.vdot,\n        in_axes=({0: 'left'}, {1: 'right'}),\n        out_axes=['left', 'right', ...],\n        axis_resources={'left': 'x', 'right': 'y'})\n\n    devices = np.array(jax.devices())[:4].reshape((2, 2))\n    with mesh(devices, ('x', 'y')):  # declare a 2D mesh with axes 'x' and 'y'\n        x = jnp.arange(20).reshape((4, 5))\n        out = func(x, x.T)\n\n        print(out.sharding_spec)\n\n\ndef test_collective_pdot():\n    def f(x, y):\n        return pdot(x, y, 'k')\n\n    x = jnp.ones((3, 4))\n    y = jnp.ones((4, 5))\n    z = jax.pmap(f, axis_name='k', in_axes=(1, 0), out_axes=None)(x, y)\n\n    print(z.sharding_spec)\n\n\ndef test_mlp():\n    def loss_func(x, y, w1, w2):\n        x = relu(pdot(x, w1, 'model'))\n        x = relu(pdot(x, w2, 'hidden'))\n        loss = (x - y) ** 2\n        loss = psum(loss, 'model')\n        loss = pmean(loss, 'batch')\n        return loss\n\n    serial_step = xmap(\n        loss_func,\n        in_axes=({0: 'batch', 1: 'model'},\n                 {0: 'batch', 1: 'model'},\n                 {0: 'model', 1: 'hidden'},\n                 {0: 'model', 1: 'hidden'},),\n        out_axes={})\n\n    parallel_step = xmap(\n        loss_func,\n        in_axes=({0: 'batch', 1: 'model'},\n                 {0: 'batch', 1: 'model'},\n                 {0: 'model', 1: 'hidden'},\n                 {0: 'model', 1: 'hidden'},),\n        out_axes={},\n        axis_resources={'batch': 'data_parallel',\n                        'hidden': 'model_parallel'})\n\n    x  = jnp.ones((8, 256))\n    y  = jnp.ones((8, 256))\n    w1 = jnp.ones((256, 1024))\n    w2 = jnp.ones((256, 1024))\n\n    serial_out = serial_step(x, y, w1, w2)\n\n    data_parallel = 2\n    model_parallel = 2\n    devices = np.array(jax.devices())[:4].reshape((data_parallel, model_parallel))\n    with mesh(devices, ('data_parallel', 'model_parallel')):\n        parallel_out = parallel_step(x, y, w1, w2)\n\n        print(parallel_out.sharding_spec)\n\n\ndef test_grad():\n    def loss(x, y):\n        loss = (x - y) ** 2\n        loss = pmean(loss, 'batch')\n        return loss\n\n\n    loss_parallel = xmap(\n        loss,\n        in_axes=({0: 'batch'},\n                 {0: 'batch'},),\n        out_axes={},\n        axis_resources={'batch': 'i'})\n\n    x = jnp.ones((16,))\n    y = jnp.ones((16,))\n\n    devices = np.array(jax.devices()[:4])\n    with mesh(devices, ('i',)):\n        # out = loss_parallel(x, y)\n        # print(out.sharding_spec)\n\n        grad = jax.grad(loss_parallel)(x, y)\n\n\nif __name__ == \"__main__\":\n    test_dist_matmul()\n    #test_collective_pdot()\n    #test_mlp()\n    #test_grad()\n\n"
  },
  {
    "path": "playground/jax_basic/util.py",
    "content": "import time\n\nimport numpy as np\n\ndef benchmark_func(func, warmup=1, repeat=3):\n    for i in range(warmup):\n        func()\n\n    costs = []\n    for i in range(repeat):\n        tic = time.time()\n        func()\n        costs.append(time.time() - tic)\n\n    return np.array(costs)\n\n"
  },
  {
    "path": "playground/other/input_pipeline.py",
    "content": "# Copyright 2022 The Flax Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\"\"\"ImageNet input pipeline.\n\"\"\"\n\nimport jax\nimport tensorflow as tf\nimport tensorflow_datasets as tfds\n\n\nIMAGE_SIZE = 224\nCROP_PADDING = 32\nMEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255]\nSTDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255]\n\n\ndef distorted_bounding_box_crop(image_bytes,\n                                bbox,\n                                min_object_covered=0.1,\n                                aspect_ratio_range=(0.75, 1.33),\n                                area_range=(0.05, 1.0),\n                                max_attempts=100):\n  \"\"\"Generates cropped_image using one of the bboxes randomly distorted.\n\n  See `tf.image.sample_distorted_bounding_box` for more documentation.\n\n  Args:\n    image_bytes: `Tensor` of binary image data.\n    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`\n        where each coordinate is [0, 1) and the coordinates are arranged\n        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole\n        image.\n    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped\n        area of the image must contain at least this fraction of any bounding\n        box supplied.\n    aspect_ratio_range: An optional list of `float`s. The cropped area of the\n        image must have an aspect ratio = width / height within this range.\n    area_range: An optional list of `float`s. The cropped area of the image\n        must contain a fraction of the supplied image within in this range.\n    max_attempts: An optional `int`. Number of attempts at generating a cropped\n        region of the image of the specified constraints. After `max_attempts`\n        failures, return the entire image.\n  Returns:\n    cropped image `Tensor`\n  \"\"\"\n  shape = tf.io.extract_jpeg_shape(image_bytes)\n  sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(\n      shape,\n      bounding_boxes=bbox,\n      min_object_covered=min_object_covered,\n      aspect_ratio_range=aspect_ratio_range,\n      area_range=area_range,\n      max_attempts=max_attempts,\n      use_image_if_no_bounding_boxes=True)\n  bbox_begin, bbox_size, _ = sample_distorted_bounding_box\n\n  # Crop the image to the specified bounding box.\n  offset_y, offset_x, _ = tf.unstack(bbox_begin)\n  target_height, target_width, _ = tf.unstack(bbox_size)\n  crop_window = tf.stack([offset_y, offset_x, target_height, target_width])\n  image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)\n\n  return image\n\n\ndef _resize(image, image_size):\n  return tf.image.resize([image], [image_size, image_size],\n                         method=tf.image.ResizeMethod.BICUBIC)[0]\n\n\ndef _at_least_x_are_equal(a, b, x):\n  \"\"\"At least `x` of `a` and `b` `Tensors` are equal.\"\"\"\n  match = tf.equal(a, b)\n  match = tf.cast(match, tf.int32)\n  return tf.greater_equal(tf.reduce_sum(match), x)\n\n\ndef _decode_and_random_crop(image_bytes, image_size):\n  \"\"\"Make a random crop of image_size.\"\"\"\n  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])\n  image = distorted_bounding_box_crop(\n      image_bytes,\n      bbox,\n      min_object_covered=0.1,\n      aspect_ratio_range=(3. / 4, 4. / 3.),\n      area_range=(0.08, 1.0),\n      max_attempts=10)\n  original_shape = tf.io.extract_jpeg_shape(image_bytes)\n  bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)\n\n  image = tf.cond(\n      bad,\n      lambda: _decode_and_center_crop(image_bytes, image_size),\n      lambda: _resize(image, image_size))\n\n  return image\n\n\ndef _decode_and_center_crop(image_bytes, image_size):\n  \"\"\"Crops to center of image with padding then scales image_size.\"\"\"\n  shape = tf.io.extract_jpeg_shape(image_bytes)\n  image_height = shape[0]\n  image_width = shape[1]\n\n  padded_center_crop_size = tf.cast(\n      ((image_size / (image_size + CROP_PADDING)) *\n       tf.cast(tf.minimum(image_height, image_width), tf.float32)),\n      tf.int32)\n\n  offset_height = ((image_height - padded_center_crop_size) + 1) // 2\n  offset_width = ((image_width - padded_center_crop_size) + 1) // 2\n  crop_window = tf.stack([offset_height, offset_width,\n                          padded_center_crop_size, padded_center_crop_size])\n  image = tf.io.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)\n  image = _resize(image, image_size)\n\n  return image\n\n\ndef normalize_image(image):\n  image -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)\n  image /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)\n  return image\n\n\ndef preprocess_for_train(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE):\n  \"\"\"Preprocesses the given image for training.\n\n  Args:\n    image_bytes: `Tensor` representing an image binary of arbitrary size.\n    dtype: data type of the image.\n    image_size: image size.\n\n  Returns:\n    A preprocessed image `Tensor`.\n  \"\"\"\n  image = _decode_and_random_crop(image_bytes, image_size)\n  image = tf.reshape(image, [image_size, image_size, 3])\n  image = tf.image.random_flip_left_right(image)\n  image = normalize_image(image)\n  image = tf.image.convert_image_dtype(image, dtype=dtype)\n  return image\n\n\ndef preprocess_for_eval(image_bytes, dtype=tf.float32, image_size=IMAGE_SIZE):\n  \"\"\"Preprocesses the given image for evaluation.\n\n  Args:\n    image_bytes: `Tensor` representing an image binary of arbitrary size.\n    dtype: data type of the image.\n    image_size: image size.\n\n  Returns:\n    A preprocessed image `Tensor`.\n  \"\"\"\n  image = _decode_and_center_crop(image_bytes, image_size)\n  image = tf.reshape(image, [image_size, image_size, 3])\n  image = normalize_image(image)\n  image = tf.image.convert_image_dtype(image, dtype=dtype)\n  return image\n\n\ndef create_split(dataset_builder, batch_size, train, dtype=tf.float32,\n                 image_size=IMAGE_SIZE, cache=False):\n  \"\"\"Creates a split from the ImageNet dataset using TensorFlow Datasets.\n\n  Args:\n    dataset_builder: TFDS dataset builder for ImageNet.\n    batch_size: the batch size returned by the data pipeline.\n    train: Whether to load the train or evaluation split.\n    dtype: data type of the image.\n    image_size: The target size of the images.\n    cache: Whether to cache the dataset.\n  Returns:\n    A `tf.data.Dataset`.\n  \"\"\"\n  if train:\n    train_examples = dataset_builder.info.splits['train'].num_examples\n    split_size = train_examples // jax.process_count()\n    start = jax.process_index() * split_size\n    split = 'train[{}:{}]'.format(start, start + split_size)\n  else:\n    validate_examples = dataset_builder.info.splits['validation'].num_examples\n    split_size = validate_examples // jax.process_count()\n    start = jax.process_index() * split_size\n    split = 'validation[{}:{}]'.format(start, start + split_size)\n\n  def decode_example(example):\n    if train:\n      image = preprocess_for_train(example['image'], dtype, image_size)\n    else:\n      image = preprocess_for_eval(example['image'], dtype, image_size)\n    return {'image': image, 'label': example['label']}\n\n  ds = dataset_builder.as_dataset(split=split, decoders={\n      'image': tfds.decode.SkipDecoding(),\n  })\n  options = tf.data.Options()\n  options.experimental_threading.private_threadpool_size = 48\n  ds = ds.with_options(options)\n\n  if cache:\n    ds = ds.cache()\n\n  if train:\n    ds = ds.repeat()\n    ds = ds.shuffle(16 * batch_size, seed=0)\n\n  ds = ds.map(decode_example, num_parallel_calls=tf.data.experimental.AUTOTUNE)\n  ds = ds.batch(batch_size, drop_remainder=True)\n\n  if not train:\n    ds = ds.repeat()\n\n  ds = ds.prefetch(10)\n\n  return ds\n"
  },
  {
    "path": "playground/other/test_cupy_partial_transfer.py",
    "content": "import time\n\nimport cupy as cp\nfrom cupy.cuda import nccl\nimport numpy as np\nimport ray\n\n\n# tensor = cp.random.normal(size=[2, 1025, 1536])\n# print(tensor)\n#\n# row_major = True\n# print(tensor.data.ptr + 2)\n# print(tensor.data.ptr + 2)\n\nMB = 1 << 20\nGB = 1 << 30\n\n\ndef do_send_recv(comm, buf, is_sender):\n    if is_sender:\n        comm.send(buf[2,:].data.ptr, buf.size / 2, nccl.NCCL_FLOAT32,\n                  1, cp.cuda.Stream.null.ptr)\n    else:\n        comm.recv(buf[2,:].data.ptr, buf.size / 2, nccl.NCCL_FLOAT32,\n                  0, cp.cuda.Stream.null.ptr)\n\n\n@ray.remote(num_gpus=1)\nclass GpuHost:\n    def __init__(self, rank, world_size, nccl_uuid_list):\n        self.rank = rank\n        self.world_size = world_size\n        self.nccl_uuid_list = nccl_uuid_list\n        self.ct = 0\n\n    def init_communicator(self, groups):\n        comm = None\n        for group in groups:\n            nccl_uuid = self.nccl_uuid_list[self.ct]\n            self.ct += 1\n            for device_id in group:\n                if self.rank == device_id:\n                    assert comm is None\n                    comm = cp.cuda.nccl.NcclCommunicator(\n                        len(group), nccl_uuid, group.index(self.rank))\n\n\n        cp.cuda.Device(0).synchronize()\n        return comm\n\n    def profile_send_recv(self, size, dtype, from_rank, to_rank):\n        groups = [[from_rank, to_rank]]\n        comm = self.init_communicator(groups)\n        if comm is None:\n            return\n\n        if self.rank == from_rank:\n            buf = cp.zeros((size, size), dtype)\n        else:\n            buf = cp.ones((size, size), dtype)\n\n        if self.rank == to_rank:\n            print(\"Before send/recv: \", buf)\n        do_send_recv(comm, buf, self.rank == from_rank)\n\n        number = min(max(10, int((1 << 30) / (size * dtype().nbytes))), 1 << 13)\n        cp.cuda.Device(0).synchronize()\n        tic = time.time()\n        for i in range(number):\n            do_send_recv(comm, buf, self.rank == from_rank)\n        cp.cuda.Device(0).synchronize()\n        toc = time.time()\n\n        if self.rank == from_rank:\n            time_cost = (toc - tic) / number\n            array_size = size * dtype().nbytes\n            communication_size = array_size\n            bandwidth = communication_size / time_cost\n            print(f\"SendRecv: {groups}\\tBytes: {array_size / GB:.5f} GB\\t\"\n                  f\"Time: {time_cost:.5f} s\\tBandwidth: {bandwidth / (1<<30):.2f} GB/s\")\n        if self.rank == to_rank:\n            print(\"After send/recv: \", buf)\n\n    def profile(self):\n        # All-reduce\n\n        # Send-recv\n        # for i in range(5, 6):\n        self.profile_send_recv(1 << 3, cp.float32, 0, 1)\n        self.profile_send_recv(1 << 3, cp.float32, 0, self.world_size - 1)\n\n\n    def sync(self):\n        return\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\")\n\n    num_gpus = int(ray.cluster_resources()[\"GPU\"])\n\n    nccl_uuid_list = [cp.cuda.nccl.get_unique_id() for _ in range(500)]\n\n    workers = []\n    for i in range(num_gpus):\n        env_vars = {\n            #\"NCCL_SOCKET_NTHREADS\": \"4\",\n            #\"NCCL_NSOCKS_PERTHREAD\": \"8\",\n            #\"NCCL_ALGO\": \"tree\",\n            #\"NCCL_DEBUG\": \"INFO\",\n        }\n        workers.append(GpuHost.options(runtime_env={\"env_vars\": env_vars}) \\\n                       .remote(i, num_gpus, nccl_uuid_list))\n\n    ray.get([w.profile.remote() for w in workers])\n    ray.get([w.sync.remote() for w in workers])\n\n"
  },
  {
    "path": "playground/other/test_ray_dataloader.py",
    "content": "import ray\nimport jax\n\nimport input_pipeline\n\n@ray.remote\nclass Worker:\n    def __init__(self):\n        self.generator = None\n\n    def register_generator(self, func):\n        self.generator = iter(func())\n\n    def get_next(self):\n        return next(self.generator)\n\n\ndef make_generator():\n    import tensorflow as tf\n    import tensorflow_datasets as tfds\n\n    dataset_builder = tfds.builder('imagenet2012:5.*.*')\n    batch_size = 64\n    image_size = 224\n    dtype = tf.float32\n    train = True\n    cache = True\n\n    ds = input_pipeline.create_split(\n        dataset_builder, batch_size, image_size=image_size, dtype=dtype,\n        train=train, cache=cache)\n    it = map(lambda xs: jax.tree_map(lambda x: x._numpy(), xs), ds)\n    return it\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\")\n\n    worker = Worker.remote()\n\n    worker.register_generator.remote(make_generator)\n\n    x = ray.get(worker.get_next.remote())\n    print(x.keys())\n    print(x['image'].shape)\n"
  },
  {
    "path": "playground/other/test_ray_put.py",
    "content": "import time\n\nimport jax\nimport ray\nimport numpy as np\n\nMB = 1024**2\nGB = 1024**3\n\n\ndef benchmark_ray(x):\n    array = np.ones((x,), dtype=np.float32)\n    warmup = 0\n    number = 1\n\n    # warm up\n    for i in range(warmup):\n        ray.put(array)\n\n    # benchmark\n    tic = time.time()\n    for i in range(number):\n        ray.put(array)\n    cost = time.time() - tic\n\n    size = np.prod(array.shape) * array.dtype.itemsize\n    bandwidth = size / (cost / number)\n    print(f\"size: {size/MB:.2f} MB, bandwidth: {bandwidth/MB:.2f} MB\")\n\n\ndef benchmark_jax_put(x):\n    batch = np.ones((x,), dtype=np.float32)\n\n    # warm up\n    for i in range(2):\n        tmp = jax.device_put(batch)\n    tmp.block_until_ready()\n\n    # benchmark\n    tic = time.time()\n    y = [None] * 10\n    for i in range(10):\n        y[i] = jax.device_put(batch)\n        #y[i] = None\n        #y[i].block_until_ready()\n    print(f\"size: {x}, time: {time.time() - tic:.2f}\")\n\n\nfor i in [1, 64, 128, 512, 1024]:\n    benchmark_ray(i * MB)\nfor i in [1, 64, 128, 512, 1024]:\n    benchmark_ray(i * MB)\nfor i in [1, 64, 128, 512, 1024]:\n    benchmark_ray(i * MB)\n\n#for i in range(10):\n#    benchmark_jax_put(8192 * 28 * 28 * 1)\n"
  },
  {
    "path": "playground/other/test_remote_call_cost.py",
    "content": "import time\n\nfrom alpa.device_mesh import Mesh\nimport numpy as np\nimport ray\n\nray.init(address=\"auto\")\nworker = ray.remote(num_gpus=1)(Worker).remote()\n\nlatencies = []\nfor i in range(1000):\n    tic = time.time()\n    ray.get(worker.check_alive.remote())\n    latency = time.time() - tic\n    print(f\"{i}, latency: {latency * 1e3:.2f} ms\")\n"
  },
  {
    "path": "playground/other/test_torch_ddp.py",
    "content": "\"\"\"\nUsage:\npython3 -m torch.distributed.launch --nproc_per_node 1 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 11000 test_torch_ddp.py\n\"\"\"\nimport torch\nimport torch.optim as optim\nfrom torch import nn\nfrom torch.nn.parallel.distributed import DistributedDataParallel as torchDDP\n#from torch.nn.parallel import DataParallel as torchDDP\n\nclass Net(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.net1 = nn.Linear(1 << 10, 1 << 19)\n        self.net2 = nn.Linear(1 << 19, 1)\n\n    def forward(self, x):\n        return self.net2(self.net1(x))\n\n\nGB = 1024 ** 3\n\ndef get_memory_usage(print_info=False):\n    \"\"\"Get accurate gpu memory usage by querying torch runtime\"\"\"\n    rank = torch.distributed.get_rank()\n    device = rank % torch.cuda.device_count()\n    allocated = torch.cuda.memory_allocated(device)\n    reserved = torch.cuda.memory_reserved(device)\n    if print_info:\n        print(\"allocated: %.2f GB\" % (allocated / GB), flush=True)\n        print(\"reserved:  %.2f GB\" % (reserved / GB), flush=True)\n    return allocated\n\ntorch.distributed.init_process_group(backend=\"nccl\", world_size=1)\n\nraw_model = Net().cuda()\n\nprint(\"After init model\", get_memory_usage() / GB)\nmodel = torchDDP(raw_model, device_ids=[0], output_device=0, gradient_as_bucket_view=True)\noptimizer = optim.SGD(model.parameters(), lr=0.001)\n\nprint(\"After torchDDP\", get_memory_usage() / GB)\n\ndata = torch.ones((1, 1<<10)).cuda()\nlabel = torch.ones((1,)).cuda()\n\noptimizer.zero_grad()\nloss = torch.square(model(data) - label).sum()\nloss.backward()\noptimizer.step()\n\nprint(\"After first backward\", get_memory_usage() / GB)\n\noptimizer.zero_grad()\nloss = torch.square(model(data) - label).sum()\nloss.backward()\noptimizer.step()\nprint(\"After second backward\", get_memory_usage() / GB)\n\n"
  },
  {
    "path": "playground/other/test_torch_trace.py",
    "content": "import torch\n\nN = 2\nH = 4\n\nloss_func = torch.nn.MSELoss()\nmodel = torch.nn.Linear(H, H)\n\ndef func(data, target, *params):\n    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)\n\n    y = model(data)\n    loss = loss_func(y, target)\n\n    print(y)\n\n    loss.backward()\n    return loss\n\ndata = torch.ones((N, H))\ntarget = torch.ones((N, H))\n\nmodel_params = tuple(model.parameters())\nfunc(*((data, target,) + model_params))\nmodel_grads = tuple(x.grad for x in model_params)\n\ngraph, output = torch.jit._get_trace_graph(func, (data, target) + model_params + model_grads)\n"
  },
  {
    "path": "playground/pipeline/auto_pipeline_slicing_dp.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import copy\\n\",\n    \"import itertools\\n\",\n    \"import time\\n\",\n    \"import math\\n\",\n    \"import numpy as np\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# simplification\\n\",\n    \"def f(L, maxl, cost, k, B):\\n\",\n    \"    if k == 1:\\n\",\n    \"        return ([L], B*max(0, L-maxl))\\n\",\n    \"    if k == L:\\n\",\n    \"        cost_ = max(1, maxl) * B\\n\",\n    \"        for i in range(k-1):\\n\",\n    \"         #   cost_ += cost[i][i]\\n\",\n    \"            cost_ += cost[i]\\n\",\n    \"        return ([1] * L, cost_)\\n\",\n    \"    \\n\",\n    \"    cost_best = float(\\\"inf\\\")\\n\",\n    \"    S_best = []\\n\",\n    \"    for i in reversed(range(k, L)):\\n\",\n    \"        S, cost_ = f(i, max(L-i, maxl), cost, k-1, B)\\n\",\n    \"        cost_ += max(0, L-i-maxl)*B\\n\",\n    \"        cost_ += cost[i-1]\\n\",\n    \"        if cost_ < cost_best:\\n\",\n    \"            cost_best = cost_\\n\",\n    \"            S.append(L-i)\\n\",\n    \"            S_best = S\\n\",\n    \"    return S_best, cost_best\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"L = 12\\n\",\n    \"k = 8\\n\",\n    \"cost = [2,1,1,3] * 12\\n\",\n    \"f(L, 0, cost, k, 3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def pipe_dp(L, cost_e, cost_c, k, B):\\n\",\n    \"    # Generate all possible max length\\n\",\n    \"    possible = [0]\\n\",\n    \"    \\n\",\n    \"    for i in range(1, L+1):\\n\",\n    \"        ptr = 0\\n\",\n    \"        while ptr + i <= L:\\n\",\n    \"            possible.append(sum(cost_e[ptr:ptr+i]))\\n\",\n    \"            ptr += 1\\n\",\n    \"    \\n\",\n    \"    possible = sorted(list(set(possible)))\\n\",\n    \"    # print(possible)\\n\",\n    \"    # trace will be a 3D list\\n\",\n    \"    trace = []\\n\",\n    \"    for i in range(L):\\n\",\n    \"        outer = []\\n\",\n    \"        for j in range(k):\\n\",\n    \"            inner = []\\n\",\n    \"            for m in range(len(possible)):\\n\",\n    \"                inner.append(([],np.infty))\\n\",\n    \"            outer.append(inner)\\n\",\n    \"        trace.append(outer)\\n\",\n    \"    \\n\",\n    \"    # i: layer id, starting from 0\\n\",\n    \"    # j: number of cut (=GPU-1)\\n\",\n    \"    for i in range(L):\\n\",\n    \"        for j in range(k):\\n\",\n    \"            for m in range(len(possible)):\\n\",\n    \"                if i+1 <= j: # invalid\\n\",\n    \"                    pass\\n\",\n    \"                else:\\n\",\n    \"                    if j == 0: # base case: 0 cut\\n\",\n    \"                        cur_sum = sum(cost_e[:i+1])\\n\",\n    \"                        assert cur_sum in possible\\n\",\n    \"                        trace[i][j][m] = ([i+1], (B-1) * max(0, cur_sum - possible[m]))\\n\",\n    \"                    else:\\n\",\n    \"                        cost_best = np.infty\\n\",\n    \"                        S_best = []\\n\",\n    \"                        for cut in range(j-1, i):\\n\",\n    \"                            cur_sum = sum(cost_e[cut+1:i+1])\\n\",\n    \"                            assert cur_sum in possible\\n\",\n    \"                            S, cost_ = trace[cut][j-1][possible.index(max(cur_sum, possible[m]))]\\n\",\n    \"                            #print(S, cost_)\\n\",\n    \"                            cost_ += (B-1) * max(0, cur_sum - possible[m])\\n\",\n    \"                            cost_ += cost_c[cut][j-1]\\n\",\n    \"                            if cost_ < cost_best:\\n\",\n    \"                                cost_best = cost_\\n\",\n    \"                                S_ = copy.deepcopy(S)\\n\",\n    \"                                S_.append(i-cut)\\n\",\n    \"                                S_best = S_\\n\",\n    \"                        trace[i][j][m] = (S_best, cost_best)\\n\",\n    \"                        \\n\",\n    \"    for i in range(L):\\n\",\n    \"        for j in range(k):\\n\",\n    \"            pass\\n\",\n    \"            #print(i, j, trace[i][j])\\n\",\n    \"    return trace[L-1][k-1][0]\\n\",\n    \"\\n\",\n    \"def brute_force(L, cost_e, cost_c, k, B):\\n\",\n    \"    best_S = []\\n\",\n    \"    best_cost = np.infty\\n\",\n    \"    ptr_done = [0] * (k-1)\\n\",\n    \"    possible = list(itertools.combinations(range(L-1), k-1))\\n\",\n    \"    for p in possible:\\n\",\n    \"        p = list(p)\\n\",\n    \"        p.append(L-1)\\n\",\n    \"        lens = [sum(cost_e[:p[0]+1])]\\n\",\n    \"        s = [p[0]+1]\\n\",\n    \"        for i in range(len(p)-1):\\n\",\n    \"            lens.append(sum(cost_e[p[i]+1:p[i+1]+1]))\\n\",\n    \"            s.append(p[i+1]-p[i])     \\n\",\n    \"        max_l = max(lens)\\n\",\n    \"        cost = (B-1) * max_l\\n\",\n    \"        for i in range(k-1):\\n\",\n    \"            cost += cost_c[p[i]][i]\\n\",\n    \"        if cost < best_cost:\\n\",\n    \"            best_cost = cost\\n\",\n    \"            best_S = s\\n\",\n    \"    return best_S, best_cost\\n\",\n    \"\\n\",\n    \"def uniform_split(L, cost_e, cost_c, k, B):\\n\",\n    \"    per_stage = L // k\\n\",\n    \"    \\n\",\n    \"    s = [per_stage] * (k-1)\\n\",\n    \"    s.append(L-sum(s))\\n\",\n    \"    p = [s[0]-1]\\n\",\n    \"    for i in range(1, k):\\n\",\n    \"        p.append(p[i-1] + s[i])\\n\",\n    \"    lens = [sum(cost_e[:p[0]+1])]\\n\",\n    \"    for i in range(len(s)-1):\\n\",\n    \"        lens.append(sum(cost_e[p[i]+1:p[i+1]+1]))\\n\",\n    \"    max_l = max(lens)\\n\",\n    \"    cost = (B-1) * max_l\\n\",\n    \"    for i in range(k-1):\\n\",\n    \"        cost += cost_c[p[i]][i]\\n\",\n    \"    return s, cost\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"0 0 [([1], 2), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0), ([1], 0)]\\n\",\n      \"0 1 [([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf), ([], inf)]\\n\",\n      \"1 0 [([2], 8), ([2], 6), ([2], 4), ([2], 2), ([2], 0), ([2], 0), ([2], 0), ([2], 0), ([2], 0), ([2], 0)]\\n\",\n      \"1 1 [([1, 1], 8.0), ([1, 1], 6.0), ([1, 1], 4.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0), ([1, 1], 2.0)]\\n\",\n      \"2 0 [([3], 12), ([3], 10), ([3], 8), ([3], 6), ([3], 4), ([3], 2), ([3], 0), ([3], 0), ([3], 0), ([3], 0)]\\n\",\n      \"2 1 [([2, 1], 10.0), ([2, 1], 8.0), ([2, 1], 6.0), ([2, 1], 4.0), ([2, 1], 2.0), ([1, 2], 2.0), ([1, 2], 2.0), ([1, 2], 2.0), ([1, 2], 2.0), ([1, 2], 2.0)]\\n\",\n      \"3 0 [([4], 22), ([4], 20), ([4], 18), ([4], 16), ([4], 14), ([4], 12), ([4], 10), ([4], 8), ([4], 2), ([4], 0)]\\n\",\n      \"3 1 [([3, 1], 14.0), ([3, 1], 12.0), ([3, 1], 10.0), ([3, 1], 8.0), ([3, 1], 6.0), ([3, 1], 4.0), ([3, 1], 2.0), ([2, 2], 2.0), ([1, 3], 2.0), ([1, 3], 2.0)]\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"([3, 1], 14.0)\"\n      ]\n     },\n     \"execution_count\": 3,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"L = 4\\n\",\n    \"k = 2\\n\",\n    \"cost_e = [1,3,2,5]\\n\",\n    \"cost_c = np.ones((L-1, k-1)) * 2\\n\",\n    \"pipe_dp(L, cost_e, cost_c, k, 3)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"test_list = [(12, 4), (24, 4), (24,8), (24, 12), (36, 8)]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"homo dp L=12 k=4 is [3, 3, 3, 3], minimum cost 12.0. Took time 0.011948347091674805\\n\",\n      \"homo bf L=12 k=4 is [3, 3, 3, 3], minimum cost 12.0. Took time 0.0019943714141845703\\n\",\n      \"homo us L=12 k=4 is [3, 3, 3, 3], minimum cost 12.0. Took time 0.0\\n\",\n      \"homo dp L=24 k=4 is [6, 6, 6, 6], minimum cost 18.0. Took time 0.10673046112060547\\n\",\n      \"homo bf L=24 k=4 is [6, 6, 6, 6], minimum cost 18.0. Took time 0.01792764663696289\\n\",\n      \"homo us L=24 k=4 is [6, 6, 6, 6], minimum cost 18.0. Took time 0.0\\n\",\n      \"homo dp L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.21442461013793945\\n\",\n      \"homo bf L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 4.285534381866455\\n\",\n      \"homo us L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.0\\n\",\n      \"homo dp L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 26.0. Took time 0.27722954750061035\\n\",\n      \"homo bf L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 26.0. Took time 32.76035165786743\\n\",\n      \"homo us L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 26.0. Took time 0.0\\n\",\n      \"homo dp L=36 k=8 is [1, 5, 5, 5, 5, 5, 5, 5], minimum cost 24.0. Took time 0.872692346572876\\n\",\n      \"homo bf L=36 k=8 is [1, 5, 5, 5, 5, 5, 5, 5], minimum cost 24.0. Took time 127.84894752502441\\n\",\n      \"homo us L=36 k=8 is [4, 4, 4, 4, 4, 4, 4, 8], minimum cost 30.0. Took time 0.0\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# homogeneous test\\n\",\n    \"for L, k in test_list:\\n\",\n    \"    cost_e = np.ones(L)\\n\",\n    \"    cost_c = np.ones((L-1, k-1)) * 2\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = pipe_dp(L, cost_e, cost_c, k, 3)\\n\",\n    \"    print(f\\\"homo dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\")\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = brute_force(L, cost_e, cost_c, k, 3)\\n\",\n    \"    print(f\\\"homo bf L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\")\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = uniform_split(L, cost_e, cost_c, k, 3)\\n\",\n    \"    print(f\\\"homo us L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 10,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"hete dp L=12 k=4 is [3, 3, 2, 4], minimum cost 65. Took time 0.046866655349731445\\n\",\n      \"hete bf L=12 k=4 is [3, 3, 2, 4], minimum cost 65. Took time 0.001994609832763672\\n\",\n      \"hete us L=12 k=4 is [3, 3, 3, 3], minimum cost 65. Took time 0.0\\n\",\n      \"hete dp L=24 k=4 is [6, 7, 7, 4], minimum cost 109. Took time 0.6502325534820557\\n\",\n      \"hete bf L=24 k=4 is [6, 7, 7, 4], minimum cost 109. Took time 0.017981767654418945\\n\",\n      \"hete us L=24 k=4 is [6, 6, 6, 6], minimum cost 114. Took time 0.0\\n\",\n      \"hete dp L=24 k=8 is [3, 3, 2, 3, 3, 3, 4, 3], minimum cost 93. Took time 1.4241876602172852\\n\",\n      \"hete bf L=24 k=8 is [3, 3, 2, 3, 3, 3, 4, 3], minimum cost 93. Took time 4.182834148406982\\n\",\n      \"hete us L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 98. Took time 0.0\\n\",\n      \"hete dp L=24 k=12 is [2, 3, 1, 1, 2, 1, 2, 2, 3, 3, 1, 3], minimum cost 104. Took time 1.7802371978759766\\n\",\n      \"hete bf L=24 k=12 is [2, 3, 1, 1, 2, 1, 2, 2, 3, 3, 1, 3], minimum cost 104. Took time 31.874720811843872\\n\",\n      \"hete us L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 114. Took time 0.0\\n\",\n      \"hete dp L=36 k=8 is [4, 4, 5, 5, 5, 4, 4, 5], minimum cost 114. Took time 6.4348156452178955\\n\",\n      \"hete bf L=36 k=8 is [4, 4, 5, 5, 5, 4, 4, 5], minimum cost 114. Took time 120.12648391723633\\n\",\n      \"hete us L=36 k=8 is [4, 4, 4, 4, 4, 4, 4, 8], minimum cost 165. Took time 0.0\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"# hetergeneous test\\n\",\n    \"for L, k in test_list:\\n\",\n    \"    cost_e = np.random.randint(low=5,high=10,size=L)\\n\",\n    \"    cost_c = np.random.randint(low=5,high=10,size=(L-1,k-1))\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = pipe_dp(L, cost_e, cost_c, k, 3)\\n\",\n    \"    print(f\\\"hete dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\")\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = brute_force(L, cost_e, cost_c, k, 3)\\n\",\n    \"    print(f\\\"hete bf L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\")\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = uniform_split(L, cost_e, cost_c, k, 3)\\n\",\n    \"    print(f\\\"hete us L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"hete dp L=12 k=4 is [2, 3, 3, 4], minimum cost 66. Took time 0.04785466194152832\\n\",\n      \"hete us L=12 k=4 is [3, 3, 3, 3], minimum cost 70. Took time 0.000997304916381836\\n\",\n      \"hete dp L=24 k=12 is [3, 3, 1, 3, 1, 2, 1, 3, 3, 1, 1, 2], minimum cost 102. Took time 1.8829903602600098\\n\",\n      \"hete us L=24 k=12 is [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], minimum cost 107. Took time 0.0\\n\"\n     ]\n    },\n    {\n     \"ename\": \"KeyboardInterrupt\",\n     \"evalue\": \"\",\n     \"output_type\": \"error\",\n     \"traceback\": [\n      \"\\u001b[1;31m---------------------------------------------------------------------------\\u001b[0m\",\n      \"\\u001b[1;31mKeyboardInterrupt\\u001b[0m                         Traceback (most recent call last)\",\n      \"\\u001b[1;32m<ipython-input-7-abad37587b85>\\u001b[0m in \\u001b[0;36m<module>\\u001b[1;34m\\u001b[0m\\n\\u001b[0;32m      4\\u001b[0m     \\u001b[0mcost_c\\u001b[0m \\u001b[1;33m=\\u001b[0m \\u001b[0mnp\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mrandom\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mrandint\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[0mlow\\u001b[0m\\u001b[1;33m=\\u001b[0m\\u001b[1;36m5\\u001b[0m\\u001b[1;33m,\\u001b[0m\\u001b[0mhigh\\u001b[0m\\u001b[1;33m=\\u001b[0m\\u001b[1;36m10\\u001b[0m\\u001b[1;33m,\\u001b[0m\\u001b[0msize\\u001b[0m\\u001b[1;33m=\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[0mL\\u001b[0m\\u001b[1;33m-\\u001b[0m\\u001b[1;36m1\\u001b[0m\\u001b[1;33m,\\u001b[0m\\u001b[0mk\\u001b[0m\\u001b[1;33m-\\u001b[0m\\u001b[1;36m1\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0;32m      5\\u001b[0m     \\u001b[0mtime_s\\u001b[0m \\u001b[1;33m=\\u001b[0m \\u001b[0mtime\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mtime\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[1;32m----> 6\\u001b[1;33m     \\u001b[0mres\\u001b[0m \\u001b[1;33m=\\u001b[0m \\u001b[0mpipe_dp\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[0mL\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[0mcost_e\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[0mcost_c\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[0mk\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[1;36m3\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0m\\u001b[0;32m      7\\u001b[0m     \\u001b[0mprint\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[1;34mf\\\"hete dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\"\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0;32m      8\\u001b[0m     \\u001b[0mtime_s\\u001b[0m \\u001b[1;33m=\\u001b[0m \\u001b[0mtime\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mtime\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\",\n      \"\\u001b[1;32m<ipython-input-6-696464c8afad>\\u001b[0m in \\u001b[0;36mpipe_dp\\u001b[1;34m(L, cost_e, cost_c, k, B)\\u001b[0m\\n\\u001b[0;32m     38\\u001b[0m                         \\u001b[0mS_best\\u001b[0m \\u001b[1;33m=\\u001b[0m \\u001b[1;33m[\\u001b[0m\\u001b[1;33m]\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0;32m     39\\u001b[0m                         \\u001b[1;32mfor\\u001b[0m \\u001b[0mcut\\u001b[0m \\u001b[1;32min\\u001b[0m \\u001b[0mrange\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[0mj\\u001b[0m\\u001b[1;33m-\\u001b[0m\\u001b[1;36m1\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[0mi\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m:\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[1;32m---> 40\\u001b[1;33m                             \\u001b[0mcur_sum\\u001b[0m \\u001b[1;33m=\\u001b[0m \\u001b[0msum\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[0mcost_e\\u001b[0m\\u001b[1;33m[\\u001b[0m\\u001b[0mcut\\u001b[0m\\u001b[1;33m+\\u001b[0m\\u001b[1;36m1\\u001b[0m\\u001b[1;33m:\\u001b[0m\\u001b[0mi\\u001b[0m\\u001b[1;33m+\\u001b[0m\\u001b[1;36m1\\u001b[0m\\u001b[1;33m]\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0m\\u001b[0;32m     41\\u001b[0m                             \\u001b[1;32massert\\u001b[0m \\u001b[0mcur_sum\\u001b[0m \\u001b[1;32min\\u001b[0m \\u001b[0mpossible\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0;32m     42\\u001b[0m                             \\u001b[0mS\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[0mcost_\\u001b[0m \\u001b[1;33m=\\u001b[0m \\u001b[0mtrace\\u001b[0m\\u001b[1;33m[\\u001b[0m\\u001b[0mcut\\u001b[0m\\u001b[1;33m]\\u001b[0m\\u001b[1;33m[\\u001b[0m\\u001b[0mj\\u001b[0m\\u001b[1;33m-\\u001b[0m\\u001b[1;36m1\\u001b[0m\\u001b[1;33m]\\u001b[0m\\u001b[1;33m[\\u001b[0m\\u001b[0mpossible\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mindex\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[0mmax\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[0mcur_sum\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[0mpossible\\u001b[0m\\u001b[1;33m[\\u001b[0m\\u001b[0mm\\u001b[0m\\u001b[1;33m]\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m]\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\",\n      \"\\u001b[1;31mKeyboardInterrupt\\u001b[0m: \"\n     ]\n    }\n   ],\n   \"source\": [\n    \"test_list_large = [(12, 4), (24, 12), (36, 8), (36, 12), (48,12), (48, 24), (64, 12), (64, 16), (128, 32), (128, 12), (128, 50)]\\n\",\n    \"for L, k in test_list_large:\\n\",\n    \"    cost_e = np.random.randint(low=5,high=10,size=L)\\n\",\n    \"    cost_c = np.random.randint(low=5,high=10,size=(L-1,k-1))\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = pipe_dp(L, cost_e, cost_c, k, 3)\\n\",\n    \"    print(f\\\"hete dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\")\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = uniform_split(L, cost_e, cost_c, k, 3)\\n\",\n    \"    print(f\\\"hete us L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time.time() - time_s}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 19,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"homo dp L=16 k=8 is [2, 2, 2, 2, 2, 2, 2, 2], minimum cost 18.0. Took time 0.05189323425292969\\n\",\n      \"homo bf L=16 k=8 is [2, 2, 2, 2, 2, 2, 2, 2], minimum cost 18.0. Took time 0.1096792221069336\\n\",\n      \"homo dp L=17 k=8 is [1, 1, 1, 2, 3, 3, 3, 3], minimum cost 20.0. Took time 0.06781816482543945\\n\",\n      \"homo bf L=17 k=8 is [1, 1, 1, 2, 3, 3, 3, 3], minimum cost 20.0. Took time 0.20744705200195312\\n\",\n      \"homo dp L=18 k=8 is [1, 1, 1, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.08078145980834961\\n\",\n      \"homo bf L=18 k=8 is [1, 1, 1, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.34108781814575195\\n\",\n      \"homo dp L=19 k=8 is [1, 1, 2, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.08978819847106934\\n\",\n      \"homo bf L=19 k=8 is [1, 1, 2, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.5295546054840088\\n\",\n      \"homo dp L=20 k=8 is [1, 1, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.11272788047790527\\n\",\n      \"homo bf L=20 k=8 is [1, 1, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.8706696033477783\\n\",\n      \"homo dp L=21 k=8 is [1, 2, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.1266329288482666\\n\",\n      \"homo bf L=21 k=8 is [1, 2, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 1.3324649333953857\\n\",\n      \"homo dp L=22 k=8 is [1, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.14860153198242188\\n\",\n      \"homo bf L=22 k=8 is [1, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 1.997645616531372\\n\",\n      \"homo dp L=23 k=8 is [2, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.17852044105529785\\n\",\n      \"homo bf L=23 k=8 is [2, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 3.0099191665649414\\n\",\n      \"homo dp L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 0.20644736289978027\\n\",\n      \"homo bf L=24 k=8 is [3, 3, 3, 3, 3, 3, 3, 3], minimum cost 20.0. Took time 4.319443702697754\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"from matplotlib import pyplot as plt\\n\",\n    \"\\n\",\n    \"test_list = [(16,8), (17, 8), (18,8), (19,8), (20, 8), (21,8), (22,8), (23, 8),(24,8)]\\n\",\n    \"dp_time = []\\n\",\n    \"bf_time = []\\n\",\n    \"\\n\",\n    \"# homogeneous test\\n\",\n    \"for L, k in test_list:\\n\",\n    \"    cost_e = np.ones(L)\\n\",\n    \"    cost_c = np.ones((L-1, k-1)) * 2\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = pipe_dp(L, cost_e, cost_c, k, 3)\\n\",\n    \"    time_elapsed = time.time() - time_s\\n\",\n    \"    dp_time.append(time_elapsed)\\n\",\n    \"    print(f\\\"homo dp L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time_elapsed}\\\")\\n\",\n    \"    time_s = time.time()\\n\",\n    \"    res = brute_force(L, cost_e, cost_c, k, 3)\\n\",\n    \"    time_elapsed = time.time() - time_s\\n\",\n    \"    bf_time.append(time_elapsed)\\n\",\n    \"    print(f\\\"homo bf L={L} k={k} is {res[0]}, minimum cost {res[1]}. Took time {time_elapsed}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 26,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"<matplotlib.legend.Legend at 0x2489528b8e0>\"\n      ]\n     },\n     \"execution_count\": 26,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    },\n    {\n     \"data\": {\n      \"image/png\": \"iVBORw0KGgoAAAANSUhEUgAAAXgAAAEGCAYAAABvtY4XAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deXhU1f3H8fd3srIEkEVFAYNaEcUQKIuKCgrVulGxWnfhR62t3bRWWndbW1vbqlVrXSvuC+JChWpFWxesCoKiqGwurCqr7CQkM+f3x7lJJjEhE8jMnZl8Xs8zz9y5987cT0b85uTcc8815xwiIpJ9ImEHEBGR5FCBFxHJUirwIiJZSgVeRCRLqcCLiGSp3LADxOvcubMrLi4OO4aISMaYNWvWaudcl/q2pVWBLy4uZubMmWHHEBHJGGa2uKFt6qIREclSKvAiIllKBV5EJEulVR98fSoqKli2bBllZWVhR8lqhYWFdOvWjby8vLCjiEgzSfsCv2zZMoqKiiguLsbMwo6TlZxzrFmzhmXLltGzZ8+w44hIM0n7LpqysjI6deqk4p5EZkanTp30V5JIlkn7Ag+ouKeAvmOR7JMRBV5EJGt9+iq8dSfEos3+0WnfBy8ikrUqtsLkC8Ei8M0xEMlp1o9XgU+iyspKcnP1FYtIA167Ab76DM59FvIKm/3j1UWTgJtuuok+ffrQp08fbr75ZhYtWkSfPn2qt99www385je/AWDYsGFcfvnlDB06lFtuuYWJEyfSp08f+vbtyxFHHBHSTyAiaWflXPjfzdD3DNh7aFIOkVHNy99O/pCPPt/QrJ95wB7tuObEAxvcPmvWLO677z6mT5+Oc47BgwczdOj2/2OsW7eOV199FYCDDjqIF154gT333JN169Y1a3YRyVCxmO+aKWgHR1+XtMOoBd+I119/nVGjRtGmTRvatm3LySefzLRp07b7ntNOO616eciQIYwZM4Z77rmHaLT5T6KISAZ65wFYOh2OuQ7adEraYTKqBb+9lnay1HdT8nXr1hGLxapf1x0/3qZNm+rlO++8k+nTp/Ovf/2L0tJSZs+eTadOyfsPKiJpbuMKePEaKD7cd88kkVrwjTjiiCOYNGkSW7ZsYfPmzTzzzDMce+yxrFy5kjVr1lBeXs6UKVMafP8nn3zC4MGDufbaa+ncuTNLly5NYXoRSTsvXAaVW+GEmyHJ159kVAs+DP3792fMmDEMGjQIgPPOO4+BAwdy9dVXM3jwYHr27Mn+++/f4PvHjRvHwoULcc4xfPhw+vbtm6roIpJuFr4EHzwFwy6Hzvsm/XBWXxdEWAYMGODq3vBj7ty59O7dO6RELYu+a5Ek2rYFbh8MuYXwo9cht6BZPtbMZjnnBtS3TS14EZFUePV6WLcExjzXbMW9MeqDFxFJti/nwBu3Qb9zoHhIyg6rAi8ikkyxKEy+CFrtAt+6NqWHTnqBN7McM3vXzBoeaiIikq1mjoflM+Hbf4TWHVN66FS04C8E5qbgOCIi6WXDF/DSb2HvI+GgU1N++KQWeDPrBhwP/COZxxERSUvP/wpiFXDCTUkf816fZLfgbwZ+BcQa2sHMzjezmWY2c9WqVUmO03R1JxbbEbNnz+a5555r8vvGjRvHgQceyLhx43bq+CISgvnPw9xnYeivoOPeoURI2jBJMzsBWOmcm2Vmwxrazzl3N3A3+HHwycqTbNFolJyc+udynj17NjNnzuS4445r0mfeddddrFq1ioKCxIZUaXpikTRRvgmeGwddesMhPwstRjJb8EOAkWa2CHgcOMrMHk7i8ZKmsrKS0aNHU1JSwimnnMKWLVsAKC4u5tprr+Wwww5j4sSJDBs2jKoLtVavXk1xcTHbtm3j6quvZsKECZSWljJhwgQ2b97M2LFjGThwIP369eOf//zn1445cuRINm/ezODBg5kwYQKLFy9m+PDhlJSUMHz4cJYsWQLAmDFjuPjiiznyyCP59a9/zccff8yIESPo27cv/fv355NPPgHgL3/5CwMHDqSkpIRrrrkmRd+cSAv18h9g/VI48RbIzQ8tRtKae865y4DLAIIW/CXOubN36kOfv9SPJ21Oux8Ex16/3V3mz5/Pvffey5AhQxg7diy33347l1xyCQCFhYW8/vrrgJ9YrK78/HyuvfZaZs6cyW233QbA5ZdfzlFHHcX48eNZt24dgwYNYsSIEbUmKXv22Wdp27Yts2fPBuDEE0/k3HPPZfTo0YwfP56f//znTJo0CYAFCxbw0ksvkZOTw+DBg7n00ksZNWoUZWVlxGIxpk6dysKFC5kxYwbOOUaOHMlrr72m+elFkuHz2TD9DhgwFnoMDjWKxsEnoHv37gwZ4i9OOPvss6sLOtSeGjhRU6dO5frrr6e0tJRhw4ZRVlZW3SJvyJtvvsmZZ54JwDnnnFMrw6mnnkpOTg4bN25k+fLljBo1CvC/fFq3bs3UqVOZOnUq/fr1o3///sybN4+FCxc2ObeINCJa6ed5b9MFhof/l3JKOmydc68Ar+z0BzXS0k4Wq3P2O/51fKs7Nze3ehrhulMIx3PO8dRTT9GrV69myVSVoaF5hZxzXHbZZfzwhz/c4eOJSALevge+mA2n3AetOoSdRi34RCxZsoQ333wTgMcee4zDDjus3v2Ki4uZNWsWAE8++WT1+qKiIjZu3Fj9+phjjuFvf/tbdUF+9913G81w6KGH8vjjjwPwyCOP1JuhXbt2dOvWrbrrpry8nC1btnDMMccwfvx4Nm3aBMDy5ctZuXJlo8cUkSZYvwz++3vY91tw4Kiw0wAq8Anp3bs3DzzwACUlJaxdu5YLLrig3v0uueQS7rjjDg499FBWr15dvf7II4/ko48+qj7JetVVV1FRUUFJSQl9+vThqquuajTDrbfeyn333UdJSQkPPfQQt9xyS737PfTQQ9x6662UlJRw6KGH8uWXX3L00Udz5plncsghh3DQQQdxyimn1PqFIyLN4Llf+WkJjr8xlDHv9dF0wVJN37XIDpo7GSac7eeaGXJhSg+9vemC1YIXEdkZZRt86323g+DgH4edphZdFSMisjP++3vY+AWc9jDk5IWdppaMaMGnUzdSttJ3LLIDls2CGXfDoB9At2+GneZr0r7AFxYWsmbNGhWgJHLOsWbNGgoLC8OOIpI5qsa8F+0ORzU+UCIMad9F061bN5YtW0Y6TkSWTQoLC+nWrVvYMUQyx1u3w4o58L2HoLBd2GnqlfYFPi8vj549e4YdQ0SkxleL4ZU/Qq/joPeJYadpUNp30YiIpBXn4LlLAINj/5w2Y97rowIvItIUHz4DC6fCUVdCh+5hp9kuFXgRkURtXQf/vhS6lsLg9J/bKe374EVE0sZ/fgubV8GZT0Ck/hv8pBO14EVEErF0BswcD4MvgD1Kw06TEBV4EZHGRCv8mPd23eDIy8NOkzB10YiINOaNW2HlR3DG41DQNuw0CVMLXkRke9Z+Cq/+2Y9373Vs2GmaRAVeRKQhzsGUiyGS58e8Zxh10YiINGTOk/Dpy3DcDdBuj7DTNJla8CIi9dmy1o9533MADBgbdpodoha8iEh9XroGtn4F507KiDHv9VELXkSkrsVvwDsPwiE/gd0PCjvNDlOBFxGJV1nux7x36AHDLg07zU5RF42ISLz/3QKrF8BZT0J+m7DT7BS14EVEqqz+GF67AQ48Gb7xrbDT7DQVeBERCMa8XwS5hfDt68NO0yzURSMiAvDeY7BoGpzwVyjaLew0zUIteBGRzWvghSug+2DoPybsNM1GBV5EZOqVUL4BTrgZItlTFrPnJxER2RGfvgrvPQpDLoTdDgg7TbNSgReRlquiDKb8AnbpCUeMCztNs9NJVhFpuabdCGs/gXMmQV6rsNM0O7XgRaRlWjUfXv8rlJwG+xwZdpqkUIEXkZYnFoPJF/m7Mx19XdhpkkZdNCLS8rz7ECx5A0beBm27hJ0madSCF5GWZdNKePEq2GsI9Ds77DRJpQIvIi3LC5fDti1+zLtZ2GmSKmkF3swKzWyGmb1nZh+a2W+TdSwRkYR8/B+YMxEOvxi67Bd2mqRLZh98OXCUc26TmeUBr5vZ8865t5J4TBGR+m3bAv+6GDrtC4ddHHaalEhagXfOOWBT8DIveLhkHU9EZLtevg6+WgSjp0BeYdhpUiKpffBmlmNms4GVwIvOuen17HO+mc00s5mrVq1KZhwRaanmPAlv3gYDz4Oeh4edJmWSWuCdc1HnXCnQDRhkZn3q2edu59wA59yALl2yd7iSiITk89nwz59Cj0PgmD+GnSalUjKKxjm3DngF+HYqjiciAsCmVfD4WdC6E3zvQcjNDztRSiVzFE0XM+sQLLcCRgDzknU8EZFaKrfBE+fCltVw+iPQdtewE6VcMkfRdAUeMLMc/C+SJ5xzU5J4PBGRGv/+tb9a9bv3wh6lYacJRTJH0bwP9EvW54uINGjmeP8YciEcdErYaUKjK1lFJLssfgOeGwf7joDh14SdJlQq8CKSPdYthQnnQIe9fNdMJCfsRKFSgReR7LBtC0w4CyrL4YzHoFWHsBOFTtMFi0jmcw4m/xy+eN8X9y69wk6UFtSCF5HM98atfhKxo66EXseGnSZtqMCLSGZb+BK8eA0ccBIc/suw06QVFXgRyVyrP4Ynx8JuB8JJt2f9/O5NpQIvIpmpbAM8foYfKXP6o5DfJuxEaUcnWUUk88Ri8PQPYM0ncO4/YZe9wk6UllTgRSTzvHwdLPg3HPuXFjX9b1Opi0ZEMsuHz8C0G6DfOTDoB2GnSWsq8CKSOb6cA5N+DN0GwfE36qRqI1TgRSQzbF4Dj50JhR3gtIcgtyDsRGlPffAikv6iFTBxNGxaAWOfh6Ldw06UEVTgRST9vXAFLJoGJ90Je34z7DQZQ100IpLe3nkQZtwFh/wUSs8IO01GSajAm9luZnavmT0fvD7AzL6f3Ggi0uItmQ5TLoa9j4QRvw07TcZJtAV/P/ACsEfwegFwUTICiYgAsOFzeOIcaL8nnDIectSj3FSJFvjOzrkngBiAc64SiCYtlYi0bBVl8PhZsG0znPE4tO4YdqKMlOivxM1m1glwAGZ2MLA+aalEpOVyDiZfCJ+/A6c9Arv2DjtRxkq0wF8MPAvsY2b/A7oALfdOtiKSPG/dDu8/DsMug94nhJ0moyVU4J1z75jZUKAXYMB851xFUpOJSMvzyX9h6pWw/wlwxK/CTpPxEirwZpYDHAcUB+852sxwzt2UxGwi0pKs/RQm/h902R9G3QkRjeLeWYl20UwGyoA5BCdaRUSaTflGPw2BmZ/bvaAo7ERZIdEC3805V5LUJCLSMsVi8MyPYPV8OPtp6Ngz7ERZI9G/gZ43s6OTmkREWqZX/wTzpsDR18E+R4adJqsk2oJ/C3jGzCJABf5Eq3POtUtaMhHJfnMnw6vXQ98z4eALwk6TdRIt8DcChwBznHMuiXlEpKVY8SE8/UM/edgJf9Xc7kmQaBfNQuADFXcRaRZb1sJjZ0BBW38xU15h2ImyUqIt+C+AV4LJxsqrVmqYpIg0WbQSJo6BjV/AmH9Bu65hJ8paiRb4z4JHfvAQEdkxL14Nn70K3/k7dB8UdpqsluiVrJqnU0R23uxH4a2/w+AfQb+zw06T9bZb4M3sZufcRWY2mWCisXjOuZFJSyYi2WXZTJh8ERQfDkf/Puw0LUJjLfiHgucbkh1ERLLYxi9hwtlQtBuc+gDk5IWdqEXYboF3zs0KFkudc7fEbzOzC4FXkxVMRLJEZbkv7mXr4fsvQptOYSdqMRIdJjm6nnVjmjGHiGQj5/wt95a9DSfdAbv3CTtRi9JYH/wZwJlATzN7Nm5TEbAmmcFEJAvMuBtmPwxHjIMDTwo7TYvTWB/8G/gx8J3xV7NW2Qi8n6xQIpIFPn0V/n0Z7HcsDLs87DQtUmN98IuBxfhpCprEzLoDDwK746cYvrtuP76IZKmvFvmLmTrtCyffrbndQ5LQt25mJ5vZQjNbb2YbzGyjmW1o5G2VwC+dc72Bg4GfmNkBOxtYRNJc+SZ/w2wXhTMeg0LNSRiWRK9k/TNwonNubqIf7Jz7At+9g3Nuo5nNBfYEPmpyShHJDFvX+Zb7yo/grInQaZ+wE7VoiRb4FU0p7nWZWTHQD5hez7bzgfMBevTosaOHEJGwrZrvJxBbtxhOvBX2HRF2ohYv0QI/08wmAJOoPdnY04290czaAk8BFznnvtat45y7G7gbYMCAAZqtUiQTzXsOnj7fzwo5egrs1eTTdpIEiRb4dsAWIP6uTg7YboE3szx8cX8kkV8GIpJhYjGYdgO8fB10LYXTH4H23cJOJYFEJxv7v6Z+sJkZcC8wV9MKi2Sh8k0w6QKY+yyUnAYn3gJ5rcJOJXESKvBmdh/1TzY2djtvGwKcA8wxs9nBusudc881OaWIpJe1n/mRMqvm+onDDvmp7siUhhLtopkSt1wIjAI+394bnHOv4+/dKiLZ5NNX/EgZF4OznoR9h4edSBqQaBfNU/Gvzewx4KWkJBKR9OQcvHUHTL0SOn8DTn9UwyDTXKIt+Lq+AWhMo0hLUVEGU34B7z0K+58Ao+6EgqKwU0kjGi3wwcnSKLApbvWXwK+TFUpE0siGz/10v8tnwdBLYeivNfVAhmi0wDvnnJnNds71T0UgEUkjS2f44l6+CU57GHqfGHYiaYJEfw2/YWYDk5pERNLLOw/C/cf7oY/nvaTinoES7YM/CrjAzBYBm/GjY5xzriRZwUQkJNEKeOFyP5f73kfCKeOhdcewU8kOSLTAH5vUFCKSHjav9kMgF03zY9tH/BZydnQshoQt0WGSi5MdRERC9sX7/uKlTStg1N3Q97SwE8lO0q9mEYEPnoJJP/FdMWP/DXtqTEU2UIEXacliUfjv7+D1v0L3g+F7D0LRbmGnkmaiAi/SUm1dB0+dBx+/CN8cA8f+BXLzw04lzUgFXqQlWrUAHj/D3zv1+Jtg4PfDTiRJoAIv0tLM/zc8/QPIyYdzn4XiIWEnkiRRgRdpKZyDaTfCf38PXUvgtEegQ/ewU0kSqcCLtATbNsOkH8NHk6DPKTDyb5DfOuxUkmQq8CLZ7qtFfnz7yo/gW9fCoT/XzTlaCBV4kWz22WvwxGhwUThrIuw7IuxEkkKa81MkGzkH0++CB0+CNl3gBy+ruLdAasGLZJvKcphyMcx+GHodB6PugsJ2YaeSEKjAi2STDV8EN+eY6W/MMfRS3ZyjBVOBF8kWS98Obs6xEb73EBwwMuxEEjIVeJFs8O7D/p6pRV3hnKdhtwPDTiRpQAVeJJNFK+CFK2DGXdBzKJx6v27OIdVU4EUy1eY1MHG0vznHwT/xY9x1cw6Jo38NIpkmWgnv3A8v/8HfDPukO6H0jLBTSRpSgRfJFM7Bxy/B1Cth1TzY6zD49h/9vDIi9VCBF8kEKz70hf2T/0LHvf1EYfsfrykHZLtU4EXS2aaV8PJ18M6DUNAOjvkjDDxPN+aQhKjAi6Sjiq3w1u0w7SaoLINBP4Shv9IIGWkSFXiRdOKcvwH2S7+B9Uuh1/F+dEznfcNOJhlIBV4kXSyZDi9c7qcZ2L0ETroDeh4edirJYCrwImFb+5lvsX80yV+JetIdUHK65pCRnaYCLxKWsvXw2g0w/U6I5MKwy+DQn0F+m7CTSZZQgRdJtWglzLoPXvkjbFkLpWfCUVdCuz3CTiZZRgVeJFWcg4Uv+vHsq+dD8eFw9O9hj9Kwk0mWUoEXSYUVH/pJwT59GTruA6c/6m/GoQuVJIlU4EWSaeMKf6HSuw/5C5W+fT0M+L4uVJKUSFqBN7PxwAnASudcn2QdRyQtVWyFN/8Or//VX6g0+EdwxDhdqCQplcwW/P3AbcCDSTyGSHqJxWouVNqwDPY/wV+o1GmfsJNJC5S0Au+ce83MipP1+SJpZ8lbwYVKs/yFSqPu1IVKEqrQ++DN7HzgfIAePXqEnEZkB+hCJUlToRd459zdwN0AAwYMcCHHEUnc1nUw7QaYfldwodLlcOhPdaGSpI3QC7xIxolWwKz7/R2Vtn4FpWcFFyp1DTuZSC0q8CKJcg4WTg0uVFrgL1Q65jro2jfsZCL1SuYwyceAYUBnM1sGXOOcuzdZxxNJqi8/gKlXwKevBBcqPQa9jtWFSpLWkjmKRncBlszmHCx/B2bcBXMmBhcq/QkGjNWFSpIR1EUjUte2LX4s+9v/gC9mQ35bOPjHcPgvdaGSZBQVeJEqaz6BmePh3YehbB106Q3H3QB9T4eCorDTiTSZCry0bLEoLHjBt9Y/+Y8f7tj7RBj4A9jrUPWxS0ZTgZeWadMqePdBmHmfv/dp0R5w5BXQ/1wo2j3sdCLNQgVeWg7nYOkMePse+HASxCqg51A45g9+6t4c/e8g2UX/oiX7lW/yo2DevhdWzPGjYQZ+30/b22W/sNOJJI0KvGSvVQtg5r0w+1Eo3wC79YETboaDToWCtmGnE0k6FXjJLtFKmP+c74b57DWI5MGBJ/mTpt0H6aSptCgq8JIdNn4Jsx7wc8Rs/Bzad4fhV0O/c6Ftl7DTiYRCBV4yl3Ow+H9+iOPcyRCrhH2Gw/E3wn7HQCQn7IQioVKBl8xTtgHen+BPmq6aC4Xt/S3xBozVnZNE4qjAS+ZY8ZE/afre47Btk5/FceRt0Oe7kN867HQiaUcFXtJb5TaYN8V3wyz+H+QUQJ+T/UnTPfvrpKnIdqjAS3pav9yfMH3nAdi0Ajrs5W9eXXo2tOkUdjqRjKACL+nDOfjsVd9an/ccuBh842gYeB7sO1wnTUWaSAVewhOthBUfwNLp/rHkLdiwHFp19Pc2HTAWdikOO6VIxlKBl9QpWw/L3oYl02HpW7BsFlRs9tuK9oAeg+Ebx8CBoyCvMNysIllABV6SwzlYt7immC+ZDis/AhxYBHY7EErPhB4HQ/fB0L6bTpiKNDMVeGkeldvgyzlBMX/Ld7lsWuG35RdB94FwwEhfzLsN0A00RFJABV52zJa1QXdLUMyXvwOVW/22Dj38NLzdB/kW+q4H6ASpSAhU4KVxzvnb2S2N625ZPd9vi+TC7iUw4P98Qe9+MLTrGm5eEQFU4KU+FWX+ZtNLpwd96NNhy2q/rbC972Yp+Z5vne/RX1eRiqQpFXjxt6+rGqq4dDp8/i5Et/ltHff2Y9F7DPaFvXMviETCzSsiCVGBb2kqymDNQlg+q6Z1vvYTvy0nH7qWwuAf+q6W7oM11a5IBlOBz1blm2D1Alg13/eXr5oPq+bBV4v8FaIArTv5Qt7/XN/d0rVU489FsogKfKbb+pW/Nd2qeUFBn+eL+fqlNftE8qDTvrD7Qf52dZ3388W80z4aey6SxVTgM4FzsHl1ULzrFPKqseYAuYW+ePc4GLqMhi77+z7zjj0hJy+8/CISChX4dOIcbPi8pnjHd61s/apmv/wi6NIL9h3hn7vs7wt7hx4aby4i1VTgwxCL+sv4Vy2oU8wXwLaNNfu16uiL9wEnBYU8KOZFXdW1IiKNUoFPpmgFrP00aIUHLfHV82H1Qqgsq9mvqKtvgZeeCV3280W8y/7QpnN42UUkZWIxRyTS/I02FfidUbYB1i/zJzTXL4V1S+NeL4ONX9SMWAHfhdK5l7+Mv8v+vkXeeT9o1SG8n0FEqjnnqIg6yiqjlFfEKKuIUl4ZpawiVv3s1/nnuuur3lffe77+vhjlwXt2aZ3PjCtGNPvPowLfkFjMn8BsqHivX+qnv40XyYP2e0L77r6It+/mR6902c8X8vw24fwsIhmqquBurYj6YlgRY2tQLMsqosGyL6RbtwXrK2N+OSi2VctVhbVqn/JaBbdmOeZ2PG9+ToSCvAiFeTkU5kUoyPXPhbk5tMrPYZfW+RTm5VAQvy0vh/atkjMIouUW+IqtNYW6unjHtcbXL4dYRe33FLaH9j18Ae9xCHTo7ot4+x7+ue1uuspTslIs5tgW9a3ObZUxKqL+eVvwXHd9VbHcGlc8awpwlK3bYkEBrinSVfuXV9R+744W3PzcCIW5EVrl5/iCGxTUgqCgFhQVBOt9kS3I/XphLsjLqb0tt2pdJPi8mm0FuZGkdLPsjOws8M752Q7XL/FFu7qAx72umlulikV8X3j77rDnAH/Tifji3b4bFLYL5+eRrOOcozLmiMYcFdEY0Zh/XRl1VMZiwXoXrI8F64PX0Vj1clWBjS+2FXGFuNb2qm111m2LxhXsqmIdrV2wK3emWRuIGEEBzaFV0IqtatkW5kXo0CqPwvyaQtwq2LewukVc897a6yJx62sKdE6aFdswZH6Bj8Vg2o1fL+ZVU9dWyW0VtLi7+9kPq5bbB63wdntorHgGc652C7Nu0Suv9H/G111ftVxeEfv6+6PR6vVVhTe+0NYqzMG2+NfRqKOiTlGuem+0GQpmIvJyjPycCPm5EfKC5/zciO9KCJYLciMUFeZW71e1Li8nUmtd1ftqPTewriCuSFe1ePNzIphGf6VU5hf4SATe/JufR6V9N9i1t58cq7r7JCjirTtm/dDCqtZeVZGJxrf64tcHLcWY849orOrZf4Zzjmid9THniMVqr4/FIBq3Phar2kb151atjzniPs+vd8G6+PWVURdXlKP1t0TjCnL8uuZSb/HKjZAbMXIiRm5OzXJBXoTWEf86N2Lk5hg5kQh51fsauZGIX67z3rxg39zq/YLX1ctGXk7j7y2oU4DzqnLnpF+XgaRW5hd4gEsWQm7BdneJxRwV0ahvgUUdFUGLq6L6z90YFXW2VUZjVAStr4q4FlzVe+LXV723MlZ7W01RjcUVW7+9bsGttxDHF+hoA+uD1y41jcKdYgY5ZkQi5p8Nvxy8zs2xuMKaU1282hbkkt+6dgvT75dTe11wkquhIl0Q95l1W6xV69TKlGyR1AJvZt8GbgFygH84565PxnGOv30GW7ZFfeGtVXBrCnSK/iIG/J/FuUFLLL4FVvs5EtfCq3ldkJdb/35xLcJ618e1/qpbfBEjJ67VV/VctRwx/8iJxBXcCDXrrKbwNrTejOrPrNlO9efVXa/iKZI6SSvwZpYD/B34FrAMeNvMnnXOfdTcx9pvtyIqY468qoKZEwmWgyIbV/A/J2cAAAhJSURBVGyr/tStLsKNvqd2sc4N/myu+tM7L6fmT+e8nIiKmIikjWS24AcBHzvnPgUws8eB7wDNXuD/elppc3+kiEjGS+ag7T2BuDlrWRasq8XMzjezmWY2c9WqVUmMIyLSsiSzwNfXT/G1nnDn3N3OuQHOuQFduujuQSIizSWZBX4Z0D3udTfg8yQeT0RE4iSzwL8NfMPMeppZPnA68GwSjyciInGSdpLVOVdpZj8FXsAPkxzvnPswWccTEZHakjoO3jn3HPBcMo8hIiL109SHIiJZSgVeRCRLmUujCUzMbBWweAff3hlY3eheqadcTaNcTaNcTZONufZyztU7xjytCvzOMLOZzrkBYeeoS7maRrmaRrmapqXlUheNiEiWUoEXEclS2VTg7w47QAOUq2mUq2mUq2laVK6s6YMXEZHasqkFLyIicVTgRUSyVEYWeDMbb2YrzeyDOut/ZmbzzexDM/tzOuQyswlmNjt4LDKz2WmSq9TM3gpyzTSzQWmSq6+ZvWlmc8xsspm1S3Gm7mb2spnNDf4dXRis72hmL5rZwuB5lzTJdWrwOmZmKR/+t51cfzGzeWb2vpk9Y2Yd0iTX74JMs81sqpntkcpc28sWt/0SM3Nm1nmnD+acy7gHcATQH/ggbt2RwEtAQfB613TIVWf7jcDV6ZALmAocGywfB7ySJrneBoYGy2OB36U4U1egf7BcBCwADgD+DFwarL8U+FOa5OoN9AJeAQaE8N+woVxHA7nB+j+l0ffVLm6fnwN3pst3Frzujp+gcTHQeWePlZEteOfca8DaOqsvAK53zpUH+6xMk1wAmL9R6/eAx1IaigZzOaCqddyeEObqbyBXL+C1YPlF4LspzvSFc+6dYHkjMBd/J7LvAA8Euz0AnJQOuZxzc51z81OZJcFcU51zlcFub+HvB5EOuTbE7daGem5CFFa2YPNfgV81V66MLPAN2A843Mymm9mrZjYw7EB1HA6scM4tDDtI4CLgL2a2FLgBuCzkPFU+AEYGy6dS+6YxKWVmxUA/YDqwm3PuC/D/gwK7pkmutLGdXGOB51Odp0rdXGZ2XfDv/izg6rByBVmKCbKZ2UhguXPuveb6/Gwq8LnALsDBwDjgiaDVnC7OIITW+3ZcAPzCOdcd+AVwb8h5qowFfmJms/B/vm4LI4SZtQWeAi6q0+oLVablMrMrgErgkXTJ5Zy7Ivh3/wjw0zBy1c2G/46uoJl/4WRTgV8GPO28GUAMP4FP6MwsFzgZmBB2ljijgaeD5YlAyk+y1sc5N885d7Rz7pv4X4ifpDqDmeXh/8d7xDlX9R2tMLOuwfauQMq7ABvIFbqGcpnZaOAE4CwXdDCnQ644j5LiLsAq9WTbB+gJvGdmi/BdWu+Y2e47c5xsKvCTgKMAzGw/IJ/0mTVuBDDPObcs7CBxPgeGBstHAWnRdWRmuwbPEeBK4M4UH9/wf83Mdc7dFLfpWfwvRYLnf6ZJrlA1lMvMvg38GhjpnNuSRrm+EbfbSGBeOmRzzs1xzu3qnCt2zhXjG6z9nXNf7tTBUn0GuZnOQj8GfAFUBF/E9/EF/WF8H+47wFHpkCtYfz/wozT7vg4DZgHv4fsmv5kmuS7EjypYAFxPcLV1CjMdhj/B9T4wO3gcB3QC/oP/RfgfoGOa5BoVfHflwArghTTJ9TGwNG5dSkerbCfXU0GNeB+YjD/xmrJc28tWZ59FNMMoGk1VICKSpbKpi0ZEROKowIuIZCkVeBGRLKUCLyKSpVTgRUSylAq8ZCUzeyUVsyua2c+DWQEfqbN+mJlNSfbxRbYnN+wAIunGzHJdzURZjfkxflbOz5KZqa4mZpQWSi14CY2ZFQet33uCebGnmlmrYFt1C9zMOgeXb2NmY8xsUjBX/Gdm9lMzu9jM3jU/v33HuEOcbWZvmNkHFsx3b2ZtzM9D/3bwnu/Efe5EM5uMn0q5btaLg8/5wMwuCtbdCewNPGtmv9jOzzkoyPFu8NwrWD/NzErj9vufmZUkmtHMuprZa8Hc5h+Y2eE7/l9DslKqr+LSQ4+qB1CMn2SpNHj9BHB2sPwKwfzm+DmFFgXLY/BXSRYBXYD1BFcJ46davSju/fcEy0cQzDkP/CHuGB3wV8y2CT53GfVcoQp8E5gT7NcW+BDoF2xbRD1XHALDgCnBcjtq5kYfATwVLI8Gbg6W9wNmNiUj8EvgimA5BygK+7+pHun1UBeNhO0z51zVXa5m4Yt+Y152fh7tjWa2Hn/JOfgiXBK332Pg5503s3bm7yp0NDDSzC4J9ikEegTLLzrn6pvP/zDgGefcZgAzexo//fO7ifyA+Pn2HwjmQXFAXrB+InCVmY3Dz6J5f7A+0YxvA+ODiasmxX2PIoC6aCR85XHLUWrOC1VS8++zcDvvicW9jlH7vFLdeTgcYMB3nXOlwaOHc25usH1zAxl3dtrp3+F/KfUBTiT4eZyfhOtF/M1Evoef3bDqeI1mdP6GKUcAy4GHzOzcncwpWUYFXtLVInzXCMApO/gZpwGY2WHAeufcevzt0H5Wda8AM+uXwOe8BpxkZq3NrA1+gq9pTcjRHl+EwXezxPsHcCvwdlzLPKGMZrYXsNI5dw9+dsL+TcgkLYAKvKSrG4ALzOwNdnxe/6+C99+Jn6kSfGs6D3jf/M2+f9fYhzh/e7X7gRn4mTf/4ZxLtHsG/P1c/2hm/8P3lcd/9ixgA3Bf3OpEMw4DZpvZu/h5zW9pQiZpATSbpEiIzGwP/Anh/Z1zsZDjSJZRC14kJEGf+XT8SBgVd2l2asGLiGQpteBFRLKUCryISJZSgRcRyVIq8CIiWUoFXkQkS/0/QP3a2iPYzbUAAAAASUVORK5CYII=\\n\",\n      \"text/plain\": [\n       \"<Figure size 432x288 with 1 Axes>\"\n      ]\n     },\n     \"metadata\": {\n      \"needs_background\": \"light\"\n     },\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"plt.plot([16,17, 18, 19, 20, 21, 22, 23, 24], dp_time, label=\\\"ours\\\")\\n\",\n    \"plt.plot([16,17, 18, 19, 20, 21, 22, 23, 24], bf_time, label=\\\"brute force\\\")\\n\",\n    \"plt.xlabel(\\\"number of layers\\\")\\n\",\n    \"plt.ylabel(\\\"runtime\\\")\\n\",\n    \"plt.legend(loc=\\\"best\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "playground/pipeline/jax_array_slicing.py",
    "content": "import jax\nimport numpy\nfrom jax import core, xla\nfrom jax._src.util import (partial, unzip3)\nfrom jax.abstract_arrays import array_types\nfrom jax.interpreters import pxla\nfrom jax.interpreters.pxla import (ShardingSpec, Chunked, NoSharding, Replicated,\n                                   ShardedAxis, _as_slice_indices, _hashable_index, ShardedDeviceArray)\nimport numpy as np\nfrom jax.lib import xla_client, xla_bridge\nimport jax.numpy as jnp\nfrom alpa.util import jax_buffer_set, jax_buffer_set_v2\n\n\noffset = [0, 4]\nm = jnp.zeros([10, 10], dtype=np.float32)\nprint(m.__cuda_array_interface__)\nn = jnp.ones([2, 2], dtype=np.float32)\nprint(n.__cuda_array_interface__)\nk = jax_buffer_set_v2(m, n, tuple(offset))\nprint(k.__cuda_array_interface__)\nprint(k)\n"
  },
  {
    "path": "playground/pipeline/mesh_slicing.ipynb",
    "content": "{\r\n \"cells\": [\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"outputs\": [],\r\n   \"source\": [\r\n    \"import time\\n\",\r\n    \"\\n\",\r\n    \"import copy\\n\",\r\n    \"import numpy as np\\n\"\r\n   ],\r\n   \"metadata\": {\r\n    \"collapsed\": false,\r\n    \"pycharm\": {\r\n     \"name\": \"#%%\\n\"\r\n    }\r\n   }\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"outputs\": [],\r\n   \"source\": [\r\n    \"def draw_fill(puzzle, patternLength, patternWidth, start, count, solList):\\n\",\r\n    \"    count += 1\\n\",\r\n    \"    puzzleLength, puzzleWidth = puzzle.shape\\n\",\r\n    \"    patternNum = (puzzleWidth*puzzleLength)/(patternWidth*patternLength)\\n\",\r\n    \"    \\n\",\r\n    \"    horizonal = False\\n\",\r\n    \"    if start[0] + patternLength <= puzzleLength and start[1] + patternWidth <= puzzleWidth:\\n\",\r\n    \"        horizonal = True\\n\",\r\n    \"        #if (puzzle[start[0]:start[0]+patternLength, start[1]:start[1]+patternWidth] != 0).any():\\n\",\r\n    \"        for i in range(start[0], start[0]+patternLength):\\n\",\r\n    \"             for j in range(start[1], start[1]+patternWidth):\\n\",\r\n    \"                 if puzzle[i][j] != 0:\\n\",\r\n    \"                     horizonal = False\\n\",\r\n    \"    if horizonal:\\n\",\r\n    \"        newPuzzle = copy.deepcopy(puzzle)\\n\",\r\n    \"        for i in range(start[0], start[0]+patternLength):\\n\",\r\n    \"            for j in range(start[1], start[1]+patternWidth):\\n\",\r\n    \"                newPuzzle[i][j] = count\\n\",\r\n    \"        if count == patternNum:\\n\",\r\n    \"            solList.append(newPuzzle)\\n\",\r\n    \"            return\\n\",\r\n    \"        for i in range(start[0], puzzleLength):\\n\",\r\n    \"            for j in range(0, puzzleWidth):\\n\",\r\n    \"                if newPuzzle[i][j] == 0:\\n\",\r\n    \"                    newStart = (i, j)\\n\",\r\n    \"                    break\\n\",\r\n    \"            else:\\n\",\r\n    \"                continue\\n\",\r\n    \"            break\\n\",\r\n    \"        draw_fill(newPuzzle, patternLength, patternWidth, newStart, count, solList)\\n\",\r\n    \"\\n\",\r\n    \"    vertical = False\\n\",\r\n    \"    if patternLength != patternWidth and start[0]+patternWidth <= puzzleLength and start[1]+patternLength <= puzzleWidth:\\n\",\r\n    \"        vertical = True\\n\",\r\n    \"        for i in range(start[0], start[0]+patternWidth):\\n\",\r\n    \"            for j in range(start[1], start[1]+patternLength):\\n\",\r\n    \"                if puzzle[i][j] != 0:\\n\",\r\n    \"                    vertical = False\\n\",\r\n    \"    if vertical:\\n\",\r\n    \"        newPuzzle = copy.deepcopy(puzzle)\\n\",\r\n    \"        for i in range(start[0], start[0]+patternWidth):\\n\",\r\n    \"            for j in range(start[1], start[1]+patternLength):\\n\",\r\n    \"                newPuzzle[i][j] = count\\n\",\r\n    \"        if count == patternNum:\\n\",\r\n    \"            solList.append(newPuzzle)\\n\",\r\n    \"            return\\n\",\r\n    \"        for i in range(start[0], puzzleLength):\\n\",\r\n    \"            for j in range(0, puzzleWidth):\\n\",\r\n    \"                if newPuzzle[i][j] == 0:\\n\",\r\n    \"                    newStart = (i, j)\\n\",\r\n    \"                    break\\n\",\r\n    \"            else:\\n\",\r\n    \"                continue\\n\",\r\n    \"            break\\n\",\r\n    \"        draw_fill(newPuzzle, patternLength, patternWidth, newStart, count, solList)\\n\",\r\n    \"\\n\",\r\n    \"def backtrack(puzzleLength, puzzleWidth, patternLength, patternWidth):\\n\",\r\n    \"    patternNum = (puzzleWidth*puzzleLength)/(patternWidth*patternLength)\\n\",\r\n    \"    solList = []\\n\",\r\n    \"    if patternNum%1 == 0:\\n\",\r\n    \"        inputPuzzle = np.zeros((puzzleLength, puzzleWidth))\\n\",\r\n    \"        draw_fill(inputPuzzle, patternLength, patternWidth, (0, 0), 0, solList)\\n\",\r\n    \"    #solList = np.asarray(solList).reshape((puzzleLength, puzzleWidth))\\n\",\r\n    \"    return solList\"\r\n   ],\r\n   \"metadata\": {\r\n    \"collapsed\": false,\r\n    \"pycharm\": {\r\n     \"name\": \"#%%\\n\"\r\n    }\r\n   }\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": 76,\r\n   \"metadata\": {\r\n    \"scrolled\": true\r\n   },\r\n   \"outputs\": [\r\n    {\r\n     \"data\": {\r\n      \"text/plain\": [\r\n       \"([array([[1., 1., 1., 1.],\\n\",\r\n       \"         [1., 1., 1., 1.],\\n\",\r\n       \"         [1., 1., 1., 1.],\\n\",\r\n       \"         [1., 1., 1., 1.],\\n\",\r\n       \"         [1., 1., 1., 1.],\\n\",\r\n       \"         [1., 1., 1., 1.],\\n\",\r\n       \"         [1., 1., 1., 1.],\\n\",\r\n       \"         [1., 1., 1., 1.]]),\\n\",\r\n       \"  array([[1., 1., 1., 1.],\\n\",\r\n       \"         [2., 2., 2., 2.],\\n\",\r\n       \"         [3., 3., 3., 3.],\\n\",\r\n       \"         [4., 4., 4., 4.],\\n\",\r\n       \"         [5., 5., 5., 5.],\\n\",\r\n       \"         [6., 6., 6., 6.],\\n\",\r\n       \"         [7., 7., 7., 7.],\\n\",\r\n       \"         [8., 8., 8., 8.]])],\\n\",\r\n       \" array([1., 1.]))\"\r\n      ]\r\n     },\r\n     \"execution_count\": 76,\r\n     \"metadata\": {},\r\n     \"output_type\": \"execute_result\"\r\n    }\r\n   ],\r\n   \"source\": [\r\n    \"def get_cost_c(conf, L, cluster_info=None):\\n\",\r\n    \"    # homogeneous setting; in real setting, we access cluster to get cost_c\\n\",\r\n    \"    num_stages = int(np.max(conf))\\n\",\r\n    \"    stage_cost = []\\n\",\r\n    \"    for i in range(1, num_stages):\\n\",\r\n    \"        b = np.where(conf == i)\\n\",\r\n    \"        c = np.where(conf == i+1)\\n\",\r\n    \"        # All pairs of GPU in the same node\\n\",\r\n    \"        if (b[1] == c[1]).all():\\n\",\r\n    \"            stage_cost.append(0)\\n\",\r\n    \"        else:\\n\",\r\n    \"            stage_cost.append(1)\\n\",\r\n    \"    stage_cost = np.asarray(stage_cost).reshape((1,-1))\\n\",\r\n    \"    ret = copy.deepcopy(stage_cost)\\n\",\r\n    \"    for i in range(L-1):\\n\",\r\n    \"        ret = np.concatenate((ret, stage_cost), axis=0)\\n\",\r\n    \"    return ret\\n\",\r\n    \"\\n\",\r\n    \"def get_cost_e(conf, L, cluster_info=None):\\n\",\r\n    \"    # homogeneous setting; in real setting, we access cluster to get cost_e\\n\",\r\n    \"    # return amp_simulator()\\n\",\r\n    \"    #print(conf.shape[0] * conf.shape[1])\\n\",\r\n    \"    num_gpus_per_pipeline = conf.shape[0] * conf.shape[1] / np.max(conf)\\n\",\r\n    \"    return np.ones(L) / num_gpus_per_pipeline\\n\",\r\n    \"\\n\",\r\n    \"def generate_initial(M, N, threads=2):\\n\",\r\n    \"    h_w_list = []\\n\",\r\n    \"    \\n\",\r\n    \"    h_w_list.append((M, 1))\\n\",\r\n    \"    h_w_list.append((1, N))\\n\",\r\n    \"    known = {}\\n\",\r\n    \"    \\n\",\r\n    \"    configs = []\\n\",\r\n    \"    for (h, w) in h_w_list:\\n\",\r\n    \"        solution = backtrack(M, N, h, w)\\n\",\r\n    \"        \\n\",\r\n    \"        assert len(solution) > 0\\n\",\r\n    \"        config_idx = np.random.choice(len(solution), size=1)[0]\\n\",\r\n    \"        config = solution[config_idx]\\n\",\r\n    \"        configs.append(config)\\n\",\r\n    \"        \\n\",\r\n    \"        solution.pop(config_idx)\\n\",\r\n    \"        \\n\",\r\n    \"        known[(h, w)] = solution\\n\",\r\n    \"        \\n\",\r\n    \"    #print(np.asarray(configs[0]))\\n\",\r\n    \"    return h_w_list, configs, known\\n\",\r\n    \"    \\n\",\r\n    \"\\n\",\r\n    \"def cool_down(iter, max_iter, init_temp):\\n\",\r\n    \"    return init_temp * (1 - iter / max_iter)\\n\",\r\n    \"\\n\",\r\n    \"def neighbor(cur, known, M, N, maximum_try = 10):\\n\",\r\n    \"    h, w = cur\\n\",\r\n    \"    \\n\",\r\n    \"    time_s = time.time()\\n\",\r\n    \"    while time.time() - time_s < 10:\\n\",\r\n    \"        index = np.random.choice([0,1], size=1)[0]\\n\",\r\n    \"        if index == 0:\\n\",\r\n    \"            valid = []\\n\",\r\n    \"            upper = min(M, N)\\n\",\r\n    \"            upper = min((M*N) // w, upper) + 1\\n\",\r\n    \"            \\n\",\r\n    \"            for i in range(1, upper):\\n\",\r\n    \"                if (i, w) in known.keys():\\n\",\r\n    \"                    solution = known[(i, w)]\\n\",\r\n    \"                else:\\n\",\r\n    \"                    solution = backtrack(M, N, i, w)\\n\",\r\n    \"                    known[(i, w)] = solution\\n\",\r\n    \"\\n\",\r\n    \"                if len(solution) > 0:\\n\",\r\n    \"                    valid.append(i)\\n\",\r\n    \"\\n\",\r\n    \"            if len(valid) == 0:\\n\",\r\n    \"                continue\\n\",\r\n    \"                #return\\n\",\r\n    \"                \\n\",\r\n    \"            new_h = np.random.choice(valid, size=1)[0]\\n\",\r\n    \"            \\n\",\r\n    \"            # TODO\\n\",\r\n    \"            new_config_idx = np.random.choice(len(known[(new_h, w)]), size=1)[0]\\n\",\r\n    \"            ret = known[(new_h, w)].pop(new_config_idx)\\n\",\r\n    \"            return new_h, w, ret\\n\",\r\n    \"\\n\",\r\n    \"        else:\\n\",\r\n    \"            valid = []\\n\",\r\n    \"            upper = min(M, N)\\n\",\r\n    \"            upper = min((M*N) // h, upper) + 1\\n\",\r\n    \"            for i in range(1, upper):\\n\",\r\n    \"                if (h, i) in known.keys():\\n\",\r\n    \"                    solution = known[(h, i)]\\n\",\r\n    \"                else:\\n\",\r\n    \"                    solution = backtrack(M, N, h, i)\\n\",\r\n    \"                    known[(h, i)] = solution\\n\",\r\n    \"\\n\",\r\n    \"                if len(solution) > 0:\\n\",\r\n    \"                    valid.append(i)\\n\",\r\n    \"\\n\",\r\n    \"            if len(valid) == 0:\\n\",\r\n    \"                continue\\n\",\r\n    \"\\n\",\r\n    \"            new_w = np.random.choice(valid, size=1)[0]\\n\",\r\n    \"            new_config_idx = np.random.choice(len(known[(h, new_w)]), size=1)[0]\\n\",\r\n    \"            ret = known[(h, new_w)].pop(new_config_idx)  \\n\",\r\n    \"            return h, new_w, ret\\n\",\r\n    \"    return None\\n\",\r\n    \"    \\n\",\r\n    \"def predict(configs, L, B):\\n\",\r\n    \"    costs = []\\n\",\r\n    \"    for i in range(len(configs)):\\n\",\r\n    \"        config = configs[i]\\n\",\r\n    \"        config = np.asarray(config)\\n\",\r\n    \"        #config = config.reshape((config.shape[0], config.shape[2]))\\n\",\r\n    \"        cost_e = get_cost_e(config, L)\\n\",\r\n    \"        cost_c = get_cost_c(config, L)\\n\",\r\n    \"        k = int(np.max(config))\\n\",\r\n    \"\\n\",\r\n    \"        # refer to pipeling slicing\\n\",\r\n    \"        cost = pipe_dp(L, cost_e, cost_c, k, B)[1]\\n\",\r\n    \"        costs.append(cost)\\n\",\r\n    \"    return np.asarray(costs)\\n\",\r\n    \"\\n\",\r\n    \"# number of GPU per node\\n\",\r\n    \"M = 8\\n\",\r\n    \"# \\n\",\r\n    \"N = 4\\n\",\r\n    \"num_iter = 500\\n\",\r\n    \"init_t = 1\\n\",\r\n    \"\\n\",\r\n    \"# 16 layers network, 3 microbatches\\n\",\r\n    \"L = 16\\n\",\r\n    \"B = 3\\n\",\r\n    \"\\n\",\r\n    \"h_w, configs, known = generate_initial(M, N)\\n\",\r\n    \"costs = predict(configs, L, B)\\n\",\r\n    \"\\n\",\r\n    \"for i in range(num_iter):\\n\",\r\n    \"    cur_t = cool_down(i, num_iter, init_t)   \\n\",\r\n    \"    \\n\",\r\n    \"    new_configs = []\\n\",\r\n    \"    new_h_w = []\\n\",\r\n    \"    \\n\",\r\n    \"    for (h, w) in h_w:\\n\",\r\n    \"        step = neighbor((h, w), known, M, N)\\n\",\r\n    \"        if step is None:\\n\",\r\n    \"            new_h, new_w, new_config = (h, w, configs[h_w.index((h,w))])\\n\",\r\n    \"            \\n\",\r\n    \"        else:\\n\",\r\n    \"            new_h, new_w, new_config = step\\n\",\r\n    \"        if step is None:\\n\",\r\n    \"            assert False\\n\",\r\n    \"        else:\\n\",\r\n    \"            pass\\n\",\r\n    \"            #print(step)\\n\",\r\n    \"        new_h_w.append((new_h, new_w))\\n\",\r\n    \"        new_configs.append(new_config)\\n\",\r\n    \"        \\n\",\r\n    \"    new_costs = predict(new_configs, L, B)\\n\",\r\n    \"    \\n\",\r\n    \"    acc_prob = np.exp(np.minimum((costs - new_costs)/ (cur_t+1e-5) , 0))\\n\",\r\n    \"    \\n\",\r\n    \"    acc_index = (np.random.random(len(acc_prob)) < acc_prob)\\n\",\r\n    \"    \\n\",\r\n    \"    for j in range(len(configs)):\\n\",\r\n    \"        if acc_index[j]:\\n\",\r\n    \"            configs[j] = new_configs[j]\\n\",\r\n    \"            costs[j] = new_costs[j]\\n\",\r\n    \"\\n\",\r\n    \"configs, costs\"\r\n   ]\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": []\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": []\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": []\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": [\r\n    \"# Scratch code below\"\r\n   ]\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": []\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": []\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": []\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": []\r\n  },\r\n  {\r\n   \"cell_type\": \"code\",\r\n   \"execution_count\": null,\r\n   \"metadata\": {},\r\n   \"outputs\": [],\r\n   \"source\": [\r\n    \"def placement_reachable(M, N, m, n, s_joint):\\n\",\r\n    \"    #horizontal_tile = np.asarray(list(range(m * n))).reshape((m, n))\\n\",\r\n    \"    #vertical_tile = np.transpose(horizontal_tile)\\n\",\r\n    \"    horizontal_tile = np.ones((m,n))\\n\",\r\n    \"    vertical_tile = np.ones((n,m))\\n\",\r\n    \"    vertical_tile[0] = 0\\n\",\r\n    \"    \\n\",\r\n    \"    t = True\\n\",\r\n    \"    i = 0\\n\",\r\n    \"    while i < N:\\n\",\r\n    \"        match = False\\n\",\r\n    \"        # Check whether horizontal \\n\",\r\n    \"        if i <= N - n:\\n\",\r\n    \"            for j in range(n-m, n):\\n\",\r\n    \"                #print(s_joint[j:, i:i+n])\\n\",\r\n    \"                match_height = n-j\\n\",\r\n    \"   #            print(match_height)\\n\",\r\n    \"                if (s_joint[j:, i:i+n] == horizontal_tile[:match_height,:]).all():\\n\",\r\n    \"   #                 print(i, j, \\\"h\\\", s_joint[j:, i:i+n], horizontal_tile[:match_height,:], match_height)\\n\",\r\n    \"                    i += n\\n\",\r\n    \"                    if j != n-m:\\n\",\r\n    \"                        t = False\\n\",\r\n    \"                    match = True\\n\",\r\n    \"                    break\\n\",\r\n    \"        \\n\",\r\n    \"        if i <= N - m:\\n\",\r\n    \"            for j in range(n):\\n\",\r\n    \"                #print(s_joint,j,i,m, s_joint[j:, i:i+m])\\n\",\r\n    \"                match_height = n-j\\n\",\r\n    \"                if (s_joint[j:, i:i+m] == vertical_tile[:match_height,:]).all():\\n\",\r\n    \"  #                  print(i, j, \\\"v\\\", s_joint[j:, i:i+n], vertical_tile[:match_height,:], match_height)\\n\",\r\n    \"                    i += m\\n\",\r\n    \"                    if j != 0:\\n\",\r\n    \"                        t = False\\n\",\r\n    \"                    match = True\\n\",\r\n    \"                    break\\n\",\r\n    \"        \\n\",\r\n    \"        if not match:\\n\",\r\n    \"            return False, _\\n\",\r\n    \"    return True, t\\n\",\r\n    \"\\n\",\r\n    \"# ! Always assume m < n\\n\",\r\n    \"def init(M, N, m, n, s_array):\\n\",\r\n    \"    h, w = s_array.shape\\n\",\r\n    \"    checked = np.zeros((h, w))\\n\",\r\n    \"    i = 0\\n\",\r\n    \"    j = 0\\n\",\r\n    \"#     horizontal_tile = np.asarray(list(range(m * n))).reshape((m, n))\\n\",\r\n    \"#     vertical_tile = np.transpose(horizontal_tile)\\n\",\r\n    \"    horizontal_tile = np.ones((m,n))\\n\",\r\n    \"    vertical_tile = np.ones((n,m))\\n\",\r\n    \"    vertical_tile[0] = 0\\n\",\r\n    \"    \\n\",\r\n    \"    \\n\",\r\n    \"    #print(s_array)\\n\",\r\n    \"    terminate = True\\n\",\r\n    \"    for i in range(h):\\n\",\r\n    \"        for j in range(w):\\n\",\r\n    \"            if checked[i][j] == 1:\\n\",\r\n    \"                continue\\n\",\r\n    \"                \\n\",\r\n    \"            # Check horizontal\\n\",\r\n    \"            if i <= M - m and j <= N - n:\\n\",\r\n    \"                match_height = min(h-i, m)\\n\",\r\n    \"                if (s_array[i:i+match_height, j:j+n] == horizontal_tile[:match_height,:]).all() and (checked[i:i+m, j:j+n] != 1).all():\\n\",\r\n    \"                    checked[i:i + m, j: j + n] = 1\\n\",\r\n    \"                    if match_height != m:\\n\",\r\n    \"                        terminate = False\\n\",\r\n    \"                    continue\\n\",\r\n    \"            \\n\",\r\n    \"            # Check vertical\\n\",\r\n    \"            if i <= M - n and j <= N - m:\\n\",\r\n    \"                match_height = min(h-i, n)\\n\",\r\n    \"                if (s_array[i:i+match_height, j:j+m] == vertical_tile[:match_height,:]).all() and (checked[i:i+n, j:j+m] != 1).all():\\n\",\r\n    \"                    checked[i:i + n, j: j + m] = 1\\n\",\r\n    \"                    if match_height != n:\\n\",\r\n    \"                        terminate = False\\n\",\r\n    \"                    continue\\n\",\r\n    \"            #print(i, j, s_array, checked)\\n\",\r\n    \"            return False, _\\n\",\r\n    \"    return True, terminate\\n\",\r\n    \"        \\n\",\r\n    \"# returns all possible pipe group configurations\\n\",\r\n    \"def generate_placement(grid, len_1, len_2):\\n\",\r\n    \"    tot_len = len_1 * len_2\\n\",\r\n    \"    # possible configuration number for a row\\n\",\r\n    \"    from itertools import product\\n\",\r\n    \"    #possible_s = list(product(range(tot_len),repeat = grid.shape[1]*(len_2-1)))\\n\",\r\n    \"    #single_possible_s = list(product(list(range(tot_len)),repeat = grid.shape[1]))\\n\",\r\n    \"    \\n\",\r\n    \"    possible_s = list(product(range(2),repeat = grid.shape[1]*(len_2-1)))\\n\",\r\n    \"    single_possible_s = list(product(list(range(2)),repeat = grid.shape[1]))\\n\",\r\n    \"    \\n\",\r\n    \"    #print(possible_s, single_possible_s)\\n\",\r\n    \"    for i in range(len(possible_s)):\\n\",\r\n    \"        possible_s[i] = np.asarray(list(possible_s[i])).reshape(1,-1)\\n\",\r\n    \"    \\n\",\r\n    \"    for i in range(len(single_possible_s)):\\n\",\r\n    \"        single_possible_s[i] = np.asarray(list(single_possible_s[i])).reshape(1,-1)\\n\",\r\n    \"    \\n\",\r\n    \"    \\n\",\r\n    \"    # the solution will be the union of all possible configurations in the last row\\n\",\r\n    \"    dp = [[None for j in range(len(possible_s))] for i in range(grid.shape[0])]\\n\",\r\n    \"    \\n\",\r\n    \"    # initialize the first (len_1 -1) row\\n\",\r\n    \"    for s_index in range(len(possible_s)):\\n\",\r\n    \"        valid, terminate = init(grid.shape[0], grid.shape[1], len_1, len_2, possible_s[s_index].reshape(-1, grid.shape[1]))\\n\",\r\n    \"        if valid:\\n\",\r\n    \"            dp[0][s_index] =  [(possible_s[s_index].reshape(-1, grid.shape[1]), terminate)]\\n\",\r\n    \"            #print(possible_s[s_index])\\n\",\r\n    \"    print(dp[0])\\n\",\r\n    \"    # dp by row index\\n\",\r\n    \"    for i in range(len_2-1, grid.shape[0]):\\n\",\r\n    \"        print(\\\" \\\")\\n\",\r\n    \"        print(dp[i-1], i)\\n\",\r\n    \"        print(\\\" \\\")\\n\",\r\n    \"        # iterate through all possibly reachable row?\\n\",\r\n    \"        #j = i - 1\\n\",\r\n    \"        for s_index_1 in range(len(possible_s)):\\n\",\r\n    \"      #      print(\\\"haha\\\", s_index_1, len(possible_s))\\n\",\r\n    \"            for s_index_2 in range(len(single_possible_s)):\\n\",\r\n    \"                s_1 = possible_s[s_index_1]\\n\",\r\n    \"                s_2 = single_possible_s[s_index_2]\\n\",\r\n    \"               # print(s_1, s_2)\\n\",\r\n    \"                s_joint = np.concatenate((s_1, s_2), axis=0)\\n\",\r\n    \"                # early return if the last rows themselves are not possible\\n\",\r\n    \"                #print(s_joint, valid)\\n\",\r\n    \"                if dp[i-1][s_index_1] is None:\\n\",\r\n    \"                    print(i-1, s_index_1)\\n\",\r\n    \"                    continue\\n\",\r\n    \"                    \\n\",\r\n    \"                #valid, terminate =  placement_reachable(grid.shape[0], grid.shape[1], len_1, len_2, s_joint)\\n\",\r\n    \"                #valid, terminate =  init(grid.shape[0], grid.shape[1], len_1, len_2, s_joint)\\n\",\r\n    \"                valid, terminate =  placement_reachable(grid.shape[0], grid.shape[1], len_1, len_2, s_joint)\\n\",\r\n    \"     #           print(s_joint, valid)\\n\",\r\n    \"                if valid:\\n\",\r\n    \"                    if dp[i][s_index_2] is None:\\n\",\r\n    \"                        dp[i][s_index_2] = []\\n\",\r\n    \"                    for solution in dp[i-1][s_index_1]:\\n\",\r\n    \"                        #print(i-1,solution)\\n\",\r\n    \"                        sol, _ = solution\\n\",\r\n    \"                        s_joint_sol = np.concatenate((copy.deepcopy(sol), s_2), axis=0)\\n\",\r\n    \"                        dp[i][s_index_2].append((s_joint_sol, terminate))\\n\",\r\n    \"#     print(dp[0])\\n\",\r\n    \"#     print(dp[1])\\n\",\r\n    \"#     print(dp[2])\\n\",\r\n    \"    ret_sol = []\\n\",\r\n    \"    for i in range(len(single_possible_s)):\\n\",\r\n    \"        s = possible_s[i]\\n\",\r\n    \"        if dp[grid.shape[0]-1][i] is None:\\n\",\r\n    \"            continue\\n\",\r\n    \"        for (sol, t) in dp[grid.shape[0]-1][i]:\\n\",\r\n    \"            if t:\\n\",\r\n    \"                ret_sol.append(sol)\\n\",\r\n    \"    return ret_sol\\n\",\r\n    \"\\n\",\r\n    \"# for len_1 in factors:\\n\",\r\n    \"#     # Genarate all possible configuratinos\\n\",\r\n    \"#     remain = num_gpu / len_1\\n\",\r\n    \"#     factors_2 = []\\n\",\r\n    \"#     for i in range(1, min(cluster_shape) + 1):\\n\",\r\n    \"#         if remain % i == 0:\\n\",\r\n    \"#             factors_2.append(i)\\n\",\r\n    \"#         for len_2 in factors_2:\\n\",\r\n    \"#             num_cut = num_gpu / (len_1*len_2)\\n\",\r\n    \"#             confs = generate_placement(grid, len_1. len_2)\\n\",\r\n    \"#             for conf in confs:\\n\",\r\n    \"#                 cost_c = get_cost_c(conf)\\n\",\r\n    \"#                 cost_e = get_cost_e(conf)\\n\",\r\n    \"#                 opt_pipe = pipe_dp(L, cost_e, cost_c, num_cut, B)\\n\",\r\n    \"#                 cost = amp_simulator(conf, opt_pipe)\"\r\n   ]\r\n  }\r\n ],\r\n \"metadata\": {\r\n  \"kernelspec\": {\r\n   \"display_name\": \"Python 3\",\r\n   \"language\": \"python\",\r\n   \"name\": \"python3\"\r\n  },\r\n  \"language_info\": {\r\n   \"codemirror_mode\": {\r\n    \"name\": \"ipython\",\r\n    \"version\": 3\r\n   },\r\n   \"file_extension\": \".py\",\r\n   \"mimetype\": \"text/x-python\",\r\n   \"name\": \"python\",\r\n   \"nbconvert_exporter\": \"python\",\r\n   \"pygments_lexer\": \"ipython3\",\r\n   \"version\": \"3.8.3\"\r\n  }\r\n },\r\n \"nbformat\": 4,\r\n \"nbformat_minor\": 2\r\n}"
  },
  {
    "path": "playground/pipeline/profile_compilation.py",
    "content": "import numpy as np\nfrom time import time\nfrom flax import linen as nn, optim\nimport jax\nfrom jax._src.api import make_jaxpr\nimport jax.numpy as jnp\nimport ray\n\nfrom alpa import DeviceCluster, manual_layer_slicing, mark_pipeline\nfrom alpa.device_mesh import VirtualPhysicalMesh\nfrom alpa.model.bert_model import BertConfig, FlaxBertLayer\nfrom alpa.pipeline_parallel.three_d_parallel import (\n    split_compute_grad_and_apply_grad, slice_closed_jaxpr_by_full_pipeline_marks,\n    mark_missing_vars_in_backward_computation_pipeline_marks)\nfrom alpa.pipeline_parallel.stage_construction import get_submesh_choices, dp, get_sliced_virtual_submeshes, get_compute_cost, get_stage_and_mesh_assignments\n\nray.init(address=\"auto\")\njax.config.update('jax_platform_name', 'cpu')\nvirtual_mesh = DeviceCluster().get_virtual_physical_mesh()\n\n\nN = 10\nclass BertLayer_Model(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.layers = [FlaxBertLayer(config=self.config, dtype=self.dtype) for _ in range(N)]\n\n    def __call__(self, x, attention_mask):\n        for i in range(N):\n            mark_pipeline(name=str(i), mark_type='start')\n            layer_outputs = self.layers[i](x, attention_mask)\n            x = layer_outputs[0]\n            if i != N - 1:\n                mark_pipeline(name=str(i), mark_type='end')\n        return x\n\n\ndef train_step(optimizer, batch, apply_fn):\n\n    @manual_layer_slicing\n    def loss_func(params, x, y, attention_mask):\n        out = apply_fn(params, x, attention_mask)\n        loss = jnp.mean((out - y)**2)\n        mark_pipeline(name=str(N - 1), mark_type='end')\n        return loss\n\n    grad_param = jax.grad(loss_func)(optimizer.target, batch['x'], batch['y'],\n                                     batch['attention_mask'])\n\n    # new_optimizer = optimizer.apply_gradient(grad_param)\n    return grad_param\n\n\nbatch_size = 4\nseq_len = 64\nhidden_size = 256\nnum_heads = 1\nx = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32)\ny = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) * 23\nattention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32)\n\nmodel = BertLayer_Model(config=BertConfig(hidden_size=hidden_size,\n                                          intermediate_size=hidden_size * 4,\n                                          num_attention_heads=num_heads))\nrngkey = jax.random.PRNGKey(0)\nparams = model.init(rngkey, x, attention_mask)\noptimizer = optim.GradientDescent(1e-2).create(params)\nbatch = {\"x\": x, \"y\": y, \"attention_mask\": attention_mask}\n\n\norigin_jaxpr = make_jaxpr(train_step, static_argnums=(2,))(optimizer, batch,\n                                                           model.apply)\ncompute_jaxpr, _, _ = split_compute_grad_and_apply_grad(origin_jaxpr)\nstages = slice_closed_jaxpr_by_full_pipeline_marks(compute_jaxpr)\nstages = mark_missing_vars_in_backward_computation_pipeline_marks(stages, compute_jaxpr.jaxpr.invars,\n                                                                  compute_jaxpr.jaxpr.outvars)\n\n\ndonation_mapping = {}\nglobal_invars = compute_jaxpr.jaxpr.invars\nglobal_outvars = compute_jaxpr.jaxpr.outvars\nall_invars = [set(stage.invars) for stage in stages]\nprint(compute_jaxpr)\nprint(all_invars)\n\nvirtual_mesh = DeviceCluster().get_virtual_physical_mesh()\n\nsubmesh_choices = get_submesh_choices(virtual_mesh)\n\nM = len(submesh_choices)\ncompute_cost = np.full((N, N, M), np.inf)\n\ncompute_cost = get_compute_cost(virtual_mesh, submesh_choices, stages, donation_mapping, global_outvars)\n\nprint(\"profiled compute cost\", compute_cost)\n\ncompute_cost = np.array(\n[[[0.00112862, 0.00207896, 0.00304582, 0.00409389, 0.00481757, 0.0058842 , 0.00729934, 0.00901646, 0.01083485, 0.01064126],\n [    np.inf, 0.00105063, 0.00192263, 0.00338936, 0.00393539, 0.00490199, 0.00584266, 0.0072612 , 0.00946384, 0.01016763],\n [    np.inf,     np.inf, 0.00129975, 0.00242482, 0.00291726, 0.00394379, 0.00500327, 0.00620286, 0.0075642 , 0.00776463],\n [    np.inf,     np.inf,     np.inf, 0.00107974, 0.00194375, 0.00296365, 0.00394927, 0.00489317, 0.0060268 , 0.00686378],\n [    np.inf,     np.inf,     np.inf,     np.inf, 0.00113273, 0.00208476, 0.00312124, 0.00414051, 0.00488673, 0.00603056],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00115853, 0.00214725, 0.00309205, 0.00406925, 0.00486824],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.0011634 , 0.00212847, 0.00300874, 0.00403778],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00113964, 0.00209594, 0.00295475],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00112536, 0.00208275],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00113214],],\n[[0.0030249 , 0.00583315, 0.00871592, 0.01152415, 0.01424082, 0.01615058, 0.01970495, 0.02182685, 0.02624578, 0.02759846],\n [    np.inf, 0.00283125, 0.00541072, 0.00810671, 0.0113883 , 0.0142146 , 0.01630463, 0.01949045, 0.02265135, 0.02431562],\n [    np.inf,     np.inf, 0.00275834, 0.00543684, 0.00856792, 0.01125206, 0.01419446, 0.01846258, 0.01882169, 0.02256897],\n [    np.inf,     np.inf,     np.inf, 0.00282031, 0.00544018, 0.00806549, 0.01151021, 0.01445823, 0.01596944, 0.01954889],\n [    np.inf,     np.inf,     np.inf,     np.inf, 0.00288251, 0.00546715, 0.00849128, 0.01137638, 0.01331025, 0.01597357],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00281795, 0.00563383, 0.00851236, 0.01133339, 0.01377805],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.0027566 , 0.00544667, 0.00806091, 0.01041269],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00283482, 0.00553597, 0.00840436],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00294116, 0.00520253],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00248777],],\n[[0.00318106, 0.00561643, 0.00816067, 0.01074386, 0.01330863, 0.01584069, 0.01861776, 0.02112714, 0.02398107, 0.02674866],\n [    np.inf, 0.00313836, 0.00568464, 0.00836942, 0.01092143, 0.01332755, 0.015868  , 0.01875334, 0.0215208 , 0.02460371],\n [    np.inf,     np.inf, 0.00307181, 0.00560925, 0.00822319, 0.01079559, 0.01324073, 0.0162802 , 0.01885197, 0.02085225],\n [    np.inf,     np.inf,     np.inf, 0.00309396, 0.00569873, 0.00842341, 0.01113261, 0.01343475, 0.01580254, 0.01800921],\n [    np.inf,     np.inf,     np.inf,     np.inf, 0.00313062, 0.00563579, 0.00816891, 0.01091221, 0.01354008, 0.01555475],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00304008, 0.00569354, 0.00829389, 0.01103203, 0.01338752],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00318387, 0.00579458, 0.00826253, 0.01069681],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00314818, 0.00580152, 0.00824009],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00310455, 0.005536  ],\n [    np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf,     np.inf, 0.00285437],]]\n).transpose((1, 2, 0))\n\nprint(compute_cost.shape, (N, N, M))\nprint(\"previously tested compute cost\", compute_cost)\n\ncost, solution = dp(N, virtual_mesh.num_devices, batch_size, submesh_choices, compute_cost)\nprint(\"-\" * 30, \"Solution\", \"-\" * 30)\nprint(\"Cost:\", cost)\nprint(solution)\n\nsliced_meshes = get_sliced_virtual_submeshes(virtual_mesh, submesh_choices, solution)\nprint(\"sliced_meshes\", sliced_meshes)\n\nsolution, sliced_meshes = get_stage_and_mesh_assignments(virtual_mesh, stages, donation_mapping, global_outvars, batch_size)\nprint(\"solution, sliced_meshes\", solution, sliced_meshes)\n\nray.shutdown()\n"
  },
  {
    "path": "playground/pipeline/test_acc_grad.py",
    "content": "import jax\nfrom jax import jit, grad, tree_flatten\nfrom jax._src.api import make_jaxpr\nfrom jax.core import DropVar, jaxpr_as_fun, gensym\nimport jax.numpy as jnp\nimport numpy as np\n\nimport alpa\nfrom alpa.pipeline_parallel.manual_layer_slicing import manual_layer_slicing\nfrom alpa.pipeline_parallel.computation import (\n    apply_grad_add_marker, compute_grad_to_accumulate_grad, apply_grad_get_mean,\n    get_var_mapping, slice_closed_jaxpr_by_full_pipeline_marks,\n    mark_missing_vars_in_backward_computation_pipeline_marks, mark_gradvar_to_mesh, slice_apply_gradient,\n    replace_all_with)\nfrom alpa.pipeline_parallel.three_d_parallel import split_compute_grad_and_apply_grad, split_donate_invars\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline\n\nfrom flax import linen as nn, optim\n\nfrom copy import copy\n\n\nclass MLP_Model(nn.Module):\n    hidden_dim: int\n    output_dim: int\n\n    @nn.compact\n    def __call__(self, x):\n        mark_pipeline(name='1', mark_type='start')\n        x = nn.Dense(features=self.hidden_dim, use_bias=False)(x)\n        x = nn.relu(x)\n        mark_pipeline(name='1', mark_type='end')\n        mark_pipeline(name='2', mark_type='start')\n        x = nn.Dense(features=self.output_dim, use_bias=False)(x)\n        return x\n\n\nbatch_size = 4\nhidden_dim = 3\ninput_dim = output_dim = hidden_dim\nmodel = MLP_Model(hidden_dim=hidden_dim, output_dim=output_dim)\nx = jnp.array(np.random.rand(batch_size, output_dim))\ny = jnp.array(np.random.rand(batch_size, output_dim))\nrngkey = jax.random.PRNGKey(0)\nparams = model.init(rngkey, x)\noptimizer = optim.GradientDescent(1e-2).create(params)\nbatch = {\"x\": x, \"y\": y}\ngrad_in_to_out = None\n\n\n@manual_layer_slicing\ndef loss_func(params, x, y):\n    out = model.apply(params, x)\n    loss = jnp.mean((out - y)**2)\n    mark_pipeline(name='2', mark_type='end')\n    return loss\n\n\ndef train_step(optimizer, batch):\n    grad_param, _x, _y = alpa.grad(loss_func,\n                                    argnums=(0, 1, 2))(optimizer.target,\n                                                       batch['x'], batch['y'])\n    new_optimizer = optimizer.apply_gradient(grad_param)\n    return new_optimizer\n\n\ndef test_compute_to_accumulate():\n    compute_grad = grad(loss_func, argnums=(0, 1, 2))\n    params = optimizer.target\n    compute_grad_jaxpr = make_jaxpr(compute_grad)(params, x, y)\n    gensym_fn = gensym([compute_grad_jaxpr.jaxpr])\n    flatten_args, _ = tree_flatten((params, x, y))\n    reduction_vector = [True] * len(compute_grad_jaxpr.jaxpr.outvars)\n    acc_grad_jaxpr, grad_outs, _ = compute_grad_to_accumulate_grad(compute_grad_jaxpr,\n                                                                   reduction_vector,\n                                                                   gensym_fn)\n    grad_zeros = [jnp.zeros_like(val) for val in acc_grad_jaxpr.out_avals]\n    # donate_argnums = [\n    #     i for i in range(len(donated_invars)) if donated_invars[i]\n    # ]\n    args = params, x, y\n    new_args = flatten_args + grad_zeros\n    jitted_fn = jit(jaxpr_as_fun(acc_grad_jaxpr))\n    outs = jitted_fn(*new_args)\n\n    new_args = flatten_args + list(outs)\n    double_outs = jitted_fn(*new_args)\n\n    correct = map(lambda x: 2 * x, tree_flatten(compute_grad(*args))[0])\n    for test, corr in zip(double_outs, correct):\n        assert jnp.allclose(test, corr)\n\n\ndef get_invals_from_env(closed_jaxpr, env, batch_num=0):\n    vars = closed_jaxpr.jaxpr.invars\n    if batch_num == 0:\n        return [env[batch_num][repr(var)] for var in vars]\n    vals = []\n    for var in vars:\n        if var in grad_in_to_out:\n            vals.append(env[batch_num - 1][grad_in_to_out[var]])\n        else:\n            vals.append(env[batch_num][repr(var)])\n    return vals\n\n\ndef get_vals_from_env(vars, env, batch_num=0):\n    return [env[batch_num][repr(var)] for var in vars]\n\n\ndef record_values(vars, avals, env, batch_num=0):\n    for var, aval in zip(vars, avals):\n        if isinstance(var, DropVar):\n            continue\n        key = repr(var)\n        if key in env[batch_num]:\n            assert jnp.allclose(env[batch_num][key], aval)\n        env[batch_num][key] = aval\n\n\ndef get_and_set(closed_jaxpr, env, batch_num=0, donate_argnums=()):\n    outs = jax.jit(jaxpr_as_fun(closed_jaxpr), donate_argnums=donate_argnums)(\n        *get_invals_from_env(closed_jaxpr, env, batch_num))\n    record_values(closed_jaxpr.jaxpr.outvars, outs, env, batch_num)\n\n\ndef test_compute_and_apply_basic():\n    closed_jaxpr = make_jaxpr(train_step)(optimizer, batch)\n    gensym_func = gensym([closed_jaxpr.jaxpr])\n    compute_grad_jaxpr, old_apply_grad_jaxpr, barrier = split_compute_grad_and_apply_grad(\n        closed_jaxpr)\n    # compute grad to accumulate grad\n    reduction_vector = [True] * len(compute_grad_jaxpr.jaxpr.outvars)\n    acc_grad_jaxpr, acc_grad_dict, _ = compute_grad_to_accumulate_grad(\n        compute_grad_jaxpr, reduction_vector, gensym_func)\n    # apply-grad\n    mask = {\n        outv: acc_grad_dict[inv]\n        for outv, inv in zip(barrier.outvars, barrier.invars)\n        if (not isinstance(outv, DropVar) and\n            outv in old_apply_grad_jaxpr.jaxpr.invars)\n    }\n    # change invars of apply grad to output of accumulate grad\n    apply_grad_jaxpr = replace_all_with(old_apply_grad_jaxpr, mask)\n\n    # Simulation:\n    # correct result:\n    args, _ = tree_flatten((optimizer, batch))\n    env = [dict()]\n    record_values(closed_jaxpr.jaxpr.invars, args, env)\n    correct = jaxpr_as_fun(closed_jaxpr)(\n        *get_invals_from_env(closed_jaxpr, env))\n    # Test 1: split compute and apply\n    env_1 = copy(env)\n    get_and_set(compute_grad_jaxpr, env_1)\n    for inv, outv in zip(barrier.invars, barrier.outvars):\n        if isinstance(outv, DropVar):\n            continue\n        key = repr(inv)\n        if key in env_1[0]:\n            env_1[0][repr(outv)] = env_1[0][key]\n    get_and_set(old_apply_grad_jaxpr, env_1)\n    outs = get_vals_from_env(closed_jaxpr.jaxpr.outvars, env_1)\n    for t, c in zip(outs, correct):\n        assert jnp.allclose(t, c)\n    del env_1\n    # Test 2: accumulate and apply\n    env_2 = copy(env)\n    grad_num = len(acc_grad_jaxpr.out_avals)\n    grad_invars = set(acc_grad_jaxpr.jaxpr.invars[-1 * grad_num:])\n    for inv in acc_grad_jaxpr.jaxpr.invars:\n        key = repr(inv)\n        if key not in env_2[0]:\n            assert inv in grad_invars\n            env_2[0][key] = jnp.zeros_like(inv.aval)\n    get_and_set(acc_grad_jaxpr, env_2)\n    get_and_set(apply_grad_jaxpr, env_2)\n    outs = get_vals_from_env(closed_jaxpr.jaxpr.outvars, env_2)\n    for t, c in zip(outs, correct):\n        assert jnp.allclose(t, c)\n\n\ndef donate_invars_to_argnums(donate_invars):\n    return [i for i, d in enumerate(donate_invars) if d]\n\n\ndef test_compute_and_apply(microbatches):\n    closed_jaxpr = make_jaxpr(train_step)(optimizer, batch)\n    gensym_func = gensym([closed_jaxpr.jaxpr])\n    compute_grad_jaxpr, apply_grad_jaxpr, barrier = split_compute_grad_and_apply_grad(\n        closed_jaxpr)\n    # compute grad to accumulate grad\n    global grad_in_to_out\n    reduction_vector = [True] * len(compute_grad_jaxpr.jaxpr.outvars)\n    acc_grad_jaxpr, acc_grad_dict, grad_glob_in = compute_grad_to_accumulate_grad(\n        compute_grad_jaxpr, reduction_vector, gensym_func)\n    grad_in_to_out = grad_glob_in\n    # slice accumulate grad\n    acc_invars = acc_grad_jaxpr.jaxpr.invars\n    acc_outvars = acc_grad_jaxpr.jaxpr.outvars\n    jax_pipeline_stages = slice_closed_jaxpr_by_full_pipeline_marks(\n        acc_grad_jaxpr)\n    jax_pipeline_stages = mark_missing_vars_in_backward_computation_pipeline_marks(\n        jax_pipeline_stages, acc_invars, acc_outvars)\n    # delete the two lines below in auto mesh version\n    stage_num = len(jax_pipeline_stages)\n    assert stage_num % 2 == 0\n    stage_to_mesh = {\n        i: (i if i < stage_num / 2 else stage_num - i - 1)\n        for i, _ in enumerate(jax_pipeline_stages)\n    }\n    mesh_num = int(stage_num / 2)\n    # apply-grad\n    mask = {\n        outv: acc_grad_dict[inv]\n        for outv, inv in zip(barrier.outvars, barrier.invars)\n        if not isinstance(outv, DropVar)\n    }\n    # slice apply-grad stages\n    global_outvars = closed_jaxpr.jaxpr.outvars\n    grad_mesh = mark_gradvar_to_mesh(apply_grad_jaxpr.jaxpr.invars,\n                                     jax_pipeline_stages, stage_to_mesh, mask)\n    gradients = [g for g in barrier.outvars if not isinstance(g, DropVar)]\n    apply_grad_jaxpr, global_outvars = apply_grad_get_mean(apply_grad_jaxpr,\n                                                       gradients,\n                                                       gensym_func,\n                                                       microbatches,\n                                                       global_outvars)\n    sliced_apply_grad, _ = slice_apply_gradient(apply_grad_jaxpr, grad_mesh,\n                                                mesh_num)\n    sliced_apply_grad, outvar_map = apply_grad_add_marker(sliced_apply_grad,\n                                                          mask,\n                                                          gensym_func,\n                                                          computation=True)\n    global_outvars = list(\n        map(lambda x: get_var_mapping(outvar_map, x), global_outvars))\n    # donate invars\n    donated_invars = (True, True, True, False, False)\n    slice_num = len(sliced_apply_grad)\n    grad_invars = list(grad_glob_in.keys())\n    all_invars = closed_jaxpr.jaxpr.invars + grad_invars\n    all_donation = donated_invars + (True,) * len(grad_glob_in)\n    jax_all_stages = jax_pipeline_stages + sliced_apply_grad\n    # forward, backward and apply gradient is serialized in a batch.\n    pattern = [[i, i + slice_num, i + slice_num * 2] for i in range(slice_num)]\n    donate_lists = split_donate_invars(all_donation, all_invars, jax_all_stages,\n                                       pattern)\n    pipe_donate = donate_lists[:slice_num * 2]\n    apply_donate = donate_lists[slice_num * 2:]\n    # Simulation:\n    # correct result:\n    args, _ = tree_flatten((optimizer, batch))\n    env = [dict()]\n    record_values(closed_jaxpr.jaxpr.invars, args, env)\n    correct = jaxpr_as_fun(closed_jaxpr)(\n        *get_invals_from_env(closed_jaxpr, env))\n    # Test 3: slices\n    # slices:\n    env = [dict() for _ in range(microbatches)]\n    non_split_args = tree_flatten(optimizer)[0]\n    to_split_args = tree_flatten(batch)[0]\n    # this is a rough simulator, so not actually split them but run m times instead\n    # split_args = map(lambda x: jnp.split(x, microbatches), to_split_args)\n    for b in range(microbatches):\n        args = non_split_args + to_split_args\n        record_values(closed_jaxpr.jaxpr.invars, args, env, b)\n    record_values(closed_jaxpr.jaxpr.invars, args, env)\n    env_3 = copy(env)\n    grad_num = len(acc_grad_jaxpr.out_avals)\n    grad_invars = set(acc_grad_jaxpr.jaxpr.invars[-1 * grad_num:])\n    for invar in acc_grad_jaxpr.jaxpr.invars:\n        key = repr(invar)\n        if key not in env_3[0]:\n            assert invar in grad_invars\n            env_3[0][key] = jnp.zeros_like(invar.aval)\n\n    for b in range(microbatches):\n        for i, stage in enumerate(jax_pipeline_stages):\n            get_and_set(stage.closed_jaxpr(), env_3, b)\n    # store results of apply grad into microbatches - 1\n    for i, stage in enumerate(sliced_apply_grad):\n        if stage.outvars:\n            get_and_set(stage.closed_jaxpr(), env_3, microbatches - 1)\n    outs = get_vals_from_env(global_outvars, env_3, microbatches - 1)\n    for t, c in zip(outs, correct):\n        assert jnp.allclose(t, c)\n    grad_in_to_out = None\n\n\ntest_compute_to_accumulate()\ntest_compute_and_apply_basic()\ntest_compute_and_apply(1)\ntest_compute_and_apply(4)"
  },
  {
    "path": "playground/pipeline/test_compile_and_profile.py",
    "content": "from flax import linen as nn, optim\nimport jax\nfrom jax._src.api import make_jaxpr\nimport jax.numpy as jnp\nimport ray\n\nfrom alpa import DeviceCluster, manual_layer_slicing, mark_pipeline\nfrom alpa.model.bert_model import BertConfig, FlaxBertLayer\n\n\nclass BertLayer_Model(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.layer0 = FlaxBertLayer(config=self.config, dtype=self.dtype)\n        self.layer1 = FlaxBertLayer(config=self.config, dtype=self.dtype)\n\n    def __call__(self, x, attention_mask):\n        mark_pipeline(name='1', mark_type='start')\n        layer_outputs = self.layer0(x, attention_mask)\n        x = layer_outputs[0]\n        mark_pipeline(name='1', mark_type='end')\n        mark_pipeline(name='2', mark_type='start')\n        layer_outputs = self.layer1(x, attention_mask)\n        x = layer_outputs[0]\n        return x\n\n\nray.init(address=\"auto\")\njax.config.update('jax_platform_name', 'cpu')\nvirtual_mesh = DeviceCluster().get_virtual_physical_mesh()\n\n\ndef train_step(optimizer, batch, apply_fn):\n\n    def loss_func(params, x, y, attention_mask):\n        out = apply_fn(params, x, attention_mask)\n        loss = jnp.mean((out - y)**2)\n        mark_pipeline(name='2', mark_type='end')\n        return loss\n\n    loss_func = manual_layer_slicing(loss_func)\n    grad_param = jax.grad(loss_func)(optimizer.target, batch['x'], batch['y'],\n                                     batch['attention_mask'])\n\n    # new_optimizer = optimizer.apply_gradient(grad_param)\n    return grad_param\n\n\nInc = 1\nbatch_size = 2 * Inc\nseq_len = 64 * Inc\nhidden_size = 256 * Inc\nnum_heads = 1\n\nx = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32)\ny = jnp.ones((batch_size, seq_len, hidden_size),\n             dtype=jnp.float32) * 23  # * np.arange(hidden_size)[None, None, :]\nattention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32)\n\n# Init model and optimizer\nmodel = BertLayer_Model(config=BertConfig(hidden_size=hidden_size,\n                                          intermediate_size=hidden_size * 4,\n                                          num_attention_heads=num_heads))\nrngkey = jax.random.PRNGKey(0)\nparams = model.init(rngkey, x, attention_mask)\noptimizer = optim.GradientDescent(1e-2).create(params)\nbatch = {\"x\": x, \"y\": y, \"attention_mask\": attention_mask}\n\norigin_jaxpr = make_jaxpr(train_step, static_argnums=(2,))(optimizer, batch,\n                                                           model.apply)\n\n\ndef dummy_large_trans(*args):\n\n    @manual_layer_slicing\n    def dummy_fwd(x, y, z, tgt):\n        mark_pipeline(name='1', mark_type='start')\n        out = x @ y\n        mark_pipeline(name='1', mark_type='end')\n        mark_pipeline(name='2', mark_type='start')\n        out = out @ z\n        loss = jnp.mean((out - tgt)**2)\n        mark_pipeline(name='2', mark_type='end')\n        return loss\n\n    grad = jax.grad(dummy_fwd)(*args)\n    return grad\n\n\nN = 16384\nargs = [jnp.zeros((N, N)) for _ in range(4)]\n\norigin_jaxpr = make_jaxpr(dummy_large_trans)(*args)\n\nfrom alpa.pipeline_parallel.three_d_parallel import (\n    split_compute_grad_and_apply_grad, slice_closed_jaxpr_by_full_pipeline_marks,\n    mark_missing_vars_in_backward_computation_pipeline_marks)\nfrom alpa.pipeline_parallel.stage_profiling import (\n    compile_and_profile_stage_compute_cost, create_collective_group,\n    profile_layer_communication_cost)\n\ncompute_jaxpr, _, _ = split_compute_grad_and_apply_grad(origin_jaxpr)\nstages = slice_closed_jaxpr_by_full_pipeline_marks(compute_jaxpr)\nstages = mark_missing_vars_in_backward_computation_pipeline_marks(stages, compute_jaxpr.jaxpr.invars,\n                                                                  compute_jaxpr.jaxpr.outvars)\n# for stage in stages:\n#     print(stage.closed_jaxpr())\n'''----------------profile cost c----------------'''\n# round = 1\n# physical_mesh = DeviceCluster().get_physical_mesh()\n# tn = \"compute1\"\n# timers(tn).start()\n# for t in range(round):\n#     print(compile_and_profile_stage_compute_cost((stages[0], stages[3]), physical_mesh)[0])\n# timers(tn).stop()\n# print(timers(tn).elapsed())\n# tn = \"compute2\"\n# timers(tn).start()\n# for t in range(round):\n#     print(compile_and_profile_stage_compute_cost((stages[1], stages[2]), physical_mesh)[0])\n# timers(tn).stop()\n# print(timers(tn).elapsed())\n'''----------------profile cost e----------------'''\nsrc = stages[0]\ndst = stages[1]\nsrc_mesh = virtual_mesh.slice_1d(1, [[0, 1]])\nsrc_phy_mesh = src_mesh.get_physical_mesh()\ndst_mesh = virtual_mesh.slice_1d(1, [[2, 3]])\ndst_phy_mesh = dst_mesh.get_physical_mesh()\n\n\ndef all_outvar(stages):\n    ret = set()\n    for stage in stages:\n        ret.update(stage.outvars)\n    return ret\n\n\ntest_stages = (stages[0], stages[3])\ncost_c1, _, out_spec = compile_and_profile_stage_compute_cost(\n    test_stages, src_phy_mesh, {}, all_outvar(test_stages))\ntest_stages = (stages[1], stages[2])\ncost_c2, in_spec, _ = compile_and_profile_stage_compute_cost(\n    test_stages, dst_phy_mesh, {}, all_outvar(test_stages))\n\n# print(cost_c1, cost_c2)\nsrc_phy_mesh.sync_workers()\ndst_phy_mesh.sync_workers()\ncollective_group = create_collective_group(src_phy_mesh, dst_phy_mesh)\n\ncost_e = profile_layer_communication_cost(stages[0], stages[1], out_spec[0],\n                                          in_spec[0], src_mesh, dst_mesh,\n                                          collective_group)\n\nprint(cost_e)\ncollective_group.destroy()\nsrc_phy_mesh.shutdown()\ndst_phy_mesh.shutdown()\nray.shutdown()\n\n# LnkCap: Port #2, Speed 8GT/s, Width x16, ASPM not supported, Exit Latency L0s <512ns, L1 <4us\n# LnkSta: Speed 2.5GT/s, Width x8, TrErr- Train- SlotClk+ DLActive- BWMgmt- ABWMgmt-"
  },
  {
    "path": "playground/pipeline/test_distributed_compile.py",
    "content": "from flax import linen as nn, optim\nimport jax\nfrom jax._src.api import make_jaxpr\nfrom jax.core import gensym\nimport jax.numpy as jnp\nfrom alpa.mesh_executable import NormalMeshDriverExecutable, ProtoAndSharding\nfrom alpa.pipeline_parallel.apply_grad import compute_grad_to_accumulate_grad\nimport ray\n\nfrom alpa import DeviceCluster, manual_layer_slicing, mark_pipeline\nfrom alpa.model.bert_model import BertConfig, FlaxBertLayer\nfrom alpa.pipeline_parallel.stage_profiling import (compile_all,\n                                                     generate_stage_info,\n                                                     split_global_use_and_donate)\nfrom alpa.pipeline_parallel.three_d_parallel import (\n    split_compute_grad_and_apply_grad, slice_closed_jaxpr_by_full_pipeline_marks,\n    mark_missing_vars_in_backward_computation_pipeline_marks)\n\nray.init(address=\"auto\")\njax.config.update('jax_platform_name', 'cpu')\nvirtual_mesh = DeviceCluster().get_virtual_physical_mesh()\n\nN = 10\n\n\nclass BertLayer_Model(nn.Module):\n    config: BertConfig\n    dtype: jnp.dtype = jnp.float32\n\n    def setup(self):\n        self.layers = [\n            FlaxBertLayer(config=self.config, dtype=self.dtype)\n            for _ in range(N)\n        ]\n\n    def __call__(self, x, attention_mask):\n        for i in range(N):\n            mark_pipeline(name=str(i), mark_type='start')\n            layer_outputs = self.layers[i](x, attention_mask)\n            x = layer_outputs[0]\n            if i != N - 1:\n                mark_pipeline(name=str(i), mark_type='end')\n        return x\n\n\ndef train_step(optimizer, batch, apply_fn):\n\n    def loss_func(params, x, y, attention_mask):\n        out = apply_fn(params, x, attention_mask)\n        loss = jnp.mean((out - y)**2)\n        mark_pipeline(name=str(N - 1), mark_type='end')\n        return loss\n\n    loss_func = manual_layer_slicing(loss_func)\n    grad_param = jax.grad(loss_func)(optimizer.target, batch['x'], batch['y'],\n                                     batch['attention_mask'])\n\n    # new_optimizer = optimizer.apply_gradient(grad_param)\n    return grad_param\n\n\nbatch_size = 4\nseq_len = 64\nhidden_size = 256\nnum_heads = 1\nx = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32)\ny = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32) * 23\nattention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.float32)\n\nmodel = BertLayer_Model(config=BertConfig(hidden_size=hidden_size,\n                                          intermediate_size=hidden_size * 4,\n                                          num_attention_heads=num_heads))\nrngkey = jax.random.PRNGKey(0)\nparams = model.init(rngkey, x, attention_mask)\noptimizer = optim.GradientDescent(1e-2).create(params)\nbatch = {\"x\": x, \"y\": y, \"attention_mask\": attention_mask}\n\norigin_jaxpr = make_jaxpr(train_step, static_argnums=(2,))(optimizer, batch,\n                                                           model.apply)\ncompute_jaxpr, _, _ = split_compute_grad_and_apply_grad(origin_jaxpr)\ngensym_fn = gensym([compute_jaxpr.jaxpr])\nreduction_vector = [True] * len(compute_jaxpr.jaxpr.outvars)\nacc_grad_jaxpr, acc_grad_dict, grad_in_to_out = compute_grad_to_accumulate_grad(\n    compute_jaxpr, reduction_vector, gensym_fn)\n\nstages = slice_closed_jaxpr_by_full_pipeline_marks(acc_grad_jaxpr)\nstages = mark_missing_vars_in_backward_computation_pipeline_marks(stages,\n                                                                  acc_grad_jaxpr.jaxpr.invars,\n                                                                  acc_grad_jaxpr.jaxpr.outvars)\n\ndonated_global_invars = compute_jaxpr.jaxpr.invars[:-2]\nglobal_invars = acc_grad_jaxpr.jaxpr.invars\nglobal_outvars = acc_grad_jaxpr.jaxpr.outvars\nglobal_donation_mapping = dict()\n\nnum_layer_per_stage = 2\nstage_infos = []\nfor start in range(0, N, int(2 * N / num_layer_per_stage)):\n    stop = start + num_layer_per_stage\n    indices = list(range(start, stop))\n    donation_mapping, global_used, new_layers = split_global_use_and_donate(\n        stages, indices, global_donation_mapping, global_outvars)\n    stage_info = generate_stage_info(stages, indices, donation_mapping,\n                                   global_used, str(start))\n    stage_infos.append(stage_info)\n\ncompiled_outputs = compile_all(stage_infos,\n                               virtual_mesh.get_default_logical_mesh(), 16, 4)\nphysical_mesh = virtual_mesh.get_physical_mesh()\nfor compiled_output, stage_info in zip(compiled_outputs, stage_infos):\n    _, avals, out_avals, tot_donation = stage_info\n    proto, config, in_shardings, out_shardings = compiled_output\n    compiled = ProtoAndSharding(proto=proto,\n                                input_shardings=in_shardings,\n                                output_shardings=out_shardings)\n    donated_invars = (True,) * len(tot_donation) + (False,) * (\n        len(avals) - len(tot_donation))\n    executable = NormalMeshDriverExecutable(physical_mesh, compiled, config,\n                                            avals, out_avals, donated_invars)\n    executable.profile_with_dummy_inputs()\n"
  },
  {
    "path": "playground/pipeline/test_generate_schedule.py",
    "content": "\"\"\"Experimental code to generate a Gpipe clock-cycle schedule.\"\"\"\nimport numpy as np\n\n\ndef generate_gpipe_schedule(m, n):\n    num_clock = m + n - 1\n    schedules = []\n    for k in range(num_clock):\n        scheds = [None] * n\n        for d in range(max(1 + k - m, 0), min(1 + k, n)):\n            scheds[d] = (k - d, d)\n        schedules.append(scheds)\n\n    def reverse(scheds):\n        reversed = []\n        for task in scheds:\n            if not task:\n                reversed.append(None)\n            else:\n                reversed.append((m - 1 - task[0], 2 * n - 1 - task[1]))\n        return reversed\n\n    # backward schedules\n    for k in range(num_clock):\n        mapped_scheds = schedules[num_clock - k - 1]\n        schedules.append(reverse(mapped_scheds))\n    return schedules\n\n\ndef generate_1f1b_schedule(m, n):\n    # equal to gpipe\n    num_clock = (m + n - 1) * 2\n    schedules = [[None] * n for k in range(num_clock)]\n\n    num_warmup_microbatches = [min(n - i - 1, m) for i in range(n)]\n    num_microbatches_remaining = [m - i for i in num_warmup_microbatches]\n\n    next_fwd_mb_idx = [0 for _ in range(n)]\n    next_bwd_mb_idx = [0 for _ in range(n)]\n    next_available_clock = [i for i in range(n)]\n    finished_bwd_batch_indices = np.zeros(shape=[num_clock, n], dtype=np.int32)\n\n    # warm-up clocks\n    for i in range(n):\n        for j in range(num_warmup_microbatches[i]):\n            schedules[next_available_clock[i]][i] = (next_fwd_mb_idx[i], i)\n            next_available_clock[i] = next_available_clock[i] + 1\n            next_fwd_mb_idx[i] = next_fwd_mb_idx[i] + 1\n\n    # run 1F1B\n    for i in reversed(range(n)):\n        # from the last device to the first\n        for j in range(num_microbatches_remaining[i]):\n            # running through all the remaining microbatches\n            # forward\n            next_clock = next_available_clock[i]\n            schedules[next_clock][i] = (next_fwd_mb_idx[i], i)\n            next_fwd_mb_idx[i] = next_fwd_mb_idx[i] + 1\n            finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i]\n            next_clock = next_clock + 1\n\n            # backward\n            # first, offset the next available clock to the clock\n            # when the previous stage has just finished backward of the target mb.\n            if i + 1 < n:  # not the last device\n                # find the next possible backward clock\n                while finished_bwd_batch_indices[next_clock][i + 1] <= next_bwd_mb_idx[i]:\n                    assert finished_bwd_batch_indices[next_clock - 1][i] == next_bwd_mb_idx[i]\n                    finished_bwd_batch_indices[next_clock][i] = finished_bwd_batch_indices[next_clock - 1][i]\n                    next_clock = next_clock + 1\n\n            schedules[next_clock][i] = (next_bwd_mb_idx[i], 2 * n - 1 - i)\n            finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i]\n            next_bwd_mb_idx[i] = next_bwd_mb_idx[i] + 1\n            next_available_clock[i] = next_clock + 1\n\n    # run cooldown passes\n    for i in reversed(range(n)):\n        for j in range(num_warmup_microbatches[i]):\n            assert i + 1 < n\n            next_clock = next_available_clock[i]\n            while finished_bwd_batch_indices[next_clock][i + 1] <= next_bwd_mb_idx[i]:\n                finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i]\n                next_clock = next_clock + 1\n            schedules[next_clock][i] = (next_bwd_mb_idx[i], 2 * n- 1 - i)\n            finished_bwd_batch_indices[next_clock][i] = next_bwd_mb_idx[i]\n            next_bwd_mb_idx[i] = next_bwd_mb_idx[i] + 1\n            next_available_clock[i] = next_clock + 1\n        # update status matrix for the last worker\n        if i > 0:\n            finished_bwd_batch_indices[next_available_clock[i]:num_clock, i] = m\n\n    return schedules\n\n\n\ndef pprint_schedule(schedules):\n    num_device = len(schedules[0])\n    device_str = \" \".join([\"{:<8}\".format(\"d\" + str(d)) for d in range(num_device)])\n    print(\"Clock {:<2}: {}\".format(\"id\", device_str))\n    for clock, scheds in enumerate(schedules):\n        sched_str = \" \".join([\"{:<8}\".format(str(sched)) for sched in scheds])\n        print(\"Clock {:<2}: {}\".format(clock, sched_str))\n\n\nif __name__ == \"__main__\":\n    m = 4\n    n = 3\n    schedules = generate_gpipe_schedule(m, n)\n    pprint_schedule(schedules)\n    print(\"\\n\")\n    schedules = generate_1f1b_schedule(m, n)\n    pprint_schedule(schedules)\n"
  },
  {
    "path": "playground/pipeline/test_pipeline_mlp_distributed.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport numpy as np\nimport os\nimport ray\nfrom flax import linen as nn\nfrom flax import optim\nfrom flax.core.frozen_dict import FrozenDict as FrozenDictFlax\nfrom jax.experimental.maps import FrozenDict as FrozenDictJax\n\nfrom alpa import parallelize, mark_pipeline\n\nMB = 1024 ** 2\nnum_gpus = 2\nassert len(jax.local_devices()) >= num_gpus\ndevices = tuple(jax.local_devices()[:num_gpus])\n\n\n# in order for ray to work we have to set this\n# so the driver program and actor program can share GPUs...\nos.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"False\"\n\n\ndef is_sequence(x):\n    try:\n        iter(x)\n    except TypeError:\n        return False\n    else:\n        return True\n\ndef assert_allclose(x, y):\n    if isinstance(x, dict) or isinstance(x, FrozenDictJax) or isinstance(x, FrozenDictFlax):\n        assert isinstance(y, dict) or isinstance(y, FrozenDictJax) or isinstance(x, FrozenDictFlax)\n        assert set(x.keys()) == set(y.keys())\n        for k in x.keys():\n            assert_allclose(x[k], y[k])\n    elif is_sequence(x) and not hasattr(x, '__array__'):\n        assert is_sequence(y) and not hasattr(y, '__array__')\n        assert len(x) == len(y)\n        for x_elt, y_elt in zip(x, y):\n            assert_allclose(x_elt, y_elt)\n    elif hasattr(x, '__array__') or np.isscalar(x):\n        assert hasattr(y, '__array__') or np.isscalar(y)\n        x = np.asarray(x)\n        y = np.asarray(y)\n        assert np.allclose(x, y)\n    elif x == y:\n        return\n    else:\n        raise TypeError((type(x), type(y)))\n\n\nclass Model(nn.Module):\n    hidden_dim: int\n    output_dim: int\n\n    @nn.compact\n    def __call__(self, x):\n        # FIXME (zhuohan): if don't require the gradient of x here, the\n        #                  backward pass of the pipeline start will not\n        #                  be generated.\n        x, = mark_pipeline(x, name='1', mark_type='start')\n        x = nn.Dense(features=self.hidden_dim, use_bias=False)(x)\n        x = nn.relu(x)\n        x, = mark_pipeline(x, name='1', mark_type='end')\n        x, = mark_pipeline(x, name='2', mark_type='start')\n        x = nn.Dense(features=self.output_dim, use_bias=False)(x)\n        return x\n\ndef train_step(optimizer, batch, apply_fn):\n    def loss_func(params, x, y):\n        out = apply_fn(params, x)\n        loss = jnp.mean((out - y) ** 2)\n        loss, = mark_pipeline(loss, name='2', mark_type='end')\n        return loss\n\n    grad_param, grad_x = jax.grad(loss_func, argnums = (0, 1))(optimizer.target, batch['x'], batch['y'])\n    # new_optimizer = optimizer.apply_gradient(grad_param)\n    return grad_param\n\n\nray.init(num_cpus=8, num_gpus=2)\nbatch_size = 128\nhidden_dim = 2048\ninput_dim = output_dim = hidden_dim\n\nx = jnp.ones((batch_size, input_dim))\ny = jnp.ones((batch_size, output_dim))\n\n# Init model and optimizer\nmodel = Model(hidden_dim=hidden_dim, output_dim=output_dim)\nrngkey = jax.random.PRNGKey(0)\nparams = model.init(rngkey, x)\noptimizer = optim.GradientDescent(1e-2).create(params)\n\ngradients = train_step(optimizer, {\"x\": x, \"y\": y}, model.apply)\n# strategy = \"distributed_pipeline_parallel\"\n# strategy = \"pipeline_parallel\"\nstrategy = \"3d_parallel\"\n# import cloudpickle as pickle\n# m = pickle.dumps(train_step)\n# new_train_step = pickle.loads(m)\n# print(\"OK\")\n# new_gradients = new_train_step(optimizer, {\"x\": x, \"y\": y}, model.apply)\nassert_allclose(x, y)\npipelined_train_step = parallelize(donate_argnums=(), devices=devices, strategy=strategy)(train_step)\ngradients_with_pipeline = pipelined_train_step(optimizer, {\"x\": x, \"y\": y}, model.apply)\nassert_allclose(gradients, gradients_with_pipeline)\n"
  },
  {
    "path": "playground/pipeline/test_ray_jax_array.py",
    "content": "# check gpu devices\nimport os\n\nimport jax.numpy as jnp\nimport ray\n\n\nos.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"] = \"False\"\nray.init(num_gpus=2, num_cpus=4)\n\n\n@ray.remote(num_gpus=1, num_cpus=2)\nclass Runner:\n    def __init__(self, name):\n        print(\"ray.get_gpu_ids(): {}\".format(ray.get_gpu_ids()))\n        print(\"CUDA_VISIBLE_DEVICES: {}\".format(os.environ[\"CUDA_VISIBLE_DEVICES\"]))\n        self.name = name\n        self.a = None\n        self.b = None\n\n    def compute(self):\n        print(type(self.a))\n        print(type(self.b))\n        c = jnp.matmul(self.a, self.b)\n        print(type(c))\n        return c\n\n    def set(self, refs):\n        arrays = ray.get(refs)\n        print(arrays)\n        # a = ray.get(a_ref)\n        # print(a)\n        # print(type(a))\n        self.a = jnp.asarray(arrays[0])\n        # b = ray.get(b_ref)\n        # print(b)\n        # print(type(b))\n        self.b = jnp.asarray(arrays[1])\n\n\nworkers = []\nworkers.append(Runner.remote(name=\"0\"))\nworkers.append(Runner.remote(name=\"1\"))\n\na = jnp.ones([3, 4])\nb = jnp.ones([4, 5])\na_ref = ray.put(a)\nb_ref = ray.put(b)\nworker = workers[0]\nworker.set.remote([a_ref, b_ref])\nc_ref = worker.compute.remote()\nc_result = ray.get(c_ref)\n\nworker = workers[1]\nworker.set.remote([a_ref, b_ref])\nc_ref = worker.compute.remote()\nc_result = ray.get(c_ref)\nprint(c_result)\n"
  },
  {
    "path": "playground/xla_builder/test_multi_host.py",
    "content": "import numpy as np\nimport ray\nfrom jax.lib import xla_client\n\nfrom alpa import DeviceCluster, XlaPassContext, parallelize, global_config\n\nops = xla_client.ops\n\n\ndef parameter(builder, num, shape, dtype):\n    shape = xla_client.Shape.array_shape(np.dtype(dtype), shape)\n    name = \"\"\n    replicated = []\n    return ops.Parameter(builder, num,\n                         shape.with_major_to_minor_layout_if_absent(), name,\n                         replicated)\n\n\ndef all_reduce(builder, operand, reduce_op, replica_groups):\n    replica_groups_protos = xla_client.make_replica_groups(replica_groups)\n    if reduce_op == 'add':\n        rc = xla_client.XlaBuilder(\"reduce_\" + reduce_op)\n        x = parameter(rc, 0, (), np.float32)\n        y = parameter(rc, 1, (), np.float32)\n        z = ops.Add(x, y)\n        rc = rc.build(z)\n    else:\n        raise NotImplementedError\n\n    return ops.AllReduce(operand, rc, replica_groups_protos,\n            None, None)\n\n\ndef test_multi_host_all_reduce():\n    device_cluster = DeviceCluster()\n\n    print(\"Device mesh\")\n    device_mesh = device_cluster.get_physical_mesh()\n\n    def get_hlo_module_proto():\n        backend = xla_client._gpu_backend_factory()\n        c = xla_client.XlaBuilder(\"shard\")\n        x = parameter(c, 0, (5,), np.float32)\n        z = all_reduce(c, x, 'add', (tuple(range(device_mesh.num_devices)),))\n        c = c.build(ops.Tuple(c, [z]))\n\n        global_device_ids = np.arange(device_mesh.num_devices)\n\n        num_replicas = len(global_device_ids)\n        num_partitions = 1\n        device_assignment = global_device_ids.reshape((num_replicas, num_partitions))\n        device_assignment = xla_client.DeviceAssignment.create(device_assignment)\n        use_spmd_partitioning = False\n\n        compile_options = xla_client.CompileOptions()\n        build_options = compile_options.executable_build_options\n        build_options.num_replicas = num_replicas\n        build_options.num_partitions = num_partitions\n        build_options.use_spmd_partitioning = use_spmd_partitioning\n        build_options.device_assignment = device_assignment\n\n        with XlaPassContext({\n            \"build_option::pass_through_device_assignment\": True\n        }):\n            compiled_computation = backend.compile(c, compile_options)\n        hlo_module = compiled_computation.hlo_modules()[0]\n        return hlo_module\n\n    # Prepare inputs. shape: (num_hosts, num_args, num_devices)\n    dtype = np.float32\n    host_inputs = [   \n        [[np.ones(5, dtype=dtype), np.ones(5, dtype=dtype)]],\n        [[np.ones(5, dtype=dtype), np.ones(5, dtype=dtype)]],\n    ]\n\n    # Compile and run\n    hlo_module = get_hlo_module_proto()\n    device_mesh.launch_distributed_xla_service()\n    device_mesh.compile_hlo_module(hlo_module, None, None)\n    device_mesh.execute(host_inputs)\n    device_mesh.sync_workers()\n\n\ndef test_multi_host_auto_sharding():\n    global_config.shard_parallel_strategy = \"auto_sharding\"\n\n    device_cluster = DeviceCluster()\n    physical_mesh = device_cluster.get_physical_mesh()\n    num_devices = len(physical_mesh.host_ids) * physical_mesh.num_devices_per_host\n    logical_mesh = physical_mesh.get_logical_mesh([1, num_devices], [1, 1], [1, 1])\n\n    @parallelize(devices=logical_mesh)\n    def add_one(x):\n        x = x + 1\n        return x\n\n    a = np.ones((1000, 1000))\n    out = add_one(a)\n\n    print(\"Output\", out)\n\n\nif __name__ == \"__main__\":\n    ray.init(address=\"auto\")\n    test_multi_host_auto_sharding()\n"
  },
  {
    "path": "playground/xla_builder/test_xla_builder.py",
    "content": "from functools import partial\n\nimport numpy as np\nimport jax\nimport jax.numpy as jnp\nfrom jax.lib import xla_client, xla_bridge\n\nops = xla_client.ops\n\nMB = 1 << 20\n\ndef test_sin_cos():\n    def f(x):\n        return jax.numpy.sin(jax.numpy.cos(x.T))\n\n    c = jax.xla_computation(f)(np.ones((10,8)))\n\n    gpu_backend = xla_bridge.get_backend(\"gpu\")\n    compiled_computation = gpu_backend.compile(c)\n\n    print(c.as_hlo_text())\n    print(compiled_computation.hlo_modules()[0].to_string())\n\n    host_input = np.ones((10,8), dtype=np.float32)\n    device_input = gpu_backend.buffer_from_pyval(host_input)\n    device_out = compiled_computation.execute([device_input,])\n\n\ndef parameter(builder, num, shape, dtype):\n    shape = xla_client.Shape.array_shape(np.dtype(dtype), shape)\n    name = \"\"\n    replicated = []\n    return ops.Parameter(builder, num,\n                         shape.with_major_to_minor_layout_if_absent(), name,\n                         replicated)\n\ndef test_alias():\n    c = xla_client.XlaBuilder(\"test\")\n    a = parameter(c, 0, (8 * MB//4,), np.float32)\n    b = parameter(c, 1, (8 * MB//4,), np.float32)\n    d = parameter(c, 2, (8 * MB//4,), np.float32)\n    e = parameter(c, 3, (8 * MB//4,), np.float32)\n\n    backend = xla_bridge.get_backend(\"gpu\")\n\n    #z = ops.Add(a, b)\n    z = ops.Constant(c, 0.1)\n\n    #c.setup_alias((0,), 0, ())\n\n    c = c.build(ops.Tuple(c, [z]))\n    compiled_c = backend.compile(c)\n    real_mem = compiled_c.total_allocation_size()\n\n    print(compiled_c.hlo_modules()[0].to_string())\n    print(f\"{real_mem / MB:.2f} MB\")\n\n    #a = backend.buffer_from_pyval(np.ones((8 * MB // 4), dtype=np.float32))\n    #b = backend.buffer_from_pyval(np.ones((8 * MB // 4), dtype=np.float32))\n    #d = backend.buffer_from_pyval(np.ones((8 * MB // 4), dtype=np.float32))\n    #e = backend.buffer_from_pyval(np.ones((8 * MB // 4), dtype=np.float32))\n\n    #for i in range(10):\n    #    ans, = compiled_c.execute([a, b, d, e])\n\n\ndef test_shard():\n    c = xla_client.XlaBuilder(\"shard\")\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.REPLICATED\n    sharding.tile_assignment_dimensions = [1]\n    sharding.tile_assignment_devices = [0]\n    c.set_sharding(sharding)\n    x = ops.Parameter(c, 0, xla_client.shape_from_pyval(np.ones((10, 8), dtype=np.float32)))\n    c.clear_sharding()\n    y = ops.Parameter(c, 1, xla_client.shape_from_pyval(np.ones((10, 8), dtype=np.float32)))\n\n    backend = xla_bridge.get_backend(\"gpu\")\n\n    z = ops.Add(x, y)\n    z = ops.Add(z, y)\n\n    c = c.build(z)\n    #print(c.as_hlo_text())\n\n    compiled_c = backend.compile(c)\n\n    print(compiled_c.hlo_modules()[0].to_string())\n\n    x = backend.buffer_from_pyval(np.ones((10, 8), dtype=np.float32))\n    y = backend.buffer_from_pyval(np.ones((10, 8), dtype=np.float32))\n    ans, = compiled_c.execute([x, y])\n\n\ndef parameter(builder, num, shape, dtype):\n    shape = xla_client.Shape.array_shape(np.dtype(dtype), shape)\n    name = \"\"\n    replicated = []\n    return ops.Parameter(builder, num,\n                         shape.with_major_to_minor_layout_if_absent(), name,\n                         replicated)\n\n\ndef all_reduce(builder, operand, reduce_op, replica_groups):\n    replica_groups_protos = xla_client.make_replica_groups(replica_groups)\n    if reduce_op == 'add':\n        rc = xla_client.XlaBuilder(\"reduce_\" + reduce_op)\n        x = parameter(rc, 0, (), np.float32)\n        y = parameter(rc, 1, (), np.float32)\n        z = ops.Add(x, y)\n        rc = rc.build(z)\n    else:\n        raise NotImplementedError\n\n    return ops.AllReduce(operand, rc, replica_groups_protos,\n            None, None)\n\n\ndef test_manual_construct_replica():\n    c = xla_client.XlaBuilder(\"shard\")\n    x = parameter(c, 0, (2, 2), np.float32)\n    y = ops.Constant(c, np.float32(1))\n    z = ops.Broadcast(y, (2, 2))\n    z = ops.Add(x, z)\n    z = all_reduce(c, z, 'add', ((0, 1, 2, 3,),))\n\n    c = c.build(ops.Tuple(c, [z]))\n    print(c.as_hlo_text())\n\n    num_replicas = 4\n    num_partitions = 1\n    device_assignment = xla_client.DeviceAssignment.create([[0], [1], [2], [3]])\n    use_spmd_partitioning = False\n\n    compile_options = xla_client.CompileOptions()\n    build_options = compile_options.executable_build_options\n    build_options.num_replicas = num_replicas\n    build_options.num_partitions = num_partitions\n    build_options.use_spmd_partitioning = use_spmd_partitioning\n    build_options.device_assignment = device_assignment\n\n    backend = xla_bridge.get_backend(\"gpu\")\n    compiled_computation = backend.compile(c, compile_options)\n\n    host_input = np.ones((2,2), dtype=np.float32)\n    device_inputs = [[\n        backend.buffer_from_pyval(host_input, backend.devices()[i])\n        for i in range(4)\n    ]]\n\n    device_outs = compiled_computation.execute_sharded_on_local_devices(device_inputs)\n    print(device_outs)\n\n\ndef test_manual_construct_spmd_shard():\n    c = xla_client.XlaBuilder(\"shard\")\n\n    # Set input sharding\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.OTHER\n    sharding.tile_assignment_dimensions = [2, 1]\n    sharding.tile_assignment_devices = [0, 1]\n    c.set_sharding(sharding)\n    x = parameter(c, 0, (2, 2), np.float32)\n    c.clear_sharding()\n\n    # Build computational graph\n    y = ops.Constant(c, np.float32(1))\n    z = ops.Broadcast(y, (2, 2))\n    z = ops.Add(x, z)\n\n    # Set output sharding\n    sharding2 = xla_client.OpSharding()\n    sharding2.type = sharding.type.TUPLE\n    sharding2.tuple_shardings = [sharding]\n    c.set_sharding(sharding2)\n    out = ops.Tuple(c, [z])\n    c.clear_sharding()\n\n    # Build HLO\n    c = c.build(out)\n    print(c.as_hlo_text())\n    print(\"=\" * 20)\n\n    # Compile\n    num_replicas = 1\n    num_partitions = 2\n    use_spmd_partitioning = False\n    device_assignment = xla_client.DeviceAssignment.create([[0, 1]])\n    compile_options = xla_client.CompileOptions()\n    build_options = compile_options.executable_build_options\n    build_options.num_replicas = num_replicas\n    build_options.num_partitions = num_partitions\n    build_options.use_spmd_partitioning = True\n    build_options.device_assignment = device_assignment\n\n    backend = xla_bridge.get_backend(\"gpu\")\n    compiled_computation = backend.compile(c, compile_options)\n\n    # Print spmd partitioned HLO\n    print(compiled_computation.hlo_modules()[0].to_string())\n\n    # Run\n    host_input = np.ones((2, 2), dtype=np.float32)\n    device_inputs = [[\n        backend.buffer_from_pyval(host_input[[i],:], backend.devices()[i])\n        for i in range(2)\n    ]]\n    device_outs = compiled_computation.execute_sharded_on_local_devices(device_inputs)\n    print(device_outs)\n\n\ndef test_manual_construct_spmd_one_device():\n    c = xla_client.XlaBuilder(\"shard\")\n\n    # Build a computational graph on device 0\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.OTHER\n    sharding.tile_assignment_dimensions = [1, 1]\n    sharding.tile_assignment_devices = [0,]\n    c.set_sharding(sharding)\n    x = parameter(c, 0, (2, 2), np.float32)\n\n    z = ops.Add(x, x)\n    z = ops.Add(z, z)\n    z = ops.Add(z, z)\n    c.clear_sharding()\n\n    # Build a computational graph on device 1\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.OTHER\n    sharding.tile_assignment_dimensions = [1, 1]\n    sharding.tile_assignment_devices = [1,]\n    c.set_sharding(sharding)\n    z = ops.Add(z, z)\n    z = ops.Add(z, z)\n    out = z\n    c.clear_sharding()\n\n    # Build HLO\n    c = c.build(out)\n    print(c.as_hlo_text())\n    print(\"=\" * 20)\n\n    # Compile\n    num_replicas = 1\n    num_partitions = 2\n    use_spmd_partitioning = False\n    device_assignment = xla_client.DeviceAssignment.create([[0, 1]])\n    compile_options = xla_client.CompileOptions()\n    build_options = compile_options.executable_build_options\n    build_options.num_replicas = num_replicas\n    build_options.num_partitions = num_partitions\n    build_options.use_spmd_partitioning = True\n    build_options.device_assignment = device_assignment\n\n    backend = xla_bridge.get_backend(\"gpu\")\n    compiled_computation = backend.compile(c, compile_options)\n\n    # Print spmd partitioned HLO\n    print(compiled_computation.hlo_modules()[0].to_string())\n\n    # Run\n    host_input = np.ones((2, 2), dtype=np.float32)\n    device_inputs = [[\n        backend.buffer_from_pyval(host_input, backend.devices()[0]),\n        backend.buffer_from_pyval(host_input, backend.devices()[1]),\n    ]]\n    device_outs = compiled_computation.execute_sharded_on_local_devices(device_inputs)\n    print(device_outs)\n\n\ndef test_reshard_multi_allgather():\n    c = xla_client.XlaBuilder(\"shard\")\n\n    # Set input sharding\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.OTHER\n    sharding.tile_assignment_dimensions = [8, 2]\n    sharding.tile_assignment_devices = list(range(16))\n    c.set_sharding(sharding)\n    x = parameter(c, 0, (32, 32), np.float32)\n    c.clear_sharding()\n\n    # Build computational graph\n    y = ops.Constant(c, np.float32(1))\n    z = ops.Broadcast(y, (32, 32))\n    z = ops.Add(x, z)\n\n    # Set output sharding\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.REPLICATED\n    #sharding.tile_assignment_dimensions = [2, 2]\n    ##sharding.replicate_on_last_tile_dim = True\n    #sharding.tile_assignment_devices = [0, 1, 2, 3]\n\n    sharding2 = xla_client.OpSharding()\n    sharding2.type = sharding.type.TUPLE\n    sharding2.tuple_shardings = [sharding]\n    c.set_sharding(sharding2)\n    out = ops.Tuple(c, [z])\n    c.clear_sharding()\n\n    # Build HLO\n    c = c.build(out)\n    print(c.as_hlo_text())\n    print(\"=\" * 20)\n\n    # Compile\n    num_replicas = 1\n    num_partitions = 16\n    use_spmd_partitioning = False\n    device_assignment = xla_client.DeviceAssignment.create([list(range(num_partitions))])\n    compile_options = xla_client.CompileOptions()\n    build_options = compile_options.executable_build_options\n    build_options.num_replicas = num_replicas\n    build_options.num_partitions = num_partitions\n    build_options.use_spmd_partitioning = True\n    build_options.device_assignment = device_assignment\n\n    backend = xla_bridge.get_backend(\"gpu\")\n    import alpa\n    with alpa.XlaPassContext({\n        \"build_option::bypass_device_assignment_check\": True,\n    }):\n        compiled_computation = backend.compile(c, compile_options)\n\n    # Print spmd partitioned HLO\n    print(compiled_computation.hlo_modules()[0].to_string())\n\n\ndef test_reshard_all_to_all():\n    c = xla_client.XlaBuilder(\"shard\")\n\n    # Set input sharding\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.OTHER\n    sharding.tile_assignment_dimensions = [4, 1]\n    sharding.tile_assignment_devices = list(range(4))\n    c.set_sharding(sharding)\n    x = parameter(c, 0, (32, 32), np.float32)\n    c.clear_sharding()\n\n    # Build computational graph\n    if False:\n        z = ops.Reshape(x, (2, 16, 32))\n        sharding = xla_client.OpSharding()\n        sharding.type = sharding.type.OTHER\n        sharding.tile_assignment_dimensions = [2, 1, 2]\n        sharding.tile_assignment_devices = list(range(4))\n    else:\n        z = x\n        sharding = xla_client.OpSharding()\n        sharding.type = sharding.type.OTHER\n        sharding.tile_assignment_dimensions = [2, 2]\n        sharding.tile_assignment_devicesi = list(range(4))\n\n    sharding2 = xla_client.OpSharding()\n    sharding2.type = sharding.type.TUPLE\n    sharding2.tuple_shardings = [sharding]\n    c.set_sharding(sharding2)\n    out = ops.Tuple(c, [z])\n    c.clear_sharding()\n\n    # Build HLO\n    c = c.build(out)\n    print(c.as_hlo_text())\n    print(\"=\" * 20)\n\n    # Compile\n    num_replicas = 1\n    num_partitions = 4\n    use_spmd_partitioning = False\n    device_assignment = xla_client.DeviceAssignment.create([list(range(num_partitions))])\n    compile_options = xla_client.CompileOptions()\n    build_options = compile_options.executable_build_options\n    build_options.num_replicas = num_replicas\n    build_options.num_partitions = num_partitions\n    build_options.use_spmd_partitioning = True\n    build_options.device_assignment = device_assignment\n\n    backend = xla_bridge.get_backend(\"gpu\")\n    import alpa\n    with alpa.XlaPassContext({\n        \"build_option::bypass_device_assignment_check\": True,\n    }):\n        compiled_computation = backend.compile(c, compile_options)\n\n    # Print spmd partitioned HLO\n    print(compiled_computation.hlo_modules()[0].to_string())\n\n\ndef test_reshard_change_mesh_shape():\n    c = xla_client.XlaBuilder(\"shard\")\n\n    # Set input sharding\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.OTHER\n    sharding.tile_assignment_dimensions = [1, 2, 2]\n    sharding.tile_assignment_devices = [0, 1, 2, 3]\n    sharding.replicate_on_last_tile_dim = True\n    c.set_sharding(sharding)\n    x = parameter(c, 0, (32, 32), np.float32)\n    c.clear_sharding()\n\n    # Build computational graph\n    z = x\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.OTHER\n    sharding.tile_assignment_dimensions = [4, 1]\n    sharding.tile_assignment_devices = [0, 1, 2, 3]\n\n    sharding2 = xla_client.OpSharding()\n    sharding2.type = sharding.type.TUPLE\n    sharding2.tuple_shardings = [sharding]\n    c.set_sharding(sharding2)\n    out = ops.Tuple(c, [z])\n    c.clear_sharding()\n\n    # Build HLO\n    c = c.build(out)\n    print(c.as_hlo_text())\n    print(\"=\" * 20)\n\n    # Compile\n    num_replicas = 1\n    num_partitions = 4\n    use_spmd_partitioning = False\n    device_assignment = xla_client.DeviceAssignment.create([list(range(num_partitions))])\n    compile_options = xla_client.CompileOptions()\n    build_options = compile_options.executable_build_options\n    build_options.num_replicas = num_replicas\n    build_options.num_partitions = num_partitions\n    build_options.use_spmd_partitioning = True\n    build_options.device_assignment = device_assignment\n\n    backend = xla_bridge.get_backend(\"gpu\")\n    import alpa\n    with alpa.XlaPassContext({\n        \"build_option::bypass_device_assignment_check\": True,\n    }):\n        compiled_computation = backend.compile(c, compile_options)\n\n    # Print spmd partitioned HLO\n    print(compiled_computation.hlo_modules()[0].to_string())\n\n\ndef test_skip_hlo_passes():\n    from alpa import XlaPassContext\n\n    c = xla_client.XlaBuilder(\"shard\")\n\n    # Set input sharding\n    sharding = xla_client.OpSharding()\n    sharding.type = sharding.type.OTHER\n    sharding.tile_assignment_dimensions = [2, 1]\n    sharding.tile_assignment_devices = [0, 1]\n    c.set_sharding(sharding)\n    x = parameter(c, 0, (2, 2), np.float32)\n    c.clear_sharding()\n\n    # Build computational graph\n    y = ops.Constant(c, np.float32(1))\n    z = ops.Broadcast(y, (2, 2))\n    z = ops.Add(x, z)\n\n    # Set output sharding\n    sharding2 = xla_client.OpSharding()\n    sharding2.type = sharding.type.TUPLE\n    sharding2.tuple_shardings = [sharding]\n    c.set_sharding(sharding2)\n    out = ops.Tuple(c, [z])\n    c.clear_sharding()\n\n    # Build HLO\n    c = c.build(out)\n    print(c.as_hlo_text())\n    print(\"=\" * 20)\n\n    # Compile\n    num_replicas = 1\n    num_partitions = 2\n    use_spmd_partitioning = False\n    device_assignment = xla_client.DeviceAssignment.create([[0, 1]])\n    compile_options = xla_client.CompileOptions()\n    build_options = compile_options.executable_build_options\n    build_options.num_replicas = num_replicas\n    build_options.num_partitions = num_partitions\n    build_options.use_spmd_partitioning = True\n    build_options.device_assignment = device_assignment\n\n    backend = xla_bridge.get_backend(\"gpu\")\n    with XlaPassContext({\"build_option::skip_backend_codegen\": True}):\n        compiled_computation = backend.compile(c, compile_options)\n\n    # Print spmd partitioned HLO\n    hlo_module = compiled_computation.hlo_modules()[0]\n    c = xla_client.XlaComputation(hlo_module.as_serialized_hlo_module_proto())\n\n    with XlaPassContext({\"build_option::skip_hlo_passes\": True}):\n        compiled_computation = backend.compile(c, compile_options)\n\n    # Run\n    host_input = np.ones((2, 2), dtype=np.float32)\n    device_inputs = [[\n        backend.buffer_from_pyval(host_input[[i],:], backend.devices()[i])\n        for i in range(2)\n    ]]\n    device_outs = compiled_computation.execute_sharded_on_local_devices(device_inputs)\n    print(device_outs)\n\n\ndef test_create_zero_buffers():\n    shapes = ((2, 2), (3, 3))\n    dtypes = (jnp.float32, jnp.float32)\n\n    def compile_get_zero_buffers(backend, num_devices):\n        c = xla_client.XlaBuilder(\"get_zero_buffers\")\n        sharding = xla_client.OpSharding()\n        sharding.type = sharding.type.REPLICATED\n        c.set_sharding(sharding)\n        ret = []\n        for shape, dtype in zip(shapes, dtypes):\n            zero = ops.Constant(c, dtype(0))\n            zero = ops.Broadcast(zero, shape)\n            ret.append(zero)\n        c.clear_sharding()\n        c = c.build(ops.Tuple(c, ret))\n\n        compile_options = xla_bridge.get_compile_options(\n            num_replicas=1,\n            num_partitions=num_devices,\n            device_assignment=np.arange(num_devices).reshape((1, -1)),\n            use_spmd_partitioning=True,\n        )\n        compiled_computation = backend.compile(c, compile_options)\n        return compiled_computation\n\n    backend = xla_bridge.get_backend(\"gpu\")\n    num_devices = 8\n    get_zero_buffers = compile_get_zero_buffers(backend, num_devices)\n\n    device_outs = get_zero_buffers.execute_sharded_on_local_devices([])\n\n    print(device_outs)\n\n\nif __name__ == \"__main__\":\n    #test_sin_cos()\n    #test_alias()\n    #test_shard()\n\n    #test_manual_construct_replica()\n    #test_manual_construct_spmd_shard()\n    #test_manual_construct_spmd_one_device()\n\n    #test_reshard_multi_allgather()\n    #test_reshard_all_to_all()\n    test_reshard_change_mesh_shape()\n\n    #test_skip_hlo_passes()\n\n    #test_create_zero_buffers()\n\n"
  },
  {
    "path": "setup.py",
    "content": "import glob\nimport os\nimport re\nimport shutil\nimport subprocess\nimport sys\n\nfrom setuptools import setup, find_packages\n\nIS_WINDOWS = sys.platform == \"win32\"\nROOT_DIR = os.path.dirname(__file__)\nHAS_CUDA = os.system(\"nvidia-smi > /dev/null 2>&1\") == 0\n\n\ndef get_cuda_version(cuda_home):\n    \"\"\"Locate the CUDA version.\"\"\"\n    version_file = os.path.join(cuda_home, \"version.txt\")\n    try:\n        if os.path.isfile(version_file):\n            with open(version_file, \"r\") as f_version:\n                version_str = f_version.readline().replace(\"\\n\", \"\").replace(\n                    \"\\r\", \"\")\n                return version_str.split(\" \")[2][:4]\n        else:\n            version_str = subprocess.check_output(\n                [os.path.join(cuda_home, \"bin\", \"nvcc\"), \"--version\"])\n            version_str = str(version_str).replace(\"\\n\", \"\").replace(\"\\r\", \"\")\n            idx = version_str.find(\"release\")\n            return version_str[idx + len(\"release \"):idx + len(\"release \") + 4]\n    except RuntimeError:\n        raise RuntimeError(\"Cannot read cuda version file\")\n\n\ndef locate_cuda():\n    \"\"\"Locate the CUDA environment on the system.\"\"\"\n    # Guess #1\n    cuda_home = os.environ.get(\"CUDA_HOME\") or os.environ.get(\"CUDA_PATH\")\n    if cuda_home is None:\n        # Guess #2\n        try:\n            which = \"where\" if IS_WINDOWS else \"which\"\n            nvcc = subprocess.check_output([which,\n                                            \"nvcc\"]).decode().rstrip(\"\\r\\n\")\n            cuda_home = os.path.dirname(os.path.dirname(nvcc))\n        except subprocess.CalledProcessError:\n            # Guess #3\n            if IS_WINDOWS:\n                cuda_homes = glob.glob(\n                    \"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*\")\n                if len(cuda_homes) == 0:\n                    cuda_home = \"\"\n                else:\n                    cuda_home = cuda_homes[0]\n            else:\n                cuda_home = \"/usr/local/cuda\"\n            if not os.path.exists(cuda_home):\n                cuda_home = None\n    version = get_cuda_version(cuda_home)\n    cudaconfig = {\n        \"home\":\n            cuda_home,\n        \"include\":\n            os.path.join(cuda_home, \"include\"),\n        \"lib64\":\n            os.path.join(cuda_home,\n                         os.path.join(\"lib\", \"x64\") if IS_WINDOWS else \"lib64\"),\n    }\n    if not all([os.path.exists(v) for v in cudaconfig.values()]):\n        raise EnvironmentError(\n            \"The CUDA  path could not be located in $PATH, $CUDA_HOME or $CUDA_PATH. \"\n            \"Either add it to your path, or set $CUDA_HOME or $CUDA_PATH.\")\n\n    return cudaconfig, version\n\n\ndef get_cuda_version_str(no_dot=False):\n    \"\"\"Return the cuda version in the format of [x.x].\"\"\"\n    ver = locate_cuda()[1]\n    if no_dot:\n        ver = ver.replace(\".\", \"\")\n    return ver\n\n\ninstall_require_list = [\n    \"tqdm\",\n    \"ray\",\n    \"jax==0.3.22\",\n    \"chex==0.1.5\",\n    \"flax==0.6.2\",\n    \"pulp>=2.6.0\",\n    \"numpy>=1.20\",\n    \"numba\",\n]\n\ndev_require_list = [\"yapf==0.32.0\", \"pylint==2.14.0\", \"cmake\", \"pybind11\"]\n\nif HAS_CUDA:\n    dev_require_list += [\n        f\"cupy-cuda{get_cuda_version_str(no_dot=True)}\",\n    ]\n\ndoc_require_list = [\n    \"sphinx\", \"sphinx-rtd-theme\", \"sphinx-gallery\", \"matplotlib\"\n]\n\n\ndef get_alpa_version():\n    with open(os.path.join(ROOT_DIR, \"alpa\", \"version.py\")) as fp:\n        version_match = re.search(r\"^__version__ = ['\\\"]([^'\\\"]*)['\\\"]\",\n                                  fp.read(), re.M)\n        if version_match:\n            return version_match.group(1)\n    raise RuntimeError(\"Unable to find version string.\")\n\n\nif __name__ == \"__main__\":\n    import setuptools\n    from setuptools.command.install import install\n\n    class BinaryDistribution(setuptools.Distribution):\n\n        def has_ext_modules(self):\n            return False\n\n    class InstallPlatlib(install):\n\n        def finalize_options(self):\n            install.finalize_options(self)\n            if self.distribution.has_ext_modules():\n                self.install_lib = self.install_platlib\n\n    with open(\"README.md\", encoding=\"utf-8\") as f:\n        long_description = f.read()\n\n    setup(\n        name=\"alpa\",\n        version=get_alpa_version(),\n        author=\"Alpa Developers\",\n        author_email=\"\",\n        description=\n        \"Alpa automatically parallelizes large tensor computation graphs and \"\n        \"runs them on a distributed cluster.\",\n        long_description=long_description,\n        long_description_content_type=\"text/markdown\",\n        url=\"https://github.com/alpa-projects/alpa\",\n        classifiers=[\n            'Programming Language :: Python :: 3',\n            'Topic :: Scientific/Engineering :: Artificial Intelligence'\n        ],\n        keywords=(\"alpa distributed parallel machine-learning model-parallelism\"\n                  \"gpt-3 deep-learning language-model python\"),\n        packages=find_packages(\n            exclude=[\"benchmark\", \"examples\", \"playground\", \"tests\"]),\n        python_requires='>=3.7',\n        cmdclass={\"install\": InstallPlatlib},\n        install_requires=install_require_list,\n        extras_require={\n            'dev': dev_require_list,\n            'doc': doc_require_list + dev_require_list,\n        },\n    )\n"
  },
  {
    "path": "tests/README.md",
    "content": "# Unit test\n\n## Requirement\nA machine with at least 4 gpus.\n\n## Run all test cases\n\n1. Start a ray cluster\n```\nray start --head\n```\n\n2. Run all tests\n```\npython3 run_all.py\n```\n\n## Run specific files\n\n- For debug usage:\n```\npython3 shard_parallel/test_basic.py\n```\n\n- More similar to how CI runs files\n```\n# Run one file\npython3 run_all.py --run-pattern shard_parallel/test_basic.py\n\n# Run a folder\npython3 run_all.py --run-pattern shard_parallel\n```\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/killall_python.sh",
    "content": "kill -9 $(ps aux | grep 'python3' | grep -v 'grep' | awk '{print $2}')\n"
  },
  {
    "path": "tests/pipeline_parallel/test_bert.py",
    "content": "import unittest\nimport os\n\nimport jax\nimport jax.numpy as jnp\nimport optax\nimport ray\n\nfrom alpa import init, parallelize, PipeshardParallel\nfrom alpa.model.model_util import TrainState\nfrom alpa.model.bert_model import BertConfig\nfrom alpa.parallel_method import LocalPipelineParallel\nfrom alpa.pipeline_parallel.layer_construction import manual_layer_construction\nfrom alpa.testing import BertLayerModel, assert_allclose\n\n\nclass PipelineBERTTest(unittest.TestCase):\n\n    def setUp(self):\n        os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\"\n\n    def train_2_layer_bert(self, method):\n\n        def train_step(state, batch):\n\n            def loss_func(params, x, y, attention_mask):\n                out = state.apply_fn(params, x, attention_mask)\n                loss = jnp.mean((out - y)**2)\n                return loss\n\n            loss_func = manual_layer_construction(loss_func)\n            grads = jax.grad(loss_func)(state.params, batch[\"x\"], batch[\"y\"],\n                                        batch[\"attention_mask\"])\n            return grads\n\n        batch_size = 16\n        seq_len = 8\n        hidden_size = 128\n        num_heads = 8\n        dtype = jnp.float32\n\n        rngkey = jax.random.PRNGKey(0)\n        x = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size),\n                              dtype=dtype)\n        y = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size),\n                              dtype=dtype)\n        attention_mask = jnp.ones((batch_size, seq_len), dtype=dtype)\n\n        # Init model and optimizer\n        model = BertLayerModel(config=BertConfig(hidden_size=hidden_size,\n                                                 intermediate_size=hidden_size *\n                                                 4,\n                                                 num_attention_heads=num_heads,\n                                                 num_hidden_layers=2))\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x, attention_mask)\n        tx = optax.sgd(learning_rate=1e-2)\n        state = TrainState.create(apply_fn=model.apply,\n                                  params=params,\n                                  tx=tx,\n                                  dynamic_scale=None)\n\n        # Train step\n        batch = {\"x\": x, \"y\": y, \"attention_mask\": attention_mask}\n        gradients = train_step(state, batch)\n        p_train_step = parallelize(train_step, donate_argnums=(), method=method)\n        gradients_with_pipeline = p_train_step(state, batch)\n\n        # Check results\n        assert_allclose(gradients, gradients_with_pipeline)\n\n    def test_2_layer_bert_local_pipeline_parallel(self):\n        self.train_2_layer_bert(LocalPipelineParallel())\n\n    def test_2_layer_bert_pipeshard_parallel(self):\n        init(cluster=\"ray\")\n        self.train_2_layer_bert(PipeshardParallel())\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipelineBERTTest(\"test_2_layer_bert_local_pipeline_parallel\"))\n    suite.addTest(PipelineBERTTest(\"test_2_layer_bert_pipeshard_parallel\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_cross_mesh_resharding.py",
    "content": "\"\"\"Test cross-mesh resharding.\"\"\"\nimport unittest\nfrom alpa.pipeline_parallel.runtime_emitter import PipelineInstEmitter\n\nimport jax\nfrom jax import xla\nfrom jax.core import Var\nfrom jax._src.abstract_arrays import ShapedArray\nfrom jax.interpreters.pxla import (Chunked, NoSharding, Replicated, ShardedAxis,\n                                   ShardingSpec, spec_to_indices)\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom alpa import init\nfrom alpa.device_mesh import (DistributedArray, create_remote_array_refs,\n                              get_global_virtual_physical_mesh)\nfrom alpa.mesh_executable import next_mesh_executable_uuid\nfrom alpa.global_env import global_config\nfrom alpa.pipeline_parallel.cross_mesh_resharding import (\n    CollectiveGroup, ReshardingTaskSpec, CrossMeshCommunicator,\n    SymbolicReshardingTask, SymbolicBroadcastReshardingTask)\nfrom alpa.pipeline_parallel.pipeshard_executable import (\n    AllocateZeroWorkerExecutableConfig, PipelineInstruction,\n    PipeshardMeshWorkerExecutable)\nfrom alpa.pipeline_parallel.resharding_tensor import VirtualDistributedArray\nfrom alpa.testing import assert_allclose\nfrom alpa.util import get_shard_shape\n\n\ndef test_resharding(var,\n                    src_mesh,\n                    src_sharding_spec,\n                    dst_mesh,\n                    dst_sharding_spec,\n                    use_local_allgather,\n                    resharding_mode,\n                    src_loads=None,\n                    dst_loads=None):\n    global_config.use_local_allgather = use_local_allgather\n    global_config.resharding_mode = resharding_mode\n\n    # Resharding task spec and send/recv strategy\n    src_loads = src_loads or {src: 0 for src in src_mesh.device_strs}\n    dst_loads = dst_loads or {dst: 0 for dst in dst_mesh.device_strs}\n    if resharding_mode == \"send_recv\":\n        rewrite_dst_sharding_spec = CrossMeshCommunicator._rewrite_allgather_spec(\n            dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape)\n    else:\n        rewrite_dst_sharding_spec = dst_sharding_spec\n    src_array = VirtualDistributedArray(device_mesh=src_mesh,\n                                        aval=var.aval,\n                                        sharding_spec=src_sharding_spec)\n    dst_array = VirtualDistributedArray(device_mesh=dst_mesh,\n                                        aval=var.aval,\n                                        sharding_spec=rewrite_dst_sharding_spec)\n    task_spec = ReshardingTaskSpec(src_array, dst_array, dst_sharding_spec)\n    if resharding_mode == \"send_recv\":\n        strategy = CrossMeshCommunicator._generate_send_recv_resharding_strategy_by_loads(\n            task_spec, src_loads, dst_loads)\n    else:\n        strategy = CrossMeshCommunicator._generate_broadcast_resharding_strategy_by_loads(\n            task_spec, src_loads, dst_loads)\n    task_spec.set_resharding_strategy(strategy)\n\n    # Resharding task. Compile send/recv from strategy and allgather.\n    collective_group = CollectiveGroup(task_spec.get_participant_device_strs(),\n                                       src_mesh, dst_mesh)\n    if global_config.eagerly_create_communicators:\n        collective_group.instantiate_now()\n    else:\n        collective_group.instantiate()\n    if resharding_mode == \"send_recv\":\n        task = SymbolicReshardingTask(task_spec, collective_group, src_mesh,\n                                      dst_mesh)\n    else:\n        task = SymbolicBroadcastReshardingTask(task_spec, collective_group,\n                                               src_mesh, dst_mesh)\n\n    # Compile pipeline instructions\n    instruction_lists = {worker: [] for worker in src_mesh.workers}\n    for worker in dst_mesh.workers:\n        instruction_lists[worker] = []\n    executable_config_lists = {worker: [] for worker in dst_mesh.workers}\n    src_uuid = 21474\n    dst_uuid = 21475\n    # allocate the buffer\n    exec_uuid = next_mesh_executable_uuid()\n    config = AllocateZeroWorkerExecutableConfig(\n        exec_uuid, [get_shard_shape(var.aval, rewrite_dst_sharding_spec)],\n        [var.aval.dtype])\n    output_uuids = [dst_uuid]\n    for worker in dst_mesh.workers:\n        executable_config_lists[worker].append(config)\n        in_uuids = []\n        out_uuids = output_uuids\n        instruction_lists[worker].append(\n            PipelineInstruction.run(config.exec_uuid,\n                                    in_uuids,\n                                    out_uuids, {\n                                        \"sync_before\": False,\n                                        \"sync_after\": False\n                                    },\n                                    info=\"allocate zero for recv\"))\n    # Create resharding task\n    if resharding_mode == \"send_recv\":\n        PipelineInstEmitter._compile_resharding_task(src_uuid, task, dst_uuid,\n                                                     instruction_lists)\n    else:\n        PipelineInstEmitter._compile_broadcast_resharding_task(\n            src_mesh, src_uuid, task, dst_uuid, instruction_lists)\n\n    exec_uuids = {}\n\n    # Compile Pipeline Executable\n    for worker in src_mesh.workers:\n        exec_uuid = next_mesh_executable_uuid()\n        worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable,\n                                     instruction_lists[worker], [src_uuid], [],\n                                     [], [], [],\n                                     [False] * src_mesh.num_devices_per_host)\n        exec_uuids[worker] = exec_uuid\n    for worker in dst_mesh.workers:\n        exec_uuid = next_mesh_executable_uuid()\n        worker.put_executable.remote(exec_uuid, PipeshardMeshWorkerExecutable,\n                                     instruction_lists[worker], [], [dst_uuid],\n                                     executable_config_lists[worker], [], [],\n                                     [False] * dst_mesh.num_devices_per_host)\n        exec_uuids[worker] = exec_uuid\n\n    # Prepare array and shard args\n    test_array = np.arange(np.prod(var.aval.shape),\n                           dtype=var.aval.dtype).reshape(var.aval.shape)\n    indices = spec_to_indices(var.aval.shape, src_sharding_spec)\n    test_array = xla.canonicalize_dtype(test_array)\n    input_refs = src_mesh.shard_args_to_bufs([indices], (False,), (False,),\n                                             None, [test_array])\n    input_refs = np.array(input_refs)\n    input_uuids = [ref.uuid for ref in input_refs]\n    output_refs, output_uuids = create_remote_array_refs(dst_mesh)\n\n    # Run executables\n    # for _ in range(3):\n    # timers(\"overall_resharding_time\").start()\n    for worker in src_mesh.workers:\n        worker.run_executable.remote(exec_uuids[worker],\n                                     input_uuids, [],\n                                     sync_for_timer=True,\n                                     collect_trace=False)\n    for worker in dst_mesh.workers:\n        worker.run_executable.remote(exec_uuids[worker], [],\n                                     output_uuids,\n                                     sync_for_timer=True,\n                                     collect_trace=False)\n    output_array = DistributedArray(dst_mesh, var.aval, dst_sharding_spec,\n                                    output_refs[0])\n\n    # dst_mesh.sync_workers()\n    # timers(\"overall_resharding_time\").stop()\n    # timers(\"overall_resharding_time\").log()\n    # timers(\"overall_resharding_time\").reset()\n\n    # Check correctness\n    assert_allclose(test_array, output_array)\n\n    # Delete executables\n    for worker in src_mesh.workers:\n        worker.delete_executable.remote(exec_uuids[worker])\n    for worker in dst_mesh.workers:\n        worker.delete_executable.remote(exec_uuids[worker])\n\n\nclass ReshardingTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def run_resharding_task(self,\n                            src_mesh_shape,\n                            dst_mesh_shape,\n                            src_sharding_spec,\n                            dst_sharding_spec,\n                            tensor_shape,\n                            use_local_allgather=True,\n                            resharding_mode=\"send_recv\",\n                            tensor_dtype=None):\n        virtual_mesh = get_global_virtual_physical_mesh()\n        src_num_host = src_mesh_shape[0]\n        dst_num_host = dst_mesh_shape[0]\n        src_mesh = virtual_mesh.slice_2d(range(src_num_host),\n                                         [range(src_mesh_shape[1])] *\n                                         src_num_host).get_physical_mesh()\n        if (src_mesh_shape[1] + dst_mesh_shape[1] <=\n                virtual_mesh.num_devices_per_host):\n            dst_host_indices = range(dst_num_host)\n            dst_device_indices = [\n                range(src_mesh_shape[1], src_mesh_shape[1] + dst_mesh_shape[1])\n            ] * dst_num_host\n        else:\n            dst_host_indices = range(src_num_host, src_num_host + dst_num_host)\n            dst_device_indices = [range(dst_mesh_shape[1])] * dst_num_host\n        dst_mesh = virtual_mesh.slice_2d(\n            dst_host_indices, dst_device_indices).get_physical_mesh()\n\n        tensor_dtype = tensor_dtype or jnp.int32\n        var = Var(0, \"\", ShapedArray(tensor_shape, tensor_dtype))\n        test_resharding(var, src_mesh, src_sharding_spec, dst_mesh,\n                        dst_sharding_spec, use_local_allgather, resharding_mode)\n        src_mesh.shutdown()\n        dst_mesh.shutdown()\n\n    def _test_4gpu_send_recv(self, nccl_mode):\n        global_config.nccl_mode = nccl_mode\n        src_shape = (1, 2)\n        dst_shape = (1, 2)\n        tensor_shape = (4, 8, 16)\n        src_spec = ShardingSpec(\n            [NoSharding(), NoSharding(),\n             NoSharding()], [Replicated(2)])\n        dst_spec = ShardingSpec([Chunked(\n            [2]), NoSharding(), NoSharding()], [ShardedAxis(0)])\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape)\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape, False)\n        src_spec = ShardingSpec([Chunked(\n            [2]), NoSharding(), NoSharding()], [ShardedAxis(0)])\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape)\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape, False)\n        src_spec = ShardingSpec(\n            [NoSharding(), Chunked([2]),\n             NoSharding()], [ShardedAxis(0)])\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape)\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape, False)\n\n    def _test_4gpu_allgather(self, nccl_mode):\n        global_config.nccl_mode = nccl_mode\n        src_shape = (1, 2)\n        dst_shape = (1, 2)\n        tensor_shape = (4, 8, 16)\n        src_spec = ShardingSpec(\n            [NoSharding(), NoSharding(),\n             NoSharding()], [Replicated(2)])\n        dst_spec = ShardingSpec(\n            [NoSharding(), NoSharding(),\n             NoSharding()], [Replicated(2)])\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape)\n        src_spec = ShardingSpec([Chunked(\n            [2]), NoSharding(), NoSharding()], [ShardedAxis(0)])\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape)\n        src_spec = ShardingSpec(\n            [NoSharding(), Chunked([2]),\n             NoSharding()], [ShardedAxis(0)])\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape)\n        # test allgather at the second dim\n        tensor_shape = (3, 8, 2)\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape)\n\n    def _test_8gpu_2_dim_allgather(self, nccl_mode):\n        global_config.nccl_mode = nccl_mode\n        src_shape = (1, 4)\n        dst_shape = (1, 4)\n        tensor_shape = (6, 8, 16)\n        src_spec = ShardingSpec(\n            [NoSharding(), NoSharding(),\n             NoSharding()], [Replicated(4)])\n        dst_spec = ShardingSpec(\n            [NoSharding(), NoSharding(),\n             NoSharding()], [Replicated(4)])\n        self.run_resharding_task(src_shape, dst_shape, src_spec, dst_spec,\n                                 tensor_shape)\n\n    def _test_4gpu_broadcast(self, nccl_mode):\n        global_config.nccl_mode = nccl_mode\n        src_shape = (1, 2)\n        dst_shape = (1, 2)\n        tensor_shape = (4, 8, 16)\n        src_spec = ShardingSpec(\n            [NoSharding(), NoSharding(),\n             NoSharding()], [Replicated(2)])\n        dst_spec = ShardingSpec([Chunked(\n            [2]), NoSharding(), NoSharding()], [ShardedAxis(0)])\n        self.run_resharding_task(src_shape,\n                                 dst_shape,\n                                 src_spec,\n                                 dst_spec,\n                                 tensor_shape,\n                                 resharding_mode=\"broadcast\")\n        src_spec = ShardingSpec([Chunked(\n            [2]), NoSharding(), NoSharding()], [ShardedAxis(0)])\n        self.run_resharding_task(src_shape,\n                                 dst_shape,\n                                 src_spec,\n                                 dst_spec,\n                                 tensor_shape,\n                                 resharding_mode=\"broadcast\")\n        src_spec = ShardingSpec(\n            [NoSharding(), Chunked([2]),\n             NoSharding()], [ShardedAxis(0)])\n        self.run_resharding_task(src_shape,\n                                 dst_shape,\n                                 src_spec,\n                                 dst_spec,\n                                 tensor_shape,\n                                 resharding_mode=\"broadcast\")\n\n    @unittest.skipIf(jax.device_count('gpu') < 8, \"no enough device\")\n    def _test_8gpu_broadcast(self, nccl_mode):\n        global_config.nccl_mode = nccl_mode\n        src_shape = (1, 4)\n        dst_shape = (1, 4)\n        tensor_shape = (2, 64, 64)\n\n        src_spec = ShardingSpec([Chunked(\n            [2]), Chunked([2]), NoSharding()],\n                                [ShardedAxis(0), ShardedAxis(1)])\n        dst_spec = ShardingSpec(\n            [NoSharding(), NoSharding(),\n             NoSharding()], [Replicated(4)])\n        self.run_resharding_task(src_shape,\n                                 dst_shape,\n                                 src_spec,\n                                 dst_spec,\n                                 tensor_shape,\n                                 resharding_mode=\"broadcast\")\n\n        tensor_shape = (64, 64, 64)\n        src_spec = ShardingSpec([Chunked(\n            [2]), Chunked([2]), NoSharding()],\n                                [ShardedAxis(0), ShardedAxis(1)])\n        dst_spec = ShardingSpec([Chunked(\n            [2]), NoSharding(), Chunked([2])],\n                                [ShardedAxis(0), ShardedAxis(1)])\n        self.run_resharding_task(src_shape,\n                                 dst_shape,\n                                 src_spec,\n                                 dst_spec,\n                                 tensor_shape,\n                                 resharding_mode=\"broadcast\")\n\n    def test_4gpu_send_recv(self):\n        self._test_4gpu_send_recv(\"cupy\")\n        self._test_4gpu_send_recv(\"xla_extension\")\n\n    def test_4gpu_allgather(self):\n        self._test_4gpu_allgather(\"cupy\")\n        self._test_4gpu_allgather(\"xla_extension\")\n\n    @unittest.skipIf(jax.device_count('gpu') < 8, \"no enough device\")\n    def test_8gpu_2_dim_allgather(self):\n        self._test_8gpu_2_dim_allgather(\"cupy\")\n\n    def test_4gpu_broadcast(self):\n        self._test_4gpu_broadcast(\"cupy\")\n        self._test_4gpu_broadcast(\"xla_extension\")\n\n    @unittest.skipIf(jax.device_count('gpu') < 8, \"no enough device\")\n    def test_8gpu_broadcast(self):\n        self._test_8gpu_broadcast(\"cupy\")\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(ReshardingTest(\"test_4gpu_send_recv\"))\n    suite.addTest(ReshardingTest(\"test_4gpu_allgather\"))\n    suite.addTest(ReshardingTest(\"test_8gpu_2_dim_allgather\"))\n    suite.addTest(ReshardingTest(\"test_4gpu_broadcast\"))\n    suite.addTest(ReshardingTest(\"test_8gpu_broadcast\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_dynamic_programming.py",
    "content": "\"\"\"Test dynamic programming.\"\"\"\n\nimport numpy as np\nimport unittest\n\nimport alpa\nfrom alpa.pipeline_parallel.stage_construction import (training_dp as\n                                                       stage_construction_dp,\n                                                       get_submesh_choices)\nfrom alpa.testing import assert_allclose\n\n\nclass DynamicProgrammingTest(unittest.TestCase):\n    \"\"\"Test dynamic programming.\"\"\"\n\n    def test_stage_construction(self):\n        \"\"\"Test stage construction.\"\"\"\n        num_layers = 8\n        num_hosts = 1\n        num_devices_per_host = 8\n        num_devices = num_hosts * num_devices_per_host\n        num_micro_batches = 16\n        num_autosharding_configs = 1\n        for i in range(1, num_devices + 1):\n            if num_devices % i == 0:\n                num_autosharding_configs += 1\n        submesh_choices = get_submesh_choices(num_hosts, num_devices_per_host,\n                                              \"all\")\n        num_submesh_choices = len(submesh_choices)\n        np.random.seed(42)\n        compute_cost = np.random.rand(num_layers, num_layers,\n                                      num_submesh_choices,\n                                      num_autosharding_configs)\n        max_n_succ_stages = np.full(\n            (num_layers, num_layers, num_submesh_choices,\n             num_autosharding_configs), 4096)\n        alpa.util._DISABLE_NUMBA = False\n        numba_cost, _ = stage_construction_dp(num_layers, num_devices,\n                                              num_micro_batches,\n                                              submesh_choices,\n                                              num_autosharding_configs,\n                                              compute_cost, max_n_succ_stages)\n        alpa.util._DISABLE_NUMBA = True\n        no_numba_cost, _ = stage_construction_dp(\n            num_layers, num_devices, num_micro_batches, submesh_choices,\n            num_autosharding_configs, compute_cost, max_n_succ_stages)\n        assert_allclose(numba_cost, no_numba_cost)\n        # Note(zhuohan): The profiling here suggest that the numba jitted\n        #  version is ~250x faster than the non-jitted version. Therefore,\n        #  we highly recommend to use the numba version, but for smaller\n        #  problem sizes, the non-jitted version is also acceptable.\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(unittest.makeSuite(DynamicProgrammingTest))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_global_norm.py",
    "content": "import unittest\n\nimport jax\nfrom jax import numpy as jnp, lax\nfrom jax._src.tree_util import tree_map\nfrom optax import global_norm\n\nfrom alpa import grad\nfrom alpa.testing import PipelineBasicTest\n\n\nclass GlobalNormTest(PipelineBasicTest):\n\n    def test_global_norm(self):\n        hlos = self.run_n_layer_bert(num_layers=2,\n                                     manual_pipeline_layer=False,\n                                     clip_by_global_norm=True)\n        for x in hlos[-2:]:\n            assert \"CrossMeshAllReduce\" in x\n\n    @unittest.skip(\"No data to test efficiently.\")\n    def test_dynamic_scale(self):\n        hlos = self.run_n_layer_bert(num_layers=2,\n                                     manual_pipeline_layer=False,\n                                     use_dynamic_scale=True)\n\n    @unittest.skip(\"No data to test efficiently.\")\n    def test_global_norm_dynamic_scale(self):\n        hlos = self.run_n_layer_bert(num_layers=2,\n                                     manual_pipeline_layer=False,\n                                     clip_by_global_norm=True,\n                                     use_dynamic_scale=True)\n\n    def test_glob_norm_and_all_le(self):\n\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"],\n                                     batch[\"attention_mask\"])\n                loss = jnp.mean((out - batch[\"y\"])**2)\n                return loss\n\n            grads = grad(loss_func)(state.params)\n            glob_norm = global_norm(grads)\n            new_grads = tree_map(lambda g: g / glob_norm, grads)\n            new_state = state.apply_gradients(grads=new_grads)\n\n            ls_1 = jnp.array(True)\n            for g in jax.tree_util.tree_leaves(grads):\n                ls_1 &= jnp.all(lax.le(g, 1.))\n            return new_state, (new_grads, ls_1)\n\n        hlos = self.run_n_layer_bert(num_layers=2, inject_train_step=train_step)\n        for x in hlos[-2:]:\n            assert 'backend_config=\"SUM;' in x\n            assert 'backend_config=\"AND;' in x\n            assert x.count(\"CrossMeshAllReduce\") == 2\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(GlobalNormTest(\"test_global_norm\"))\n    suite.addTest(GlobalNormTest(\"test_dynamic_scale\"))\n    suite.addTest(GlobalNormTest(\"test_global_norm_dynamic_scale\"))\n    suite.addTest(GlobalNormTest(\"test_glob_norm_and_all_le\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_inference_auto.py",
    "content": "import unittest\nfrom alpa import init, PipeshardParallel, AutoStageOption\nfrom tests.pipeline_parallel.test_inference_only import PipelineInferenceTest\n\n\nclass PipelineInferenceAutoTest(PipelineInferenceTest):\n\n    def setUp(self):\n        init(cluster=\"ray\", num_nodes=1, num_devices_per_node=4)\n\n    def test_mlp(self):\n        stage_option = AutoStageOption(\n            submesh_physical_shape_space=\"manual\",\n            manually_specified_submeshes=((1, 2),),\n            submesh_logical_shape_space=\"model_parallel_only\")\n        method = PipeshardParallel(num_micro_batches=1,\n                                   pipeline_schedule=\"inference\",\n                                   layer_option=\"manual\",\n                                   stage_option=stage_option)\n        self.run_mlp_inference(True, method)\n\n    def test_bert(self):\n        stage_option = AutoStageOption(\n            submesh_physical_shape_space=\"manual\",\n            manually_specified_submeshes=((1, 2),),\n            submesh_logical_shape_space=\"model_parallel_only\")\n        method = PipeshardParallel(num_micro_batches=1,\n                                   pipeline_schedule=\"inference\",\n                                   layer_option=\"manual\",\n                                   stage_option=stage_option)\n        self.run_bert_layer_collection_inference(True, method)\n\n    def test_mlp_1d(self):\n        stage_option = AutoStageOption(\n            submesh_physical_shape_space=\"manual\",\n            manually_specified_submeshes=((1, 2),),\n            submesh_logical_shape_space=\"model_parallel_only\",\n            layer_profile_mode=\"individual\")\n        method = PipeshardParallel(num_micro_batches=1,\n                                   pipeline_schedule=\"inference\",\n                                   layer_option=\"manual\",\n                                   stage_option=stage_option)\n        self.run_mlp_inference(True, method)\n\n    def test_bert_1d(self):\n        stage_option = AutoStageOption(\n            submesh_physical_shape_space=\"manual\",\n            manually_specified_submeshes=((1, 2),),\n            submesh_logical_shape_space=\"model_parallel_only\",\n            layer_profile_mode=\"individual\")\n        method = PipeshardParallel(num_micro_batches=1,\n                                   pipeline_schedule=\"inference\",\n                                   layer_option=\"manual\",\n                                   stage_option=stage_option)\n        self.run_bert_layer_collection_inference(True, method)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipelineInferenceAutoTest(\"test_mlp\"))\n    suite.addTest(PipelineInferenceAutoTest(\"test_bert\"))\n    suite.addTest(PipelineInferenceAutoTest(\"test_mlp_1d\"))\n    suite.addTest(PipelineInferenceAutoTest(\"test_bert_1d\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_inference_only.py",
    "content": "import unittest\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom alpa import (init, shutdown, parallelize, PipeshardParallel,\n                  mark_pipeline_boundary)\nfrom alpa.model.bert_model import BertConfig, FlaxBertLayerCollection\nfrom alpa.testing import (MLPModel, create_train_state, mlp_inference_step,\n                          bert_layer_collection_inference_step, assert_allclose)\n\n\nclass PipelineInferenceTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    # pylint: disable=no-self-use\n    def tearDown(self):\n        shutdown()\n\n    def run_mlp_inference(self, manual_pipeline_layer, parallel_method):\n        # Init model and optimizer\n        batch_size = 64\n        hidden_size = 16\n\n        model = MLPModel(hidden_size=hidden_size,\n                         num_layers=4,\n                         add_manual_pipeline_marker=manual_pipeline_layer)\n        rngkey = jax.random.PRNGKey(0)\n        x = jax.random.normal(rngkey, (batch_size, hidden_size))\n        y = jax.random.normal(rngkey, (batch_size, hidden_size))\n        batch = {'x': x, 'y': y}\n        state = create_train_state(rngkey, model, [x])\n\n        # Compile\n        serial_inference_step = mlp_inference_step\n\n        parallel_inference_step = parallelize(mlp_inference_step,\n                                              method=parallel_method,\n                                              donate_argnums=())\n        executable = parallel_inference_step.get_executable(state, batch)\n\n        # Run correctnesss test\n        serial_out = serial_inference_step(state, batch)\n        parallel_out = parallel_inference_step(state, batch)\n        assert_allclose(serial_out, parallel_out, 1e-3, 1e-3)\n\n    def run_bert_layer_collection_inference(self, manual_pipeline_layer,\n                                            parallel_method):\n        # Init model and optimizer\n        batch_size = 16\n        seq_len = 256\n        hidden_size = 512\n        num_heads = 512 // 64\n        n_layers = 4\n\n        model = FlaxBertLayerCollection(\n            config=BertConfig(hidden_size=hidden_size,\n                              intermediate_size=hidden_size * 4,\n                              num_attention_heads=num_heads,\n                              num_hidden_layers=n_layers,\n                              add_manual_pipeline_markers=manual_pipeline_layer,\n                              pipeline_mp_size=n_layers))\n        rngkey = jax.random.PRNGKey(0)\n        x = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size))\n        y = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size))\n        attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8)\n        batch = {\"x\": x, \"y\": y, \"attention_mask\": attention_mask}\n        state = create_train_state(rngkey, model, [x, attention_mask])\n\n        # Compile\n        serial_inference_step = bert_layer_collection_inference_step\n        parallel_inference_step = parallelize(\n            bert_layer_collection_inference_step,\n            method=parallel_method,\n            donate_argnums=())\n        executable = parallel_inference_step.get_executable(state, batch)\n\n        # Run correctnesss test\n        serial_out = serial_inference_step(state, batch)\n        parallel_out = parallel_inference_step(state, batch)\n        assert_allclose(serial_out, parallel_out, 1e-3, 1e-3)\n\n    def test_mlp(self):\n        method = PipeshardParallel(num_micro_batches=4,\n                                   pipeline_schedule=\"inference\",\n                                   layer_option=\"manual\")\n        self.run_mlp_inference(True, method)\n\n    def test_bert(self):\n        method = PipeshardParallel(num_micro_batches=4,\n                                   pipeline_schedule=\"inference\",\n                                   layer_option=\"manual\")\n        self.run_bert_layer_collection_inference(True, method)\n\n    def test_output(self):\n        method = PipeshardParallel(num_micro_batches=2,\n                                   pipeline_schedule=\"inference\",\n                                   layer_option=\"manual\")\n\n        @parallelize(method=method, batch_argnums=(0,))\n        def func(x):\n            a = jnp.ones_like(x) + x\n            mark_pipeline_boundary()\n            b = jnp.ones_like(x) * 2 + x\n            return a, b, 3\n\n        x = np.ones(32, dtype=np.float32)\n        a, b, c = func(x)\n\n        assert_allclose(a, np.ones(32) * 2)\n        assert_allclose(b, np.ones(32) * (2 + 1))\n        assert_allclose(c, 3)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipelineInferenceTest(\"test_mlp\"))\n    suite.addTest(PipelineInferenceTest(\"test_bert\"))\n    suite.addTest(PipelineInferenceTest(\"test_output\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_layer_construction.py",
    "content": "import unittest\n\nimport jax\nfrom alpa.testing import PipelineBasicTest\n\n\nclass LayerConstructionTest(PipelineBasicTest):\n\n    def test_mlp_layer_construction(self):\n        self.run_mlp(manual_pipeline_layer=False)\n\n    def test_2_layer_bert_layer_construction(self):\n        self.run_n_layer_bert(num_layers=2, manual_pipeline_layer=False)\n\n    @unittest.skipIf(jax.device_count('gpu') < 8, \"no enough device\")\n    def test_8_layer_bert_layer_construction(self):\n        self.run_n_layer_bert(num_layers=8, manual_pipeline_layer=False)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(LayerConstructionTest('test_mlp_layer_construction'))\n    suite.addTest(LayerConstructionTest('test_2_layer_bert_layer_construction'))\n    suite.addTest(LayerConstructionTest('test_8_layer_bert_layer_construction'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_manual_sharding.py",
    "content": "\"\"\"\nTest the manual sharding spec.\n\"\"\"\nimport itertools\nimport unittest\n\nimport jax\nfrom jax.experimental.pjit import PartitionSpec\nfrom jax.tree_util import tree_map\nimport jax.numpy as jnp\n\nimport alpa\nfrom alpa import (AutoShardingOption, ManualShardingOption, ManualStageOption,\n                  PipeshardParallel, mark_pipeline_boundary, parallelize)\nfrom alpa.testing import HloParser\n\n\nclass PipeshardManualShardingTest(unittest.TestCase):\n\n    def setUp(self):\n        alpa.init()\n        # use (1 * 4) mesh\n        alpa.set_global_virtual_physical_mesh(\n            alpa.get_global_cluster().get_virtual_physical_mesh([0], 4))\n\n    def tearDown(self):\n        alpa.shutdown()\n\n    def _get_fn_manual_sharding_with(self, fn, num_microbatches, stage_option,\n                                     ms_option, *args):\n        method = PipeshardParallel(\n            num_micro_batches=num_microbatches,\n            stage_option=stage_option,\n            manual_sharding_option=ms_option,\n            default_auto_sharding_option=AutoShardingOption(False))\n        parallelized = parallelize(fn, method=method)\n        return parallelized.get_executable(*args).get_hlo_text()\n\n    @staticmethod\n    def _is_superset_with_x_more(seq1, seq2, x):\n        set1 = set(seq1)\n        set2 = set(seq2)\n        if set1.issuperset(set2) and len(set1) - len(set2) == x:\n            return True\n        return False\n\n    def test_set_input_output(self):\n\n        def fn(params, batch):\n            x, tgt = batch\n\n            def loss_fn(params):\n                w0, b0, w1, b1, w2, b2, w3, b3 = params\n                y = jax.nn.relu(x @ w0 + b0)\n                z = jax.nn.relu(y @ w1 + b1)\n                mark_pipeline_boundary()\n                u = jax.nn.relu(z @ w2 + b2)\n                v = jax.nn.softmax(u @ w3 + b3)\n                return jnp.mean((v - tgt)**2)\n\n            grads = alpa.grad(loss_fn)(params)\n            new_params = tree_map(lambda p, g: p - g, params, grads)\n            return new_params\n\n        # data\n        batch_size = 64\n        hiddens = [6, 8, 10, 12, 14]\n        params = itertools.chain(*[(jnp.ones((hiddens[i], hiddens[i + 1])),\n                                    jnp.ones((hiddens[i + 1],)))\n                                   for i in range(len(hiddens) - 1)])\n        params = tuple(params)\n        x = jnp.ones((batch_size, hiddens[0]))\n        tgt = jnp.ones((batch_size, hiddens[-1]))\n        batch = (x, tgt)\n\n        # partitions\n        mp_start = PartitionSpec(None, \"model\")\n        mp_end = PartitionSpec(\"model\", None)\n        bias_partitioned = PartitionSpec(\"model\")\n        replicated = None\n        dp = PartitionSpec(\"data\", None)\n\n        param_axis_resources = (mp_start, bias_partitioned, mp_end,\n                                replicated) + (replicated, replicated,\n                                               replicated, replicated)\n        batch_axis_resources = (replicated, dp)\n        in_axis_resources = (param_axis_resources, batch_axis_resources)\n\n        # options\n        s_option = ManualStageOption([[0], [1]], [(1, 2)] * 2, [(1, 2)] * 2,\n                                     [{}] * 2)\n        submesh_axis_names = ((\"dummy\", \"model\"), (\"dummy\", \"data\"))\n        ms_option = ManualShardingOption(None, submesh_axis_names,\n                                         in_axis_resources)\n        text = self._get_fn_manual_sharding_with(fn, 2, s_option, ms_option,\n                                                 params, batch)\n        l0_fwd, l1_fwd, l1_bwd, l0_bwd, l0_apl, l1_apl = text\n        # layer 0\n        l0_param_shape = (\"f32[6,4]\", \"f32[4]\", \"f32[4,10]\", \"f32[10]\")\n        l0_batch_shape = (\"f32[32,6]\",)\n        l0_fwd_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l0_fwd))\n        assert sorted(l0_fwd_param) == sorted(l0_param_shape + l0_batch_shape)\n        l0_bwd_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l0_bwd))\n        l0_bwd_root = HloParser.parse_root_shapes(\n            HloParser.get_root_line(l0_bwd))\n        # the donated accumulated gradient are at first\n        assert sorted(l0_bwd_param[:4]) == sorted(l0_param_shape)\n        assert sorted(l0_bwd_root) == sorted(l0_param_shape)\n        l0_apl_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l0_apl))\n        l0_apl_root = HloParser.parse_root_shapes(\n            HloParser.get_root_line(l0_apl))\n        assert sorted(l0_apl_param) == sorted(l0_param_shape + l0_param_shape)\n        assert sorted(l0_apl_root) == sorted(l0_param_shape)\n\n        # layer 1\n        l1_param_shape = (\"f32[10,12]\", \"f32[12]\", \"f32[12,14]\", \"f32[14]\")\n        l1_batch_shape = (\"f32[16,14]\",)\n        l1_fwd_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l1_fwd))\n        assert self._is_superset_with_x_more(l1_fwd_param,\n                                             l1_param_shape + l1_batch_shape, 1)\n        l1_bwd_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l1_bwd))\n        l1_bwd_root = HloParser.parse_root_shapes(\n            HloParser.get_root_line(l1_bwd))\n        # the donated accumulated gradient are at first\n        assert sorted(l1_bwd_param[:4]) == sorted(l1_param_shape)\n        assert self._is_superset_with_x_more(l1_bwd_root, l1_param_shape, 1)\n        l1_apl_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l1_apl))\n        l1_apl_root = HloParser.parse_root_shapes(\n            HloParser.get_root_line(l1_apl))\n        assert sorted(l1_apl_param) == sorted(l1_param_shape + l1_param_shape)\n        assert sorted(l1_apl_root) == sorted(l1_param_shape)\n\n    def test_set_intermediate(self):\n\n        def fn(params, batch):\n            x, tgt = batch\n\n            def loss_fn(params):\n                w0, b0, w1, b1, w2, b2, w3, b3 = params\n                y = jax.nn.relu(x @ w0 + b0)\n                z = jax.nn.relu(y @ w1 + b1)\n                mark_pipeline_boundary()\n                u = jax.nn.relu(z @ w2 + b2)\n                v = jax.nn.softmax(u @ w3 + b3)\n                return jnp.mean((v - tgt)**2)\n\n            grads = alpa.grad(loss_fn)(params)\n            new_params = tree_map(lambda p, g: p - g, params, grads)\n            return new_params\n\n        # data\n        batch_size = 64\n        hiddens = [6, 8, 10, 12, 14]\n        params = itertools.chain(*[(jnp.ones((hiddens[i], hiddens[i + 1])),\n                                    jnp.ones((hiddens[i + 1],)))\n                                   for i in range(len(hiddens) - 1)])\n        params = tuple(params)\n        x = jnp.ones((batch_size, hiddens[0]))\n        tgt = jnp.ones((batch_size, hiddens[-1]))\n        batch = (x, tgt)\n\n        # partitions\n        mp_start = PartitionSpec(None, \"model\")\n        mp_end = PartitionSpec(\"model\", None)\n        bias_partitioned = PartitionSpec(\"model\")\n        replicated = None\n        dp = PartitionSpec(\"data\", None)\n\n        param_axis_resources = (mp_start, bias_partitioned, mp_end,\n                                replicated) + (replicated, replicated,\n                                               replicated, replicated)\n        # We don't set target sharded here. Otherwise it gives hint for the spmd\n        # partitioner.\n        batch_axis_resources = (replicated, replicated)\n        in_axis_resources = (param_axis_resources, batch_axis_resources)\n        s_option = ManualStageOption([[0], [1]], [(1, 2)] * 2, [(1, 2)] * 2,\n                                     [{}] * 2)\n        submesh_axis_names = ((\"dummy\", \"model\"), (\"dummy\", \"data\"))\n        pipeline_intermediate_axes = ((\"data\", 0),)\n        ms_option = ManualShardingOption(\n            None,\n            submesh_axis_names,\n            in_axis_resources,\n            pipeline_intermediate_axes=pipeline_intermediate_axes)\n        text = self._get_fn_manual_sharding_with(fn, 2, s_option, ms_option,\n                                                 params, batch)\n        # Layer 1. It should have the correct intermediate shape.\n        l0_fwd, l1_fwd, l1_bwd, l0_bwd, _, _ = text\n        l1_param_shape = (\"f32[10,12]\", \"f32[12]\", \"f32[12,14]\", \"f32[14]\")\n        intermediate_sharded = (\"f32[16,10]\",)\n        l1_fwd_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l1_fwd))\n        assert self._is_superset_with_x_more(\n            l1_fwd_param, intermediate_sharded + l1_param_shape, 1)\n\n        l1_bwd_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l1_bwd))\n        l1_bwd_root = HloParser.parse_root_shapes(\n            HloParser.get_root_line(l1_bwd))\n        # the donated accumulated gradient are at first\n        assert sorted(l1_bwd_param[:4]) == sorted(l1_param_shape)\n        assert sorted(l1_bwd_root) == sorted(intermediate_sharded +\n                                             l1_param_shape)\n\n        # Layer 0. It should not have any data parallelization.\n        l0_param_shape = (\"f32[6,4]\", \"f32[4]\", \"f32[4,10]\", \"f32[10]\")\n        l0_batch_shape = (\"f32[32,6]\",)\n        intermediate_replicated = (\"f32[32,10]\")\n        l0_fwd_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l0_fwd))\n        l0_fwd_root = HloParser.parse_root_shapes(\n            HloParser.get_root_line(l0_fwd))\n        assert sorted(l0_fwd_param) == sorted(l0_param_shape + l0_batch_shape)\n        l0_bwd_param = HloParser.parse_param_shapes(\n            HloParser.get_param_line(l0_bwd))\n        l0_bwd_root = HloParser.parse_root_shapes(\n            HloParser.get_root_line(l0_bwd))\n        # the donated accumulated gradient are at first\n        assert sorted(l0_bwd_param[:4]) == sorted(l0_param_shape)\n        assert sorted(l0_bwd_root) == sorted(l0_param_shape)\n\n        assert intermediate_replicated in l0_bwd_param\n        assert intermediate_replicated in l0_fwd_root\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipeshardManualShardingTest(\"test_set_input_output\"))\n    suite.addTest(PipeshardManualShardingTest(\"test_set_intermediate\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_mlp.py",
    "content": "import unittest\nimport os\n\nimport jax\nimport jax.numpy as jnp\nimport optax\nimport ray\n\nfrom alpa import init, parallelize, PipeshardParallel\nfrom alpa.model.model_util import TrainState\nfrom alpa.parallel_method import LocalPipelineParallel\nfrom alpa.pipeline_parallel.layer_construction import manual_layer_construction\nfrom alpa.testing import MLPModel, assert_allclose\n\n\nclass PipelineMLPTest(unittest.TestCase):\n\n    def setUp(self):\n        os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\"\n\n    def train_2_layer_mlp(self, method):\n\n        def train_step(state, batch):\n\n            @manual_layer_construction\n            def loss_func(params, x, y):\n                out = state.apply_fn(params, x)\n                # test constant handling\n                out = out + jnp.array(range(batch_size)).reshape((-1, 1))\n                loss = jnp.mean((out - y)**2)\n                return loss\n\n            # Note, we can only use jax.grad in this testcase.\n            # TODO: Fix https://github.com/alpa-projects/alpa/issues/560\n            grads = jax.grad(loss_func)(state.params, batch[\"x\"], batch[\"y\"])\n            return grads\n\n        batch_size = 64\n        hidden_size = 1024\n\n        x = jnp.ones((batch_size, hidden_size))\n        y = jnp.ones((batch_size, hidden_size))\n\n        # Init model and optimizer\n        model = MLPModel(num_layers=4,\n                         hidden_size=hidden_size,\n                         add_manual_pipeline_marker=True)\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x)\n        tx = optax.sgd(learning_rate=1e-2)\n        state = TrainState.create(apply_fn=model.apply,\n                                  params=params,\n                                  tx=tx,\n                                  dynamic_scale=None)\n\n        # Train step\n        batch = {\"x\": x, \"y\": y}\n        gradients = train_step(state, batch)\n        p_train_step = parallelize(train_step, donate_argnums=(), method=method)\n        gradients_with_pipeline = p_train_step(state, batch)\n\n        # Check results\n        assert_allclose(gradients, gradients_with_pipeline)\n\n        # Check debug utilities\n        if isinstance(method, PipeshardParallel):\n            executable = p_train_step.get_last_executable()\n            executable.dump_debug_info(\"tmp\")\n\n    def test_2_layer_mlp_local_pipeline_parallel(self):\n        self.train_2_layer_mlp(LocalPipelineParallel())\n\n    def test_2_layer_mlp_pipeshard_parallel(self):\n        init(cluster=\"ray\")\n        self.train_2_layer_mlp(PipeshardParallel(layer_option=\"manual\"))\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipelineMLPTest(\"test_2_layer_mlp_local_pipeline_parallel\"))\n    suite.addTest(PipelineMLPTest(\"test_2_layer_mlp_pipeshard_parallel\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_multi_graph.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport numpy as np\nimport unittest\n\nfrom alpa import init, parallelize, global_config, PipeshardParallel\nfrom alpa.testing import assert_allclose, get_mlp_train_state_and_step\n\n\nclass MultipleGraphRuntimeTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def run_2_mlp(self, use_value_and_grad=False, stage_option=\"uniform\"):\n\n        def test_one_mlp(method, batch_size=64, hidden_size=16):\n            state, batch, train_step = get_mlp_train_state_and_step(\n                batch_size=batch_size,\n                hidden_size=hidden_size,\n                add_manual_pipeline_marker=True)\n\n            # Compile\n            serial_train_step = train_step\n            parallel_train_step = parallelize(train_step, method=method)\n            executable = parallel_train_step.get_executable(state, batch)\n\n            # Run and check\n            expected_new_state, expected_val = serial_train_step(state, batch)\n            actual_new_state, actual_val = parallel_train_step(state, batch)\n\n            assert_allclose(expected_new_state.params, actual_new_state.params,\n                            1e-3, 1e-3)\n            assert_allclose(expected_val, actual_val, 1e-3, 1e-3)\n\n            return executable\n\n        method = PipeshardParallel(num_micro_batches=2,\n                                   stage_option=stage_option,\n                                   layer_option=\"manual\")\n        executable = test_one_mlp(method)\n        executable_2 = test_one_mlp(method)\n\n        assert executable != executable_2\n\n    def test_2_mlp(self):\n        self.run_2_mlp()\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(MultipleGraphRuntimeTest('test_2_mlp'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_old_dp_vs_new_dp.py",
    "content": "import unittest\nimport numpy as np\n\nfrom alpa.pipeline_parallel.stage_construction import (get_submesh_choices,\n                                                       training_dp as dp,\n                                                       training_dp_2 as dp_2)\n\n\ndef default_num_auto_sharding_configs(num_devices):\n    num_autosharding_configs = 0\n    for i in range(1, num_devices + 1):\n        if num_devices % i == 0:\n            num_autosharding_configs += 1\n    return num_autosharding_configs\n\n\ndef generate_stage_construction_test_case(num_devices,\n                                          submesh_choices,\n                                          num_layers,\n                                          num_autosharding_configs,\n                                          compute_cost_factor=0.0,\n                                          device_memory_size_factor=1.0,\n                                          unlimited_memory=False):\n    \"\"\"\n    Generate a test case for stage construction.\n    Args:\n        num_devices: number of total devices.\n        submesh_choices: a list of submesh choices.\n        num_layers: number of layers.\n        num_autosharding_configs: number of auto sharding configs.\n        compute_cost_factor: a factor to control the distributed compute cost.\n            Take values in [-inf, inf].\n        device_memory_size_factor: a factor to control the device memory size.\n            Take values in [0, inf].\n        unlimited_memory: ignore memory cost.\n    \"\"\"\n    num_submesh_choices = len(submesh_choices)\n    compute_cost = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        np.inf)\n    max_n_succ_stages = np.full(\n        (num_layers, num_layers, num_submesh_choices, num_autosharding_configs),\n        -1)\n    layer_base_cost = np.random.rand(num_layers)\n    memory_base_cost = np.random.rand(num_layers)\n    total_memory = memory_base_cost.sum()\n    for start in range(num_layers):\n        for end in range(start, num_layers):\n            for s, submesh in enumerate(submesh_choices):\n                submesh_size = np.prod(submesh)\n                for l in range(num_autosharding_configs):\n                    autosharding_factor = np.random.rand() + 1\n                    compute_cost[start, end, s,\n                                 l] = (layer_base_cost[start:end + 1].sum() *\n                                       autosharding_factor *\n                                       submesh_size**compute_cost_factor)\n                    if unlimited_memory:\n                        max_n_succ_stages[start, end, s, l] = 4096\n                    else:\n                        model_percentage = (\n                            memory_base_cost[start:end + 1].sum() /\n                            total_memory)\n                        device_percentage = submesh_size / num_devices\n                        max_n_succ_stages[start, end, s,\n                                          l] = (device_memory_size_factor *\n                                                num_layers * device_percentage /\n                                                model_percentage /\n                                                autosharding_factor)\n\n    return compute_cost, max_n_succ_stages\n\n\nclass OldNewDPTest(unittest.TestCase):\n    \"\"\"Test the equivalence of old DP and new DP.\"\"\"\n\n    def test_dp(self):\n        num_runs = 2\n        np.random.seed(0)\n\n        for num_layers in [4, 8]:\n            for num_hosts in [1, 4]:\n                for num_devices_per_host in [1, 4]:\n                    submesh_choices = get_submesh_choices(\n                        num_hosts, num_devices_per_host, \"all\")\n                    for num_micro_batches in [1, 16, 512]:\n                        for i in range(num_runs):\n                            compute_cost_factor = np.random.rand() * 4 - 2\n                            device_memory_size_factor = np.random.rand() * 4\n                            num_devices = num_hosts * num_devices_per_host\n                            num_autosharding_configs = np.random.randint(1, 5)\n                            (compute_cost, max_n_succ_stages\n                            ) = generate_stage_construction_test_case(\n                                num_devices, submesh_choices, num_layers,\n                                num_autosharding_configs, compute_cost_factor,\n                                device_memory_size_factor)\n\n                            res_old = dp(num_layers, num_devices,\n                                         num_micro_batches, submesh_choices,\n                                         num_autosharding_configs, compute_cost,\n                                         max_n_succ_stages)\n\n                            res_new = dp_2(num_devices, num_micro_batches,\n                                           submesh_choices, compute_cost,\n                                           max_n_succ_stages)\n                            assert res_new == res_old\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(unittest.makeSuite(OldNewDPTest))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_pipeline_marker.py",
    "content": "import unittest\n\nimport numpy as np\nimport jax\nfrom jax.lib import xla_client as xc, xla_bridge as xb\nimport jax.numpy as jnp\n\nfrom alpa.pipeline_parallel.primitive_def import xla_custom_call, pipeline_p\nfrom alpa.testing import assert_allclose\n\nops = xc.ops\n\n\nclass PipelineMarkerTest(unittest.TestCase):\n\n    def setUp(self):\n        np.random.seed(1337)\n\n    def test_xla_graph(self):\n        c = xc.XlaBuilder(\"xla_graph_with_marker\")\n\n        parameter_shape = xc.Shape.array_shape(np.dtype(np.float32), (10, 8),\n                                               (0, 1))\n        x = ops.Parameter(c, 0, parameter_shape)\n        y = ops.Parameter(c, 1, parameter_shape)\n\n        backend = xb.get_backend(\"gpu\")\n\n        a = ops.Add(x, y)\n        b = ops.Mul(x, y)\n\n        output_tuple = xla_custom_call(c, \"pipeline_marker\", \"1$start\", a, b)\n        a = ops.GetTupleElement(output_tuple, 0)\n        b = ops.GetTupleElement(output_tuple, 1)\n\n        z = ops.Add(a, b)\n        output_tuple = xla_custom_call(c, \"pipeline_marker\", \"1$end\", z)\n        z = ops.GetTupleElement(output_tuple, 0)\n\n        c = c.build(z)\n        compiled_c = backend.compile(c)\n\n        x_np = np.random.rand(10, 8).astype(np.float32)\n        y_np = np.random.rand(10, 8).astype(np.float32)\n\n        x = backend.buffer_from_pyval(x_np)\n        y = backend.buffer_from_pyval(y_np)\n        z, = compiled_c.execute([x, y])\n\n        a_np = x_np + y_np\n        b_np = x_np * y_np\n        z_np = a_np + b_np\n\n        assert_allclose(z, z_np)\n\n    def test_jax_graph(self):\n        x_np = np.random.rand(10, 8).astype(np.float32)\n        y_np = np.random.rand(10, 8).astype(np.float32)\n        a_np = x_np + y_np\n        b_np = x_np * y_np\n        z_np = a_np + b_np\n\n        def f(x, y):\n            a = x + y\n            b = x * y\n            a, b = pipeline_p.bind(a, b, name=\"1\", mark_type=\"start\")\n            z = a + b\n            z, = pipeline_p.bind(z, name=\"1\", mark_type=\"end\")\n            return z\n\n        z_without_jit = f(x_np, y_np)\n        f = jax.jit(f)\n        z_with_jit = f(x_np, y_np)\n        assert_allclose(z_with_jit, z_np)\n        assert_allclose(z_without_jit, z_np)\n\n    def test_transpose(self):\n\n        def f(x):\n            x, = pipeline_p.bind(x, name=\"1\", mark_type=\"start\")\n            x = jnp.transpose(x, axes=(1, 0))\n            return x\n\n        x = np.random.rand(2, 4)\n        no_jit_result = f(x)\n        jit_result = jax.jit(f)(x)\n        assert_allclose(no_jit_result, jit_result)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipelineMarkerTest(\"test_xla_graph\"))\n    suite.addTest(PipelineMarkerTest(\"test_jax_graph\"))\n    suite.addTest(PipelineMarkerTest(\"test_transpose\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_reduce_scatter.py",
    "content": "import unittest\n\nfrom alpa.shard_parallel.auto_sharding import AutoShardingOption\nfrom alpa.testing import PipelineBasicTest\nfrom alpa.util import count_communication_primitives\n\n\nclass PipelineReduceScatterTest(PipelineBasicTest):\n\n    def test_mlp_grad_acc_friendly(self):\n        as_option = AutoShardingOption(force_data_parallel=True,\n                                       prefer_reduce_scatter=True)\n        hlo_text = self.run_mlp(as_option=as_option)\n\n        # Check number of communication primitives\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[0],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == 0\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[1],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == 0\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[2],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == n_all_reduce == 1\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[3],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == n_all_reduce == 1\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[4],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == n_all_gather == 1\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[5],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == n_all_gather == 1\n\n    def test_bert_grad_acc_friendly(self):\n        as_option = AutoShardingOption(force_data_parallel=True,\n                                       prefer_reduce_scatter=True)\n        hlo_text = self.run_n_layer_bert(num_layers=2, as_option=as_option)\n\n        # Check numbers of communication primitives\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[0],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == 0\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[1],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == 0\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[2],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == n_all_reduce == 1\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[3],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == n_all_reduce == 1\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[4],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == n_all_gather == 1\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_text[5],\n                                           ignore_scalar_all_reduce=True))\n        assert n_total == n_all_gather == 1\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipelineReduceScatterTest('test_mlp_grad_acc_friendly'))\n    suite.addTest(PipelineReduceScatterTest('test_bert_grad_acc_friendly'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_remat.py",
    "content": "import unittest\n\nimport jax\nfrom alpa.testing import PipelineBasicTest\n\n\nclass PipelineRematTest(PipelineBasicTest):\n\n    def test_mlp_remat(self):\n        self.run_mlp(use_remat=True)\n\n    def test_2_layer_bert_remat(self):\n        self.run_n_layer_bert(num_layers=2, use_remat=True)\n\n    def test_2_layer_bert_auto_layer_slicing_remat(self):\n        self.run_n_layer_bert(num_layers=2,\n                              manual_pipeline_layer=False,\n                              use_remat=True)\n\n    @unittest.skipIf(jax.local_device_count(\"gpu\") < 8, \"no enough device\")\n    def test_8_layer_bert_auto_layer_slicing_remat(self):\n        self.run_n_layer_bert(num_layers=8,\n                              manual_pipeline_layer=False,\n                              use_remat=True)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipelineRematTest('test_mlp_remat'))\n    suite.addTest(PipelineRematTest('test_2_layer_bert_remat'))\n    suite.addTest(\n        PipelineRematTest('test_2_layer_bert_auto_layer_slicing_remat'))\n    suite.addTest(\n        PipelineRematTest('test_8_layer_bert_auto_layer_slicing_remat'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_scatter_gather.py",
    "content": "import unittest\n\nfrom alpa.device_mesh import (get_global_cluster,\n                              set_global_virtual_physical_mesh)\nfrom alpa.pipeline_parallel.stage_construction import ManualStageOption\nfrom alpa.testing import PipelineBasicTest\n\n\nclass ScatterGatherTest(PipelineBasicTest):\n\n    def test_2_layer_bert(self):\n        virtual_mesh = get_global_cluster().get_virtual_physical_mesh([0], 4)\n        set_global_virtual_physical_mesh(virtual_mesh)\n\n        stage_option = ManualStageOption(\n            forward_stage_layer_ids=[[0], [1]],\n            submesh_physical_shapes=[(1, 2), (1, 2)],\n            submesh_logical_shapes=[(1, 2), (2, 1)],\n            submesh_autosharding_option_dicts=[\n                dict(force_batch_dim_to_mesh_dim=0), {}\n            ])\n\n        self.run_n_layer_bert(num_layers=2,\n                              batch_size=4,\n                              seq_len=4,\n                              hidden_size=4,\n                              num_heads=1,\n                              stage_option=stage_option)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(ScatterGatherTest('test_2_layer_bert'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_schedules.py",
    "content": "import unittest\n\nfrom alpa.pipeline_parallel.schedules import (gen_linear_pipeline_dependency,\n                                              GpipeSchedule, PipeDreamFlush)\n\n\nclass PipelineScheduleTest(unittest.TestCase):\n\n    def run_schedule_basics(self, schedule_type, num_stage, num_mesh,\n                            num_batch):\n        deps = gen_linear_pipeline_dependency(num_stage)\n        meshes = [None] * num_mesh\n        num_fwd_stage = num_stage // 2\n        apply_grad_placement = {num_stage + i: i for i in range(num_fwd_stage)}\n        if schedule_type == \"gpipe\":\n            schedule_cls = GpipeSchedule\n        elif schedule_type == \"1f1b\":\n            schedule_cls = PipeDreamFlush\n        else:\n            print(\"unrecognized type of schedule.\")\n            return\n\n        s = schedule_cls(dependency=deps,\n                         meshes=meshes,\n                         apply_grad_placement=apply_grad_placement,\n                         num_batch=num_batch)\n\n        # check num_clock\n        assert s.num_clock == (num_mesh + num_batch - 1) * 2 + 1, (\n            \"clock number wrong.\")\n\n        # check no stage is on > 1 meshes\n        for i in range(num_stage):\n            mesh_indices = s.stage_placement(i)\n            assert len(mesh_indices) == 1, (\n                \"we only support each stage placed on one mesh.\")\n\n        # check no mesh owns > 3 stages (forward, backward, apply_grad)\n        for i in range(num_mesh):\n            stage_indices = s.mesh_placement(i)\n            assert len(stage_indices) == 3, (\n                \"One mesh at most owns three stages: forward, backward,\"\n                \" and apply_grad stages.\")\n            stage_indices_list = list(stage_indices)\n            stage_indices_list.sort()\n            f, b, a = stage_indices_list[0], stage_indices_list[\n                1], stage_indices_list[2]\n            assert f == 2 * num_mesh - 1 - b\n            assert a == num_stage + f\n\n    def run_1f1b(self, num_stage, num_mesh, num_batch):\n        deps = gen_linear_pipeline_dependency(num_stage)\n        meshes = [None] * num_mesh\n        num_fwd_stage = num_stage // 2\n        apply_grad_placement = {num_stage + i: i for i in range(num_fwd_stage)}\n        s = PipeDreamFlush(dependency=deps,\n                           meshes=meshes,\n                           apply_grad_placement=apply_grad_placement,\n                           num_batch=num_batch)\n\n        # test the in-flight microbatches <= num_mesh\n        in_flight = [0 for _ in range(num_mesh)]\n        max_in_flight = [0 for _ in range(num_mesh)]\n        for sched in s.schedules:\n            for mesh_idx, task in enumerate(sched):\n                if task:\n                    batch_idx, stage_idx = task\n                    if stage_idx < num_stage / 2:\n                        in_flight[mesh_idx] += 1\n                    if stage_idx < num_stage and stage_idx >= num_stage / 2:\n                        in_flight[mesh_idx] -= 1\n                    if in_flight[mesh_idx] > max_in_flight[mesh_idx]:\n                        max_in_flight[mesh_idx] = in_flight[mesh_idx]\n\n        for i in range(num_mesh):\n            assert max_in_flight[i] <= num_mesh - i, (\n                \"max number of in-flight is incorrect.\")\n\n    def test_schedules(self):\n        schedule_types = [\"gpipe\", \"1f1b\"]\n        num_stages = [4, 6, 8, 12, 16, 32, 64]\n        num_batches = [1, 2, 4, 8, 16, 32, 64, 128]\n        for schedule_type in schedule_types:\n            for num_stage in num_stages:\n                for num_batch in num_batches:\n                    num_mesh = num_stage // 2\n                    #print(\n                    #    \"Testing case: type {}, num_stage {}, num_mesh {}, num_batch {}.\"\n                    #    .format(schedule_type, num_stage, num_mesh, num_batch))\n                    self.run_schedule_basics(schedule_type, num_stage, num_mesh,\n                                             num_batch)\n                    if schedule_type == \"1f1b\":\n                        self.run_1f1b(num_stage, num_mesh, num_batch)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(PipelineScheduleTest(\"test_schedules\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_set_input_shard.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport unittest\n\nfrom alpa import init, parallelize, AutoShardingOption, PipeshardParallel\nfrom alpa.testing import MLPModel\n\n\nclass SetInputShardSpecTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def run_set_input_shard_spec(self):\n        hidden_size = 64\n\n        rngkey = jax.random.PRNGKey(0)\n\n        # Make a MLP model with 2 pipeline stages.\n        model = MLPModel(num_layers=4,\n                         hidden_size=hidden_size,\n                         add_manual_pipeline_marker=True)\n        data = jax.core.ShapedArray((1, hidden_size), jnp.float32)\n        params = jax.eval_shape(model.init, rngkey, data)\n        params = jax.tree_map(\n            lambda x: jax.ShapeDtypeStruct(x.shape, jnp.float32), params)\n\n        def infer_fn(params, batch):\n            return model.apply(params, batch[\"x\"])\n\n        method = PipeshardParallel(\n            num_micro_batches=1,\n            pipeline_schedule=\"inference\",\n            layer_option=\"manual\",\n            default_auto_sharding_option=AutoShardingOption(\n                force_batch_dim_to_mesh_dim=None,\n                allow_all_to_all=False,\n                allow_all_gather=False,\n            ))\n\n        # Compile with batch size 1\n        executable_1 = parallelize(\n            infer_fn, batch_argnums=(1,), method=method).get_executable(\n                params,\n                {\"x\": jax.core.ShapedArray((1, hidden_size), jnp.float32)})\n\n        # Make another parallel method with the same input shard spec.\n        method_with_input_shard = PipeshardParallel(\n            num_micro_batches=1,\n            pipeline_schedule=\"inference\",\n            layer_option=\"manual\",\n            default_auto_sharding_option=AutoShardingOption(\n                force_batch_dim_to_mesh_dim=None,\n                allow_all_to_all=False,\n                allow_all_gather=False,\n            ),\n            stage_input_shardings=executable_1.stage_input_shard_specs)\n\n        # Compile with a different batch size\n        executable_2 = parallelize(\n            infer_fn, batch_argnums=(1,), method=method).get_executable(\n                params,\n                {\"x\": jax.core.ShapedArray((8, hidden_size), jnp.float32)})\n\n        # Compile with a different batch size but the same input shard specs\n        executable_3 = parallelize(\n            infer_fn, batch_argnums=(1,),\n            method=method_with_input_shard).get_executable(\n                params,\n                {\"x\": jax.core.ShapedArray((8, hidden_size), jnp.float32)})\n\n        assert executable_2.stage_input_shard_specs != executable_3.stage_input_shard_specs\n        assert executable_1.stage_input_shard_specs == executable_3.stage_input_shard_specs\n\n    def test_set_input_shard_spec(self):\n        self.run_set_input_shard_spec()\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(SetInputShardSpecTest('test_set_input_shard_spec'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_stage_construction.py",
    "content": "import unittest\n\nfrom alpa.pipeline_parallel.stage_construction import AutoStageOption\nfrom alpa.testing import PipelineBasicTest\n\n\ndef auto_stage():\n    return AutoStageOption(submesh_physical_shape_space=\"small_power_of_two\",\n                           submesh_logical_shape_space=\"same_as_physical\")\n\n\nclass StageConstructionTest(PipelineBasicTest):\n\n    def test_mlp_stage_construction(self):\n        self.run_mlp(stage_option=auto_stage())\n\n    def test_mlp_layer_and_stage(self):\n        self.run_mlp(manual_pipeline_layer=False, stage_option=auto_stage())\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(StageConstructionTest('test_mlp_stage_construction'))\n    suite.addTest(StageConstructionTest('test_mlp_layer_and_stage'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_stage_construction_slow.py",
    "content": "import unittest\n\nfrom alpa.pipeline_parallel.stage_construction import AutoStageOption\nfrom alpa.testing import PipelineBasicTest\n\n\ndef auto_stage():\n    return AutoStageOption(submesh_physical_shape_space=\"small_power_of_two\",\n                           submesh_logical_shape_space=\"same_as_physical\")\n\n\nclass StageConstructionSlowTest(PipelineBasicTest):\n\n    def test_mlp_stage_construction(self):\n        self.run_mlp(stage_option=auto_stage())\n\n    def test_mlp_layer_and_stage(self):\n        self.run_mlp(manual_pipeline_layer=False, stage_option=auto_stage())\n\n    def test_2_layer_bert_stage_construction(self):\n        self.run_n_layer_bert(num_layers=2, stage_option=auto_stage())\n\n    def test_2_layer_bert_layer_and_stage(self):\n        self.run_n_layer_bert(num_layers=2,\n                              manual_pipeline_layer=False,\n                              stage_option=auto_stage())\n\n    def test_8_layer_bert_stage_construction(self):\n        self.run_n_layer_bert(num_layers=8, stage_option=auto_stage())\n\n    def test_8_layer_bert_layer_and_stage(self):\n        self.run_n_layer_bert(num_layers=8,\n                              manual_pipeline_layer=False,\n                              stage_option=auto_stage())\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(StageConstructionSlowTest('test_mlp_stage_construction'))\n    suite.addTest(StageConstructionSlowTest('test_mlp_layer_and_stage'))\n    suite.addTest(\n        StageConstructionSlowTest('test_2_layer_bert_stage_construction'))\n    suite.addTest(\n        StageConstructionSlowTest('test_2_layer_bert_layer_and_stage'))\n    suite.addTest(\n        StageConstructionSlowTest('test_8_layer_bert_stage_construction'))\n    suite.addTest(\n        StageConstructionSlowTest('test_8_layer_bert_layer_and_stage'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_stage_construction_util.py",
    "content": "import unittest\nfrom typing import Sequence\n\nfrom jax._src.api import make_jaxpr\nfrom jax.core import ClosedJaxpr, Var, gensym\nimport jax.numpy as jnp\n\nfrom alpa import init, grad, parallelize, PipeshardParallel\nfrom alpa.device_mesh import get_global_virtual_physical_mesh\nfrom alpa.pipeline_parallel.stage_construction import (\n    AutoStageOption, get_one_submesh_autosharding_config_choices)\nfrom alpa.pipeline_parallel.compile_executable import (\n    split_and_process_layers, slice_apply_grad_for_stage_construction)\nfrom alpa.pipeline_parallel.layer_construction import ManualLayerOption\nfrom alpa.pipeline_parallel.stage_profiling import (\n    generate_stage_info, distributed_profile_on_mesh,\n    get_merged_stages_memory_stats)\nfrom alpa.shard_parallel.auto_sharding import AutoShardingOption\nfrom alpa.testing import (get_bert_layer_train_state_and_step,\n                          get_mlp_train_state_and_step)\nfrom alpa.util import GradFuncTransformContext\n\n\ndef _aval_key(a):\n    return (a.shape, repr(a.dtype))\n\n\ndef _assert_avals_allmatch(aval_seq_a, aval_seq_b):\n    assert len(aval_seq_a) == len(\n        aval_seq_b), f\"{len(aval_seq_a)} != {len(aval_seq_b)}\"\n    aval_seq_a = sorted(aval_seq_a, key=_aval_key)\n    aval_seq_b = sorted(aval_seq_b, key=_aval_key)\n    for a, b in zip(aval_seq_a, aval_seq_b):\n        assert a.shape == b.shape and a.dtype == b.dtype\n\n\nclass StageConstructUtilTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\", num_nodes=1, num_devices_per_node=1)\n\n    def create_bert_layers(self, num_layers, num_microbatch):\n        batch_size = 16\n        state, batch, _ = get_bert_layer_train_state_and_step(\n            batch_size=batch_size,\n            seq_len=256,\n            num_layers=num_layers,\n            hidden_size=512,\n            num_heads=512 // 64,\n            clip_by_global_norm=False,\n            use_dynamic_scale=False,\n            add_manual_pipeline_marker=True,\n        )\n\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"],\n                                     batch[\"attention_mask\"])\n                loss = jnp.mean((out - batch[\"y\"])**2)\n                return loss\n\n            grads = grad(loss_func)(state.params)\n            new_state = state.apply_gradients(grads=grads)\n            return new_state\n\n        microbatch_size = batch_size // num_microbatch\n        micro_batch = {k: v[:microbatch_size] for k, v in batch.items()}\n        return train_step, state, batch, micro_batch\n\n    def create_mlp(self, num_microbatch, add_marker=True):\n        batch_size = 16\n        state, batch, train_step = get_mlp_train_state_and_step(\n            batch_size=batch_size,\n            hidden_size=512,\n            num_layers=4,\n            use_bias=False,\n            add_manual_pipeline_marker=add_marker)\n\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"])\n                return jnp.mean((out - batch[\"y\"])**2)\n\n            grads = grad(loss_func)(state.params)\n            new_state = state.apply_gradients(grads=grads)\n            return new_state\n\n        microbatch_size = batch_size // num_microbatch\n        micro_batch = {k: v[:microbatch_size] for k, v in batch.items()}\n        return train_step, state, batch, micro_batch\n\n    def get_train_step_jaxpr(self,\n                             train_step,\n                             state,\n                             batch,\n                             micro_batch,\n                             use_remat=False):\n        # Compile\n        with GradFuncTransformContext(ManualLayerOption(use_remat).transform):\n            closed_jaxpr, output_tree = make_jaxpr(train_step,\n                                                   return_shape=True)(\n                                                       state, micro_batch)\n            full_batch_closed_jaxpr, full_batch_output_tree = make_jaxpr(\n                train_step, return_shape=True)(state, batch)\n\n        num_params = len(closed_jaxpr.jaxpr.invars) - len(batch)\n        donated_invars = [True] * num_params + [False] * len(batch)\n        return closed_jaxpr, full_batch_closed_jaxpr, donated_invars\n\n    def pre_process_jaxpr(self, closed_jaxpr: ClosedJaxpr,\n                          full_batch_closed_jaxpr: ClosedJaxpr,\n                          num_microbatch: int, donated_invars: Sequence[bool]):\n        inference_mode = False\n        gensym_func = gensym([closed_jaxpr.jaxpr])\n        global_invars = closed_jaxpr.jaxpr.invars\n\n        (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr,\n         microbatch_bound, reduction_vector, post_microbatch_bound,\n         accumulator_mapping, acc_grad_invars,\n         acc_grad_outvars) = (split_and_process_layers(closed_jaxpr,\n                                                       full_batch_closed_jaxpr,\n                                                       num_microbatch,\n                                                       inference_mode,\n                                                       gensym_func))\n\n        (jax_apply_layers,\n         apply_grad_global_info) = slice_apply_grad_for_stage_construction(\n             jax_pipeline_layers, apply_grad_jaxpr, microbatch_bound,\n             global_invars, global_outvars, donated_invars, accumulator_mapping,\n             gensym_func, inference_mode)\n\n        return (closed_jaxpr, global_outvars, jax_pipeline_layers,\n                apply_grad_jaxpr, microbatch_bound, reduction_vector,\n                post_microbatch_bound, accumulator_mapping, acc_grad_invars,\n                acc_grad_outvars, jax_apply_layers, apply_grad_global_info)\n\n    def generate_profile_result(self, jax_pipeline_layers, accumulator_mapping,\n                                acc_grad_invars, acc_grad_outvars,\n                                jax_apply_layers, apply_grad_global_info,\n                                num_micro_batches, start_index, end_index):\n        virtual_mesh = get_global_virtual_physical_mesh()\n        submesh = (1, 1)\n        virtual_submesh = virtual_mesh.slice_2d(tuple(range(\n            submesh[0])), (tuple(range(submesh[1])),) * submesh[0])\n        auto_sharding_config = get_one_submesh_autosharding_config_choices(\n            virtual_submesh, \"same_as_physical\", batch_size=None)[0]\n\n        assert len(jax_pipeline_layers) % 2 == 0\n        num_layers = len(jax_pipeline_layers) // 2\n        indices = list(range(2 * num_layers))\n\n        forward_layer_indices = indices[start_index:end_index + 1]\n        backward_layer_indices = indices[2 * num_layers - end_index -\n                                         1:2 * num_layers - start_index]\n        selected_apply_grad_layers = [\n            jax_apply_layers[idx]\n            for idx in forward_layer_indices\n            if jax_apply_layers[idx] is not None\n        ]\n\n        stage_config = generate_stage_info(\n            jax_pipeline_layers,\n            [forward_layer_indices, backward_layer_indices],\n            accumulator_mapping, acc_grad_invars, acc_grad_outvars,\n            \"test_stage\", selected_apply_grad_layers, apply_grad_global_info)\n\n        stage_index = 0\n        stage = (stage_index, stage_config, auto_sharding_config)\n\n        profile_results = {}\n        default_as_option = AutoShardingOption(prefer_reduce_scatter=True)\n        auto_stage_option = AutoStageOption()\n\n        profile_results = distributed_profile_on_mesh(\n            [stage], [virtual_submesh], num_micro_batches, default_as_option,\n            auto_stage_option, profile_results)\n\n        return profile_results[stage_index]\n\n    def check_1d_2d_results_the_same(self, train_step, state, batch,\n                                     micro_batch, num_layers, num_microbatch):\n        (closed_jaxpr, full_batch_closed_jaxpr,\n         donated_invars) = self.get_train_step_jaxpr(train_step, state, batch,\n                                                     micro_batch)\n        (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr,\n         microbatch_bound, reduction_vector, post_microbatch_bound,\n         accumulator_mapping, acc_grad_invars, acc_grad_outvars,\n         jax_apply_layers, apply_grad_global_info) = self.pre_process_jaxpr(\n             closed_jaxpr, full_batch_closed_jaxpr, num_microbatch,\n             donated_invars)\n        # 2D\n        profile_results_2d = self.generate_profile_result(\n            jax_pipeline_layers, accumulator_mapping, acc_grad_invars,\n            acc_grad_outvars, jax_apply_layers, apply_grad_global_info,\n            num_microbatch, 0, num_layers - 1)\n\n        # 1D\n        profile_results_1d = []\n        for layer_idx in range(num_layers):\n            result = self.generate_profile_result(\n                jax_pipeline_layers, accumulator_mapping, acc_grad_invars,\n                acc_grad_outvars, jax_apply_layers, apply_grad_global_info,\n                num_microbatch, layer_idx, layer_idx)\n            profile_results_1d.append(result)\n\n        # Compare\n        (available_memory_2d, peak_memory_2d, initial_size_2d,\n         intermediate_size_2d,\n         max_stage_2d) = get_merged_stages_memory_stats([profile_results_2d])\n        (available_memory_1d, peak_memory_1d, initial_size_1d,\n         intermediate_size_1d,\n         max_stage_1d) = get_merged_stages_memory_stats(profile_results_1d)\n\n        assert available_memory_1d == available_memory_2d, (\n            f\"available_memory_1d: {available_memory_1d}, \"\n            f\"available_memory_2d: {available_memory_2d}\")\n        assert initial_size_1d == initial_size_2d, (\n            f\"initial_size_1d: {initial_size_1d}, \"\n            f\"initial_size_2d: {initial_size_2d}\")\n        assert intermediate_size_1d == intermediate_size_2d, (\n            f\"intermediate_size_1d: {intermediate_size_1d}, \"\n            f\"intermediate_size_2d: {intermediate_size_2d}\")\n        # Note: peak_memory_1d is not equal to peak_memory_2d because\n        # the greedy memory register allocation algorithm in XLA is not\n        # optimal, and may behave different in 1D and 2D cases.\n\n    def test_mlp_1d_2d_the_same(self):\n        num_microbatch = 2\n        num_layers = 2\n        (train_step, state, batch,\n         micro_batch) = self.create_mlp(num_microbatch)\n        self.check_1d_2d_results_the_same(train_step, state, batch, micro_batch,\n                                          num_layers, num_microbatch)\n\n    def test_bert_1d_2d_the_same(self):\n        num_microbatch = 2\n        num_layers = 3\n        (train_step, state, batch,\n         micro_batch) = self.create_bert_layers(num_layers, num_microbatch)\n        self.check_1d_2d_results_the_same(train_step, state, batch, micro_batch,\n                                          num_layers, num_microbatch)\n\n    def check_2d_real_the_same(self):\n        num_microbatch = 2\n        num_layers = 1\n        (train_step, state, batch,\n         micro_batch) = self.create_mlp(num_microbatch, add_marker=False)\n        (closed_jaxpr, full_batch_closed_jaxpr,\n         donated_invars) = self.get_train_step_jaxpr(train_step, state, batch,\n                                                     micro_batch)\n\n        (closed_jaxpr, global_outvars, jax_pipeline_layers, apply_grad_jaxpr,\n         microbatch_bound, reduction_vector, post_microbatch_bound,\n         accumulator_mapping, acc_grad_invars, acc_grad_outvars,\n         jax_apply_layers, apply_grad_global_info) = self.pre_process_jaxpr(\n             closed_jaxpr, full_batch_closed_jaxpr, num_microbatch,\n             donated_invars)\n        # 2D\n        profile_results_2d = self.generate_profile_result(\n            jax_pipeline_layers, accumulator_mapping, acc_grad_invars,\n            acc_grad_outvars, jax_apply_layers, apply_grad_global_info,\n            num_microbatch, 0, num_layers - 1)\n        (available_memory_2d, peak_memory_2d, initial_size_2d,\n         intermediate_size_2d,\n         max_stage_2d) = get_merged_stages_memory_stats([profile_results_2d])\n\n        # Real\n        pipeshard_method = PipeshardParallel(\n            num_micro_batches=num_microbatch,\n            layer_option=\"manual\",\n            stage_option=\"uniform\",\n        )\n        parallelized_train_step = parallelize(\n            train_step,\n            donate_argnums=(0,),\n            method=pipeshard_method,\n        )\n        parallelized_train_step(state, batch)\n        peak_memory = (parallelized_train_step.get_executable(\n            state, batch).mesh_group.get_max_memory_allocated())\n        print(f\"2D peak_memory: {peak_memory_2d}\")\n        print(f\"Real peak_memory: {peak_memory}\")\n        # Note: real peak_memory is not equal to peak_memory_2d because\n        # of the same reason as above. In addition, our old profiling\n        # method is also not accurate compared to the real peak memory.\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(StageConstructUtilTest(\"test_mlp_1d_2d_the_same\"))\n    suite.addTest(StageConstructUtilTest(\"test_bert_1d_2d_the_same\"))\n    # suite.addTest(StageConstructUtilTest(\"check_2d_real_the_same\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/pipeline_parallel/test_tied_embedding.py",
    "content": "import unittest\nimport os\n\nfrom flax import linen as nn\nimport jax\nimport jax.numpy as jnp\nimport optax\n\nfrom alpa import (init, parallelize, mark_pipeline_boundary, grad,\n                  PipeshardParallel)\nfrom alpa.model.model_util import TrainState\nfrom alpa.testing import assert_allclose\n\n\nclass PipelineTiedEmbeddingTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def train_tied_embedding(self, method):\n        vocab_size = 256\n        hidden_size = 16\n        batch_size = 8\n        seq_len = 8\n\n        class Model(nn.Module):\n            \"\"\"Tied input and output embedding.\"\"\"\n\n            def setup(self):\n                self.embed = nn.Embed(vocab_size, hidden_size)\n\n            def __call__(self, x):\n                x = self.embed(x)\n                mark_pipeline_boundary()\n                embed = self.embed.variables[\"params\"][\"embedding\"]\n                x = x @ embed.T\n                return x\n\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"])\n                y_ = jax.nn.one_hot(batch[\"y\"], out.shape[-1])\n                loss = -jnp.sum(y_ * jax.nn.log_softmax(out, axis=-1),\n                                axis=-1).sum()\n                return loss\n\n            grads = grad(loss_func)(state.params)\n            return state.apply_gradients(grads=grads)\n\n        x = jnp.ones((batch_size, seq_len), jnp.int32)\n        y = jnp.ones((batch_size, seq_len), jnp.int32)\n\n        # Init model and optimizer\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x)\n        tx = optax.adam(learning_rate=1e-2)\n        state = TrainState.create(apply_fn=model.apply,\n                                  params=params,\n                                  tx=tx,\n                                  dynamic_scale=None)\n\n        # Run and check results\n        p_train_step = parallelize(train_step, method=method)\n        batch = {\"x\": x, \"y\": y}\n        expected_new_state = train_step(state, batch)\n        actual_new_state = p_train_step(state, batch)\n        assert_allclose(actual_new_state.params, expected_new_state.params)\n\n    def test_tied_embedding_pipeshard_parallel(self):\n        method = PipeshardParallel(num_micro_batches=2, layer_option=\"manual\")\n        self.train_tied_embedding(method)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(\n        PipelineTiedEmbeddingTest(\"test_tied_embedding_pipeshard_parallel\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/run_all.py",
    "content": "\"\"\"Run all test cases.\nRun each file in a separate process to avoid GPU memory conflicts.\n\nUsages:\n# Run all files\npython3 run_all.py\n\n# Run files whose names contain \"pipeline\"\npython3 run_all.py --run-pattern pipeline\n\n# Run files whose names contain \"shard_parallel\"\npython3 run_all.py --run-pattern shard_parallel\n\n# Run files whose names do not contain \"torch\"\npython3 run_all.py --skip-pattern torch\n\"\"\"\n\nimport argparse\nimport glob\nimport multiprocessing\nimport os\nimport numpy as np\nimport time\nfrom typing import Sequence\nimport unittest\n\nslow_testcases = set([\n    \"pipeline_parallel/test_stage_construction_slow.py\",\n    \"torch_frontend/test_zhen.py\",\n])\n\n\ndef run_unittest_files(files, args):\n    \"\"\"Run unit test files one by one in separates processes.\"\"\"\n    os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = str(\n        args.xla_client_mem_fraction)\n    # Must import alpa after setting the global env\n    from alpa.util import run_with_timeout\n\n    for filename in files:\n        if args.run_pattern is not None and args.run_pattern not in filename:\n            continue\n        if args.skip_pattern is not None and args.skip_pattern in filename:\n            continue\n        if not args.enable_slow_tests and filename in slow_testcases:\n            continue\n        if args.run_tpu ^ (\"tpu\" in filename):\n            continue\n\n        def func():\n            ret = unittest.main(module=None, argv=[\"\", \"-vb\"] + [filename])\n\n        p = multiprocessing.Process(target=func)\n\n        def run_one_file():\n            p.start()\n            p.join()\n\n        try:\n            run_with_timeout(run_one_file, timeout=args.time_limit_per_file)\n            if p.exitcode != 0:\n                return False\n        except TimeoutError:\n            p.terminate()\n            time.sleep(5)\n            print(f\"\\nTimeout after {args.time_limit_per_file} seconds \"\n                  f\"when running {filename}\")\n            return False\n\n    return True\n\n\nif __name__ == \"__main__\":\n    arg_parser = argparse.ArgumentParser()\n    arg_parser.add_argument(\n        \"--run-pattern\",\n        type=str,\n        default=None,\n        help=\"Run files whose names contain the provided string\")\n    arg_parser.add_argument(\n        \"--skip-pattern\",\n        type=str,\n        default=None,\n        help=\"Do not run files whose names contain the provided string\")\n    arg_parser.add_argument(\n        \"--enable-slow-tests\",\n        action=\"store_true\",\n        help=\"Run test cases including profiling, which takes a long time\")\n    arg_parser.add_argument(\n        \"--xla-client-mem-fraction\",\n        type=float,\n        default=0.25,\n        help=\"The fraction of GPU memory used to run unit tests\")\n    arg_parser.add_argument(\n        \"--time-limit-per-file\",\n        type=int,\n        default=1000,\n        help=\"The time limit for running one file in seconds.\")\n    arg_parser.add_argument(\"--order\",\n                            type=str,\n                            default=\"sorted\",\n                            choices=[\"sorted\", \"random\", \"reverse_sorted\"])\n    arg_parser.add_argument(\"--run-tpu\",\n                            action=\"store_true\",\n                            help=\"Whether to run tests for tpus.\")\n    args = arg_parser.parse_args()\n\n    files = glob.glob(\"**/test_*.py\", recursive=True)\n    if args.order == \"sorted\":\n        files.sort()\n    elif args.order == \"random\":\n        files = [files[i] for i in np.random.permutation(len(files))]\n    elif args.order == \"reverse_sorted\":\n        files.sort()\n        files = reversed(files)\n\n    tic = time.time()\n    success = run_unittest_files(files, args)\n\n    if success:\n        print(f\"Success. Time elapsed: {time.time() - tic:.2f}s\")\n    else:\n        print(f\"Fail. Time elapsed: {time.time() - tic:.2f}s\")\n\n    exit(0 if success else -1)\n"
  },
  {
    "path": "tests/runtime/test_create_state.py",
    "content": "\"\"\"Test distributed weight initialization.\"\"\"\nimport unittest\n\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nimport jax\nfrom jax.tree_util import tree_flatten\nfrom jax._src.api import make_jaxpr\nimport jax.numpy as jnp\nimport optax\n\nimport alpa\nfrom alpa import (init, shutdown, parallelize, ShardParallel, PipeshardParallel,\n                  CreateStateParallel)\n\n\nclass CreateStateTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def tearDown(self):\n        shutdown()\n\n    def run_test(self, method):\n        use_bias = True\n        batch_size = 8\n        input_dim = output_dim = hidden_dim = 32\n\n        grad_fn = (jax.grad if isinstance(method, ShardParallel) and\n                   method.num_micro_batches is None else alpa.grad)\n\n        class Model(nn.Module):\n\n            @nn.compact\n            def __call__(self, x):\n                x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x)\n                x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x)\n                if isinstance(method, PipeshardParallel):\n                    alpa.mark_pipeline_boundary()\n                x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x)\n                x = nn.Dense(features=output_dim, use_bias=use_bias)(x)\n                return x\n\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"])\n                return jnp.mean((out - batch[\"y\"])**2)\n\n            grads = grad_fn(loss_func)(state.params)\n            new_state = state.apply_gradients(grads=grads)\n            return new_state\n\n        def create_state():\n            model = Model()\n            rngkey = jax.random.PRNGKey(0)\n            params = model.init(rngkey, jnp.ones((1, input_dim)))\n            tx = optax.adam(learning_rate=1e-2)\n            return TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        batch = {\n            \"x\": jnp.ones((batch_size, input_dim)),\n            \"y\": jnp.ones((batch_size, output_dim)),\n        }\n\n        train_step = parallelize(train_step, method=method)\n        create_state = parallelize(create_state,\n                                   method=CreateStateParallel(\n                                       train_step, batch))\n\n        state = create_state()\n        state = train_step(state, batch)\n\n        if isinstance(method, ShardParallel):\n            actual = tree_flatten(create_state.get_last_executable().\n                                  get_output_placement_specs())[0]\n            expected = tree_flatten(\n                train_step.get_last_executable().get_input_placement_specs()\n                [0])[0]\n            assert actual == expected\n        elif isinstance(method, PipeshardParallel):\n            # The assertion is already in CreateStateExecutable::launch_on_driver\n            # Here, we just call the function to test whether it is runnable.\n            train_step.get_last_executable().get_output_placement_specs()\n\n    def test_shard_parallel(self):\n        method = ShardParallel(num_micro_batches=None)\n        self.run_test(method)\n\n    def test_shard_parallel_grad_acc(self):\n        method = ShardParallel(num_micro_batches=2)\n        self.run_test(method)\n\n    def test_pipeshard_parallel(self):\n        method = PipeshardParallel(num_micro_batches=2, layer_option=\"manual\")\n        self.run_test(method)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(CreateStateTest(\"test_shard_parallel\"))\n    suite.addTest(CreateStateTest(\"test_shard_parallel_grad_acc\"))\n    suite.addTest(CreateStateTest(\"test_pipeshard_parallel\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_cross_mesh_communicator.py",
    "content": "import unittest\n\nimport ray\nfrom alpa import init\nfrom alpa.device_mesh import (\n    create_and_record_cross_mesh_collective_communicators, get_global_cluster)\nfrom alpa.pipeline_parallel.stage_construction import get_sliced_virtual_submeshes\nfrom alpa.util import mesh_ids_hash\n\n\nclass CrossMeshCollectiveCommunicatorTest(unittest.TestCase):\n\n    def setUp(self) -> None:\n        init(\"ray\")\n\n    def test_create_and_set(self):\n        virtual_mesh = get_global_cluster().get_virtual_physical_mesh(\n            host_ids=[0], num_devices_per_host=4)\n        submesh_shapes = [(1, 2)] * 2\n        sliced_virtual_meshes = get_sliced_virtual_submeshes(\n            virtual_mesh, submesh_shapes)\n        virtual_mesh.get_physical_mesh_group(sliced_virtual_meshes)\n        mesh_group = virtual_mesh.launched_physical_mesh_group\n        meshes = mesh_group.meshes\n        key = mesh_ids_hash([0, 1])\n        ray.get(\n            create_and_record_cross_mesh_collective_communicators(meshes, key))\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(CrossMeshCollectiveCommunicatorTest(\"test_create_and_set\"))\n\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_data_loader.py",
    "content": "\"\"\"Test distributed mesh data loader.\"\"\"\nimport os\nimport unittest\n\nfrom flax import linen as nn\nimport jax\nimport jax.numpy as jnp\nfrom jax.interpreters import pxla\nimport numpy as np\n\nfrom alpa import init, MeshDriverDataLoader\nfrom alpa.parallel_plan import PlacementSpec\nfrom alpa.device_mesh import get_global_physical_mesh\nfrom alpa.testing import assert_allclose\nfrom alpa.testing import data_loader_input_iter_func as input_iter_func\n\n\nclass DataLoaderTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n        self.physical_mesh = get_global_physical_mesh(create_if_not_exist=True)\n\n    def run_test(self, sharding_specs):\n        batch_size = 64\n        num_samples = 256\n        feature_dim = 32\n        avals = [\n            jax.core.ShapedArray((batch_size, feature_dim), jnp.float32),\n            jax.core.ShapedArray((batch_size,), jnp.int32)\n        ]\n        placement_specs = [\n            PlacementSpec(aval, (self.physical_mesh.mesh_id,), (sharding_spec,))\n            for aval, sharding_spec in zip(avals, sharding_specs)\n        ]\n        prefetch_size = 2\n\n        data_loader = MeshDriverDataLoader(batch_size, num_samples,\n                                           input_iter_func, placement_specs,\n                                           prefetch_size)\n        expected_data_loader = input_iter_func(0, num_samples, batch_size)\n\n        actual_x = []\n        actual_y = []\n        expected_x = []\n        expected_y = []\n        for actual_batch, expected_batch in zip(data_loader,\n                                                expected_data_loader):\n            actual_x.append(np.array(actual_batch[0]))\n            actual_y.append(np.array(actual_batch[1]))\n            expected_x.append(np.array(expected_batch[0]))\n            expected_y.append(np.array(expected_batch[1]))\n\n        actual_x = np.concatenate(actual_x)\n        actual_y = np.concatenate(actual_y)\n        expected_x = np.concatenate(expected_x)\n        expected_y = np.concatenate(expected_y)\n\n        # Check that actual_x is a permutation of expected_x.\n        for i in range(feature_dim):\n            assert np.sum(actual_x[:, i]) == np.sum(expected_x[:, i])\n        # Check that actual_y is a permutation of expected_y.\n        assert np.sum(actual_y) == np.sum(expected_y)\n\n    def test_data_parallel(self):\n        num_devices = self.physical_mesh.num_devices\n\n        sharding_specs = [\n            pxla.ShardingSpec((pxla.Chunked((num_devices,)), pxla.NoSharding()),\n                              (pxla.ShardedAxis(0),)),\n            pxla.ShardingSpec((pxla.Chunked((num_devices,)),),\n                              (pxla.ShardedAxis(0),))\n        ]\n        self.run_test(sharding_specs)\n\n    def test_model_parallel(self):\n        num_devices = self.physical_mesh.num_devices\n\n        sharding_specs = [\n            pxla.ShardingSpec((pxla.NoSharding(), pxla.Chunked((num_devices,))),\n                              (pxla.ShardedAxis(0),)),\n            pxla.ShardingSpec((pxla.NoSharding(),),\n                              (pxla.Replicated(num_devices),))\n        ]\n        self.run_test(sharding_specs)\n\n    def test_data_model_parallel(self):\n        dp = 2\n        mp = self.physical_mesh.num_devices // dp\n        sharding_specs = [\n            pxla.ShardingSpec((pxla.Chunked((dp,)), pxla.Chunked((mp,))),\n                              (pxla.ShardedAxis(0), pxla.ShardedAxis(1))),\n            pxla.ShardingSpec((pxla.Chunked((dp,)),), (\n                pxla.ShardedAxis(0),\n                pxla.Replicated(mp),\n            ))\n        ]\n        self.run_test(sharding_specs)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(DataLoaderTest(\"test_data_parallel\"))\n    suite.addTest(DataLoaderTest(\"test_model_parallel\"))\n    suite.addTest(DataLoaderTest(\"test_data_model_parallel\"))\n\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_debug_info.py",
    "content": "\"\"\"Test the debug information dummping.\"\"\"\nimport os\nimport unittest\n\nfrom alpa import (init, parallelize, ShardParallel, PipeshardParallel,\n                  AutoLayerOption, global_config)\nfrom alpa.pipeline_parallel.stage_construction import get_last_dp_result\nfrom alpa.device_mesh import get_global_cluster\nfrom alpa.testing import assert_allclose, get_mlp_train_state_and_step\n\n\nclass DebugInfoTest(unittest.TestCase):\n\n    def setUp(self):\n        os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\"\n\n    def test_1_debug_shard_parallel(self):\n        state, batch, train_step = get_mlp_train_state_and_step(batch_size=128,\n                                                                hidden_size=128,\n                                                                num_layers=4)\n\n        # Print auto-sharding intermidiate results\n        os.environ[\"ALPA_DEBUG_PRINT_AS_STRATEGY\"] = \"1\"\n\n        p_train_step = parallelize(train_step,\n                                   method=ShardParallel(num_micro_batches=2))\n        actual_output = p_train_step(state, batch)\n        executable = p_train_step.get_last_executable()\n        executable.sync()\n\n        # Dump final HLO and other debug info\n        executable.dump_debug_info(\"alpa_debug_info\")\n\n    def test_2_debug_pipeline_parallel(self):\n        init(cluster=\"ray\")\n        state, batch, train_step = get_mlp_train_state_and_step(batch_size=128,\n                                                                hidden_size=128,\n                                                                num_layers=6)\n\n        # Print auto-sharding intermidiate results\n        global_config.pipeline_distributed_compile = False\n        os.environ[\"ALPA_DEBUG_PRINT_AS_STRATEGY\"] = \"1\"\n\n        layer_num = min(get_global_cluster().num_devices, 2)\n        p_train_step = parallelize(\n            train_step,\n            method=PipeshardParallel(\n                num_micro_batches=2,\n                layer_option=AutoLayerOption(layer_num=layer_num)))\n        actual_output = p_train_step(state, batch)\n        executable = p_train_step.get_last_executable()\n        executable.sync()\n\n        # Dump final HLO and other debug info\n        executable.dump_debug_info(\"alpa_debug_info\")\n\n        # Print auto-stage dynamic programming results if use auto stage partition\n        print(get_last_dp_result())\n\n\ndef suite():\n    s = unittest.TestSuite()\n    s.addTest(DebugInfoTest(\"test_1_debug_shard_parallel\"))\n    s.addTest(DebugInfoTest(\"test_2_debug_pipeline_parallel\"))\n    return s\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_device_mesh.py",
    "content": "\"\"\"Test distributed mulit-host device mesh.\"\"\"\n\nimport os\nimport unittest\n\nfrom flax import linen as nn\nimport jax\nimport jax.numpy as jnp\nfrom jax.interpreters import pxla\nimport numpy as np\nimport ray\n\nfrom alpa import init, shutdown, parallelize, DistributedArray\nfrom alpa.device_mesh import get_global_physical_mesh\nfrom alpa.testing import assert_allclose\n\n\nclass DeviceMeshTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def tearDown(self):\n        shutdown()\n\n    def test_add_one(self):\n\n        @parallelize\n        def add_one(x):\n            return x + 1\n\n        @parallelize\n        def multiply_two(x):\n            return x * 2\n\n        # Run computation\n        a = jnp.ones((512, 512))\n        out = add_one(a)\n        out = multiply_two(out)\n\n        # Check results\n        assert_allclose(np.array(out), (np.ones_like(a) + 1) * 2)\n\n    def test_distributed_array(self):\n        physical_mesh = get_global_physical_mesh(create_if_not_exist=True)\n        logical_mesh = physical_mesh.get_logical_mesh()\n\n        array = jnp.arange(64).reshape([8, 8])\n        sharding_spec = logical_mesh.make_tile_spec(array, [0, 1], [0, 1])\n        indices = sharding_spec.indices(array.shape).flatten()\n        dis_a = physical_mesh.shard_args_to_arrays([array.aval], [indices],\n                                                   [sharding_spec], [array])[0]\n\n        assert_allclose(array, dis_a)\n\n    def test_preshard_args(self):\n\n        @parallelize\n        def add_one(x):\n            return x + 1\n\n        a = jnp.ones((64, 64))\n        a, = add_one.preshard_dynamic_args(a)\n        assert isinstance(a, DistributedArray)\n\n\nclass DeviceMesh_ResourceAwareness(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\", num_nodes=1, num_devices_per_node=2)\n\n    def tearDown(self):\n        shutdown()\n\n    @unittest.skipIf(jax.local_device_count(\"gpu\") < 4, \"no enough device\")\n    def test_resource_check(self):\n        cluster_devices = ray.cluster_resources().get(\"GPU\", 0)\n        available_devices = ray.available_resources().get(\"GPU\", 0)\n        print(cluster_devices, available_devices, ray.cluster_resources(),\n              ray.available_resources())\n        assert available_devices + 2 == cluster_devices\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(DeviceMeshTest(\"test_add_one\"))\n    suite.addTest(DeviceMeshTest(\"test_distributed_array\"))\n    suite.addTest(DeviceMeshTest(\"test_preshard_args\"))\n    suite.addTest(DeviceMeshTest(\"test_preshard_args\"))\n    suite.addTest(DeviceMesh_ResourceAwareness(\"test_resource_check\"))\n\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_dist_save_load.py",
    "content": "\"\"\"Test distributed save and load.\"\"\"\n\nimport subprocess\nimport tempfile\nimport unittest\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\nfrom alpa import (init, shutdown, parallelize, DistributedArray,\n                  PipeshardParallel, save_checkpoint, restore_checkpoint)\nfrom alpa.device_mesh import get_global_cluster\nfrom alpa.testing import (get_mlp_train_state_and_step,\n                          get_bert_layer_train_state_and_step, assert_allclose)\n\n\nclass DistSaveLoadTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def tearDown(self):\n        shutdown()\n\n    def check_dist_array_eq(self, x, y):\n        if isinstance(x, DistributedArray):\n            x = np.array(\n                x.device_mesh.get_remote_buffers(x.remote_ref, batching=True))\n        if isinstance(y, DistributedArray):\n            y = np.array(\n                y.device_mesh.get_remote_buffers(y.remote_ref, batching=True))\n        assert_allclose(x, y)\n\n    def _get_efs_mount_point(self):\n        # Hacky function to get the EFS mount point\n        for line in subprocess.check_output(\"df -h\",\n                                            shell=True).decode().split('\\n'):\n            cols = line.split(' ')\n            if \"efs\" in cols[0]:\n                return cols[-1] + \"/\"\n        return None\n\n    def _get_save_prefix(self):\n        device_cluster = get_global_cluster()\n        if len(device_cluster.host_info) > 1:\n            # Get EFS mount point for the multi-host test\n            save_prefix = self._get_efs_mount_point()\n            if save_prefix is None:\n                self.skipTest(\"The multi-host test requires a mounted EFS! \")\n        else:\n            # Use tmp dir for the single-host test\n            save_prefix = \"/tmp/\"\n        return save_prefix\n\n    def test_distributed_array_save_load(self):\n        device_cluster = get_global_cluster()\n        save_prefix = self._get_save_prefix()\n\n        # Launch a device mesh contains four devices\n        if device_cluster.num_devices < 4:\n            self.skipTest(\n                \"This unit test requires a cluster with at least 4 devices! \")\n        host_num = min(len(device_cluster.host_info), 4)\n        device_per_host = 4 // host_num\n        physical_mesh = device_cluster.get_physical_mesh(\n            list(range(host_num)), device_per_host)\n        logical_mesh = physical_mesh.get_logical_mesh([2, 2])\n\n        global_input_shape = (4, 2)\n        num = np.prod(np.array(global_input_shape))\n\n        # Build DistributedArray to be saved\n        # [[0,1],          [[0],  [[1],\n        #  [2,3],  shard    [2]]   [3]]\n        #  [4,5],  ====>   [[4],  [[5],\n        #  [6,7]]           [6]]   [7]]\n        global_input_data1 = jnp.arange(num).reshape(global_input_shape)\n        input_sharding_spec = logical_mesh.make_tile_spec(\n            global_input_data1, [0, 1], [0, 1])\n        input_indices = input_sharding_spec.indices(\n            global_input_data1.shape).flatten()\n        (dist_input_data1,) = physical_mesh.shard_args_to_arrays(\n            (jax.ShapedArray(global_input_data1.shape, jnp.int32),),\n            (input_indices,), (input_sharding_spec,), (global_input_data1,))\n\n        # Check the DistributedArray's remote buffers\n        desired_buffers1 = np.array([[[0], [2]], [[1], [3]], [[4], [6]],\n                                     [[5], [7]]])\n        self.check_dist_array_eq(desired_buffers1, dist_input_data1)\n\n        # cached save/load\n        with tempfile.TemporaryDirectory(prefix=save_prefix) as ckpt_dir:\n            with tempfile.TemporaryDirectory(prefix=\"/tmp/\") as cache_dir:\n                # Save the DistributedArray (one replica only)\n                dist_input_data1.save(ckpt_dir, cache_dir)\n\n                # Sync all the move workers\n                physical_mesh.sync_move_workers()\n\n                # Load previously saved DistributedArray with a different shardingSpec\n                # [[0,1],          [[0,1],  [[0,1],\n                #  [2,3],  shard    [2,3]]   [2,3]]\n                #  [4,5],  ====>   [[4,5],  [[4,5],\n                #  [6,7]]           [6,7]]   [6,7]]\n                load_sharding_spec = logical_mesh.make_tile_spec(\n                    global_input_data1, [0, 1], [0])\n                dist_load_data1 = DistributedArray.load(\n                    ckpt_dir,\n                    jax.ShapedArray(global_input_data1.shape, jnp.int32),\n                    physical_mesh, load_sharding_spec)\n\n                # Check the DistributedArray's remote buffers\n                desired_buffers2 = np.array([[[0, 1], [2, 3]], [[0, 1], [2, 3]],\n                                             [[4, 5], [6, 7]], [[4, 5], [6,\n                                                                         7]]])\n                self.check_dist_array_eq(desired_buffers2, dist_load_data1)\n\n        # Cleanup\n        physical_mesh.shutdown()\n\n    def test_jax_mlp_save_dist_load(self):\n        save_prefix = self._get_save_prefix()\n\n        # Init model\n        jax_state, batch, train_step = get_mlp_train_state_and_step(\n            batch_size=64,\n            hidden_size=16,\n            num_layers=4,\n            add_manual_pipeline_marker=True)\n\n        with tempfile.TemporaryDirectory(prefix=save_prefix) as ckpt_dir:\n            # save normal jax model using tensorstore for distributed loading\n            save_checkpoint(ckpt_dir, jax_state, 1)\n\n            # Compile\n            method = PipeshardParallel(num_micro_batches=2,\n                                       layer_option=\"manual\")\n            serial_train_step = train_step\n            parallel_train_step = parallelize(train_step, method=method)\n            executable = parallel_train_step.get_executable(jax_state, batch)\n\n            # Restore checkpoint\n            state_ps, _ = executable.get_input_placement_specs()\n            load_state = restore_checkpoint(ckpt_dir, 1, state_ps)\n\n            # Run after load\n            serial_state = serial_train_step(jax_state, batch)[0]\n            load_state = parallel_train_step(load_state, batch)[0]\n\n            # Check results\n            assert_allclose(serial_state.params, load_state.params, 1e-3, 1e-3)\n\n    def test_distributed_mlp_uncached_save_load(self):\n        save_prefix = self._get_save_prefix()\n\n        # Init model\n        state, batch, train_step = get_mlp_train_state_and_step(\n            batch_size=128,\n            hidden_size=16,\n            num_layers=4,\n            add_manual_pipeline_marker=True)\n\n        # Compile\n        method = PipeshardParallel(num_micro_batches=1, layer_option=\"manual\")\n        serial_train_step = train_step\n        parallel_train_step = parallelize(train_step, method=method)\n        executable = parallel_train_step.get_executable(state, batch)\n\n        # Run before save\n        serial_state = state\n        parallel_state = state\n        serial_state = serial_train_step(serial_state, batch)[0]\n        parallel_state = parallel_train_step(parallel_state, batch)[0]\n        assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3)\n\n        # uncached save/load\n        with tempfile.TemporaryDirectory(prefix=save_prefix) as ckpt_dir:\n            # Save checkpoint\n            save_checkpoint(ckpt_dir, parallel_state, 1)\n\n            # Restore checkpoint\n            state_ps, _ = executable.get_input_placement_specs()\n            load_state = restore_checkpoint(ckpt_dir, 1, state_ps)\n\n            # Run after load\n            serial_state = serial_train_step(serial_state, batch)[0]\n            load_state = parallel_train_step(load_state, batch)[0]\n\n            # Check results\n            assert_allclose(serial_state.params, load_state.params, 1e-3, 1e-3)\n\n    def test_distributed_bert_cached_save_load(self):\n        save_prefix = self._get_save_prefix()\n\n        # Init model\n        state, batch, train_step = get_bert_layer_train_state_and_step(\n            batch_size=16,\n            seq_len=8,\n            num_layers=4,\n            hidden_size=128,\n            num_heads=8,\n            clip_by_global_norm=False,\n            use_dynamic_scale=False,\n            add_manual_pipeline_marker=True)\n\n        # Compile\n        method = PipeshardParallel(num_micro_batches=2, layer_option=\"manual\")\n        serial_train_step = train_step\n        parallel_train_step = parallelize(train_step, method=method)\n        executable = parallel_train_step.get_executable(state, batch)\n\n        # Run before save\n        serial_state = state\n        parallel_state = state\n        serial_state = serial_train_step(serial_state, batch)[0]\n        parallel_state = parallel_train_step(parallel_state, batch)[0]\n        assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3)\n\n        # cached save/load\n        with tempfile.TemporaryDirectory(prefix=save_prefix) as ckpt_dir:\n            with tempfile.TemporaryDirectory(prefix=\"/tmp/\") as cache_dir:\n                # Save checkpoint\n                save_checkpoint(ckpt_dir, parallel_state, 1, cache_dir)\n\n                # Sync all the move workers\n                executable.sync_move_workers()\n\n                # Restore checkpoint\n                state_ps, _ = executable.get_input_placement_specs()\n                load_state = restore_checkpoint(ckpt_dir, 1, state_ps)\n\n                # Run after load\n                serial_state = serial_train_step(serial_state, batch)[0]\n                load_state = parallel_train_step(load_state, batch)[0]\n\n                # Check results\n                assert_allclose(serial_state.params, load_state.params, 1e-3,\n                                1e-3)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(DistSaveLoadTest(\"test_distributed_array_save_load\"))\n    suite.addTest(DistSaveLoadTest(\"test_jax_mlp_save_dist_load\"))\n    suite.addTest(DistSaveLoadTest(\"test_distributed_mlp_uncached_save_load\"))\n    suite.addTest(DistSaveLoadTest(\"test_distributed_bert_cached_save_load\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_follow_parallel.py",
    "content": "\"\"\"Test following another parallel strategy.\"\"\"\nimport unittest\n\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nimport jax\nimport jax.numpy as jnp\nimport optax\n\nimport alpa\nfrom alpa import init, shutdown, parallelize, ShardParallel, PipeshardParallel\n\n\nclass FollowParallelTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def tearDown(self):\n        shutdown()\n\n    def run_test(self, method):\n        use_bias = True\n        batch_size = 32\n        input_dim = output_dim = hidden_dim = 8\n\n        class Model(nn.Module):\n\n            @nn.compact\n            def __call__(self, x):\n                x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x)\n                x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x)\n                x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x)\n                x = nn.Dense(features=output_dim, use_bias=use_bias)(x)\n                return x\n\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"])\n                return jnp.mean((out - batch[\"y\"])**2)\n\n            grads = grad_fn(loss_func)(state.params)\n            new_state = state.apply_gradients(grads=grads)\n            return new_state\n\n        def eval_step(params, batch):\n            out = state.apply_fn(params, batch[\"x\"])\n            return jnp.mean((out - batch[\"y\"])**2)\n\n        def create_state():\n            model = Model()\n            rngkey = jax.random.PRNGKey(0)\n            params = model.init(rngkey, jnp.ones((1, input_dim)))\n            tx = optax.adam(learning_rate=1e-2)\n            return TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        train_batch = {\n            \"x\": jnp.ones((batch_size, input_dim)),\n            \"y\": jnp.ones((batch_size, output_dim)),\n        }\n        eval_batch = {\n            \"x\": jnp.ones((batch_size * 2, input_dim)),\n            \"y\": jnp.ones((batch_size * 2, output_dim)),\n        }\n\n        grad_fn = jax.grad if method.num_micro_batches is None else alpa.grad\n        num_micro_batches = method.num_micro_batches\n\n        state = create_state()\n\n        train_step = parallelize(train_step, method=method)\n        eval_step = parallelize(eval_step,\n                                method=alpa.FollowParallel(\n                                    train_step,\n                                    num_micro_batches=num_micro_batches))\n\n        state = train_step(state, train_batch)\n        out = eval_step(state.params, eval_batch)\n\n        actual = jax.tree_flatten(\n            eval_step.get_last_executable().get_input_placement_specs()[0])[0]\n        expected = jax.tree_flatten(\n            train_step.get_last_executable().get_input_placement_specs()\n            [0].params)[0]\n        assert actual == expected\n\n    def test_shard_parallel(self):\n        method = ShardParallel(num_micro_batches=None)\n        self.run_test(method)\n\n    def test_shard_parallel_grad_acc(self):\n        method = ShardParallel(num_micro_batches=2)\n        self.run_test(method)\n\n    def test_pipeshard_parallel(self):\n        method = PipeshardParallel(\n            num_micro_batches=2, layer_option=alpa.AutoLayerOption(layer_num=2))\n        self.run_test(method)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(FollowParallelTest(\"test_shard_parallel\"))\n    suite.addTest(FollowParallelTest(\"test_shard_parallel_grad_acc\"))\n    suite.addTest(FollowParallelTest(\"test_pipeshard_parallel\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_install.py",
    "content": "import unittest\n\nfrom alpa.test_install import suite\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_memory_leak.py",
    "content": "\"\"\"Test whether there is any memory leak for distributed arrays and remote buffers.\"\"\"\nimport unittest\n\nimport ray\n\nfrom alpa import (init, shutdown, parallelize, global_config, ShardParallel,\n                  PipeshardParallel)\nfrom alpa.device_mesh import get_global_cluster\nfrom alpa.test_install import get_mlp_train_state_and_step\n\n\nclass MemoryLeakTest(unittest.TestCase):\n\n    def setUp(self):\n        init()\n        global_config.delete_remote_arrays_threshold = 0\n\n    def tearDown(self):\n        shutdown()\n\n    def test_shard_parallel(self):\n        state, batch, train_step = get_mlp_train_state_and_step(batch_size=128,\n                                                                hidden_size=128)\n        train_step = parallelize(train_step,\n                                 method=ShardParallel(num_micro_batches=2))\n\n        for i in range(2):\n            state, loss = train_step(state, batch)\n            del loss\n        del state\n\n        # Assert all buffers are freed\n        executable = train_step.get_last_executable()\n        for w in executable.physical_mesh.workers:\n            # One loss array cannot be deleted due to python's GC behavior\n            assert len(ray.get(w.get_live_buffer_uuids.remote())) <= 1\n\n    def test_pipeline_parallel(self):\n        state, batch, train_step = get_mlp_train_state_and_step(\n            batch_size=128, hidden_size=128, add_manual_pipeline_marker=True)\n\n        layer_num = min(get_global_cluster().num_devices, 2)\n        train_step = parallelize(\n            train_step,\n            method=PipeshardParallel(num_micro_batches=2,\n                                     layer_option=\"manual\"))\n\n        for i in range(2):\n            state, loss = train_step(state, batch)\n            del loss\n        del state\n\n        # Assert all buffers are freed\n        executable = train_step.get_last_executable()\n        for mesh in executable.mesh_group:\n            for w in mesh.workers:\n                assert len(ray.get(w.get_live_buffer_uuids.remote())) == 0\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(MemoryLeakTest(\"test_shard_parallel\"))\n    suite.addTest(MemoryLeakTest(\"test_pipeline_parallel\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_parallel_plan.py",
    "content": "\"\"\"Some basic tests to test installation.\"\"\"\nimport os\nimport pickle\nimport unittest\n\nfrom alpa import (init, shutdown, parallelize, ShardParallel, PipeshardParallel,\n                  AutoLayerOption, plan_to_method, AutoShardingOption,\n                  AutoStageOption)\nfrom alpa.device_mesh import get_global_cluster\nfrom alpa.testing import assert_allclose, get_mlp_train_state_and_step\n\n\nclass ParallelPlanTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def tearDown(self):\n        shutdown()\n\n    def test_shard_parallel(self):\n        state, batch, train_step = get_mlp_train_state_and_step(batch_size=128,\n                                                                hidden_size=128,\n                                                                num_layers=4)\n\n        method = ShardParallel(\n            num_micro_batches=2,\n            auto_sharding_option=AutoShardingOption(force_data_parallel=True))\n        p_train_step = parallelize(train_step, method=method)\n\n        executable1 = p_train_step.get_executable(state, batch)\n        plan = executable1.get_parallel_plan()\n\n        with open(\"tmp_plan.pkl\", \"wb\") as fout:\n            pickle.dump(plan, fout)\n        with open(\"tmp_plan.pkl\", \"rb\") as fin:\n            plan = pickle.load(fin)\n\n        p_train_step = parallelize(train_step, method=plan_to_method(plan))\n        executable2 = p_train_step.get_executable(state, batch)\n\n        assert (executable1.auto_sharding_objective ==\n                executable2.auto_sharding_objective)\n\n    def test_pipeshard_parallel(self):\n        state, batch, train_step = get_mlp_train_state_and_step(batch_size=128,\n                                                                hidden_size=128,\n                                                                num_layers=4)\n\n        method = PipeshardParallel(num_micro_batches=2,\n                                   layer_option=AutoLayerOption(layer_num=2),\n                                   stage_option=\"uniform\")\n        p_train_step = parallelize(train_step, method=method)\n\n        executable1 = p_train_step.get_executable(state, batch)\n        plan = executable1.get_parallel_plan()\n\n        with open(\"tmp_plan.pkl\", \"wb\") as fout:\n            pickle.dump(plan, fout)\n        with open(\"tmp_plan.pkl\", \"rb\") as fin:\n            plan = pickle.load(fin)\n\n        p_train_step = parallelize(train_step, method=plan_to_method(plan))\n        executable2 = p_train_step.get_executable(state, batch)\n\n        assert (executable1.get_input_placement_specs() ==\n                executable2.get_input_placement_specs())\n\n\ndef suite():\n    s = unittest.TestSuite()\n    s.addTest(ParallelPlanTest(\"test_shard_parallel\"))\n    s.addTest(ParallelPlanTest(\"test_pipeshard_parallel\"))\n    return s\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_random_seed.py",
    "content": "\"\"\"Test random seed.\"\"\"\nimport unittest\nimport os\n\nimport jax\nfrom jax._src.tree_util import tree_flatten, tree_unflatten\nimport jax.numpy as jnp\nimport numpy as np\n\nfrom alpa import (init, grad, parallelize, ShardParallel, set_seed, shutdown,\n                  AutoShardingOption)\nfrom alpa.parallel_method import PipeshardParallel\nfrom alpa.pipeline_parallel.layer_construction import ManualLayerOption\nfrom alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary\nfrom alpa.testing import assert_allclose\n\n\nclass RandomSeedTest(unittest.TestCase):\n\n    def setUp(self):\n        os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\"\n\n    def test_random_generation(self):\n\n        @parallelize(method=ShardParallel())\n        def func():\n            rngkey = jax.random.PRNGKey(0)\n            x = jax.random.normal(rngkey, (16, 4))\n            y = jax.random.normal(rngkey, (16, 4))\n            z = jnp.hstack((x, y))\n            z = (10000 * z).astype(jnp.int32)\n            return z.flatten()\n\n        a = func()\n        s = set(np.array(a))\n\n        # Check all random numbers are unique\n        assert len(a) == len(s)\n\n    def test_set_seed(self):\n\n        @parallelize(method=ShardParallel())\n        def func():\n            rngkey = jax.random.PRNGKey(0)\n            return jax.random.normal(rngkey, (16, 4))\n\n        @parallelize(method=ShardParallel())\n        def func2():\n            rngkey = jax.random.PRNGKey(0)\n            return jax.random.normal(rngkey, (16, 4))\n\n        set_seed(10)\n        a = func()\n        b = func()\n        set_seed(10)\n        c = func()\n        set_seed(10)\n        d = func2()\n\n        assert_allclose(a, c)\n        assert_allclose(c, d)\n\n        allclose = True\n        try:\n            assert_allclose(a, b)\n        except AssertionError:\n            allclose = False\n        assert not allclose\n\n    @unittest.skip(\n        \"The support of remat + random seed is broken after a rebase.\")\n    def test_remat_rng(self):\n        init(cluster=\"ray\")\n\n        batch_size = 64\n        hidden_size = 8\n        num_micro_batches = 1\n        rngkey = jax.random.PRNGKey(0)\n        x = jax.random.normal(rngkey, (batch_size, hidden_size))\n        params = {\n            \"x1\": jax.random.normal(rngkey, (hidden_size, hidden_size)),\n            \"x2\": jax.random.normal(rngkey, (hidden_size, hidden_size)),\n        }\n\n        # Run an inference-only forward pass to get rngs\n        def gen_rns(params, x, key):\n            # NOTE: We minic the real forward pass to make sure\n            # the sharding specs are the same. Otherwise, the results of rng\n            # do not match.\n            y = x @ params[\"x1\"]\n            rns = jax.random.normal(key, y.shape)\n            y = jax.lax.select(rns > 0, y, jnp.zeros_like(y))\n            mark_pipeline_boundary()\n            y = y @ params[\"x2\"]\n            return rns\n\n        set_seed(10)\n        method = PipeshardParallel(\n            num_micro_batches=num_micro_batches,\n            pipeline_schedule=\"inference\",\n            layer_option=\"manual\",\n            default_auto_sharding_option=AutoShardingOption(\n                force_data_parallel=True))\n        p_gen_rns = parallelize(gen_rns, method=method)\n        external_rns = np.array(p_gen_rns(params, x, rngkey))\n\n        # Run train step with remat and rng\n        def train_step(params, x, key, use_external_rns, external_rns):\n\n            def loss_func(params):\n                y = x @ params[\"x1\"]\n                if use_external_rns:\n                    rns = external_rns\n                else:\n                    rns = jax.random.normal(key, y.shape)\n                y = jax.lax.select(rns > 0, y, jnp.zeros_like(y))\n                mark_pipeline_boundary()\n                y = y @ params[\"x2\"]\n                return jnp.mean(y), rns\n\n            grads, rns = grad(loss_func, has_aux=True)(params)\n            # A workaroud to make apply_grad non-empty, otherwise it hits a bug\n            # (https://github.com/alpa-projects/alpa/issues/560).\n            grads = jax.tree_map(lambda x: x + 1, grads)\n            return grads, rns\n\n        set_seed(10)\n        method = PipeshardParallel(\n            num_micro_batches=num_micro_batches,\n            layer_option=ManualLayerOption(remat_layer=True),\n            default_auto_sharding_option=AutoShardingOption(\n                force_data_parallel=True))\n        p_train_step = parallelize(train_step,\n                                   method=method,\n                                   static_argnums=(3,))\n\n        grads_actual, rns_actual = p_train_step(params, x, rngkey, False,\n                                                external_rns)\n        grads_expected, rns_expected = train_step(params, x, rngkey, True,\n                                                  external_rns)\n\n        assert_allclose(external_rns, rns_actual)\n        assert_allclose(external_rns, rns_expected)\n        assert_allclose(grads_actual, grads_expected)\n        shutdown()\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(RandomSeedTest(\"test_random_generation\"))\n    suite.addTest(RandomSeedTest(\"test_set_seed\"))\n    suite.addTest(RandomSeedTest(\"test_remat_rng\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_save_load.py",
    "content": "import unittest\nimport time\nfrom tempfile import TemporaryFile\n\nimport ray\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport pickle\nimport flax\n\nfrom alpa import init, parallelize, PipeshardParallel, util\nfrom alpa.testing import get_mlp_train_state_and_step, assert_allclose\n\n\nclass SaveLoadTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    def test_mlp_state_load(self):\n        # Init model\n        state, batch, train_step = get_mlp_train_state_and_step(\n            batch_size=128, hidden_size=128, add_manual_pipeline_marker=True)\n\n        # Compile\n        method = PipeshardParallel(num_micro_batches=2, layer_option=\"manual\")\n        serial_train_step = train_step\n        parallel_train_step = parallelize(train_step, method=method)\n        executable = parallel_train_step.get_executable(state, batch)\n\n        serial_state = state\n        parallel_state = state\n        serial_state = serial_train_step(serial_state, batch)[0]\n        parallel_state = parallel_train_step(parallel_state, batch)[0]\n        assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3)\n\n        # Save model to a temporary file\n        outfile = TemporaryFile()\n        parallel_state_dict = flax.serialization.to_state_dict(parallel_state)\n        pickle.dump(util.map_to_nparray(parallel_state_dict), outfile)\n\n        # Load model from the temporary file\n        outfile.seek(0)\n        loaded_state_dict = pickle.load(outfile)\n        loaded_state = flax.serialization.from_state_dict(\n            state, loaded_state_dict)\n        outfile.close()\n\n        # Compare the loaded state with the original state\n        assert_allclose(loaded_state.params, serial_state.params, 1e-3, 1e-3)\n        assert_allclose(loaded_state.params, parallel_state.params, 1e-3, 1e-3)\n\n        # Take a step with the loaded state on both serial and parallel version\n        serial_state = serial_train_step(serial_state, batch)[0]\n        parallel_state = parallel_train_step(parallel_state, batch)[0]\n        serial_loaded_state = serial_train_step(loaded_state, batch)[0]\n        parallel_loaded_state = parallel_train_step(loaded_state, batch)[0]\n\n        # All the states should be the same\n        assert_allclose(serial_state.params, parallel_state.params, 1e-3, 1e-3)\n        assert_allclose(serial_state.params, serial_loaded_state.params, 1e-3,\n                        1e-3)\n        assert_allclose(serial_state.params, parallel_loaded_state.params, 1e-3,\n                        1e-3)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(SaveLoadTest('test_mlp_state_load'))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_tracing.py",
    "content": "\"\"\"Test activity tracing.\"\"\"\nimport unittest\n\nfrom alpa import (init, shutdown, parallelize, global_config, PipeshardParallel)\nfrom alpa.global_env import global_config\nfrom alpa.device_mesh import get_global_cluster\nfrom alpa.test_install import get_mlp_train_state_and_step\n\n\nclass TracingTest(unittest.TestCase):\n\n    def setUp(self):\n        global_config.collect_trace = True\n        init()\n\n    def tearDown(self):\n        shutdown()\n\n    def test_trace_pipeshard_execuable(self):\n        state, batch, train_step = get_mlp_train_state_and_step(\n            batch_size=128, hidden_size=128, add_manual_pipeline_marker=True)\n\n        layer_num = min(get_global_cluster().num_devices, 2)\n        train_step = parallelize(\n            train_step,\n            method=PipeshardParallel(num_micro_batches=2,\n                                     layer_option=\"manual\"))\n\n        for i in range(2):\n            state, _ = train_step(state, batch)\n\n        executable = train_step.get_last_executable()\n        stage_exec_info = executable.get_stage_execution_info()\n\n        assert len(stage_exec_info) == 6  # 6 stages\n        assert len(stage_exec_info[0]) == 4  # 4 invocations\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(TracingTest(\"test_trace_pipeshard_execuable\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/runtime/test_xla_nccl.py",
    "content": "\"\"\"Test cross-mesh resharding.\"\"\"\nimport unittest\n\nimport numpy as np\nimport ray\n\nfrom alpa import init\nfrom alpa.device_mesh import get_global_virtual_physical_mesh, next_array_uuids\nfrom alpa.global_env import global_config\n\n\nclass XLANCCLTest(unittest.TestCase):\n\n    def setUp(self):\n        init(cluster=\"ray\")\n\n    @unittest.skip(\"manually calling allgather is deprecated\")\n    def test_xla_nccl_allgather(self):\n        backup_nccl_mode = global_config.nccl_mode\n        global_config.nccl_mode = \"xla_extension\"\n\n        mesh_shape = (1, 4)\n        size = (4, 4)\n        virtual_mesh = get_global_virtual_physical_mesh()\n        mesh = virtual_mesh.slice_2d(range(mesh_shape[0]),\n                                     [range(mesh_shape[1])] *\n                                     mesh_shape[0]).get_physical_mesh()\n        worker = mesh.workers[0]\n        device_ids = np.arange(mesh.num_devices_per_host)\n\n        # Put buffers\n        ary_uuid = next_array_uuids(1)[0]\n        shard_len = size[0] // mesh.num_devices_per_host\n        shards = []\n        for i in range(mesh.num_devices_per_host):\n            data = np.zeros(size, dtype=int)\n            data[i * shard_len:(i + 1) * shard_len, :] = i\n            shards.append(data)\n        ray.get(worker.put_buffers.remote(ary_uuid, shards, 1, 0))\n\n        # Put allgather task\n        output_slice = [slice(0, size[0], None), slice(0, size[1], None)]\n        tensor_slices = []\n        for i in range(mesh.num_devices_per_host):\n            tensor_slices.append([\n                slice(i * shard_len, (i + 1) * shard_len, None),\n                slice(0, size[1], None)\n            ])\n        ray.get(\n            worker.put_resharding_allgather_task.remote(\n                0, (ReshardingAllGatherSpec(device_ids, tensor_slices,\n                                            output_slice),)))\n\n        # Run allgather task\n        ray.get(worker.run_allgather_task.remote(0, ary_uuid))\n        refs = ray.get(worker.get_buffers.remote(ary_uuid))\n        for i in range(4):\n            for j in range(4):\n                assert refs[i][j * shard_len, 0] == j\n\n        global_config.nccl_mode = backup_nccl_mode\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(XLANCCLTest(\"test_xla_nccl_allgather\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/serve/test_controller.py",
    "content": "\"\"\"Test alpa.serve controller.\"\"\"\nimport unittest\n\nimport numpy as np\nimport ray\nimport requests\nfrom tokenizers import Tokenizer\n\nfrom alpa.api import parallelize\nfrom alpa.serve.controller import run_controller\n\n\nclass EchoModel:\n\n    async def handle_request(self, request):\n        obj = await request.json()\n        return obj\n\n\nclass AddOneModel:\n\n    def __init__(self):\n\n        def func(x):\n            return x + 1\n\n        self.add_one = parallelize(func)\n\n    async def handle_request(self, request):\n        obj = await request.json()\n        x = np.array(obj[\"x\"])\n        y = self.add_one(x)\n        return await y.to_np_async()\n\n\nclass TokenizerModel:\n\n    def __init__(self):\n        self.tokenizer = Tokenizer.from_pretrained(\"bert-base-uncased\")\n\n    async def handle_request(self, request):\n        obj = await request.json()\n        x = obj[\"input\"]\n        y = self.tokenizer.encode(x)\n        return y.ids\n\n\nclass ControllerTest(unittest.TestCase):\n\n    def setUp(self):\n        ray.init(address=\"auto\", namespace=\"alpa_serve\")\n\n    def tearDown(self):\n        ray.shutdown()\n\n    def test_query(self):\n        controller = run_controller(\"localhost\")\n\n        info = ray.get(controller.get_info.remote())\n        host, port, root_path = info[\"host\"], info[\"port\"], info[\"root_path\"]\n\n        controller.register_model.remote(\"echo\", EchoModel)\n        controller.register_model.remote(\"add_one\", AddOneModel)\n        controller.register_model.remote(\"tokenizer\", TokenizerModel)\n        group_id = 0\n        controller.launch_mesh_group_manager.remote(group_id, [1, 4])\n        a = controller.create_replica.remote(\"echo\", group_id)\n        b = controller.create_replica.remote(\"add_one\", group_id)\n        c = controller.create_replica.remote(\"tokenizer\", group_id)\n\n        ray.get([a, b, c])\n        url = f\"http://{host}:{port}{root_path}\"\n\n        json = {\n            \"model\": \"echo\",\n            \"task\": \"completions\",\n            \"input\": \"Paris is the capital city of\",\n        }\n        resp = requests.post(url=url, json=json)\n        assert resp.json() == json\n\n        resp = requests.post(url=url,\n                             json={\n                                 \"model\": \"add_one\",\n                                 \"x\": list(range(16)),\n                             })\n        assert resp.text == str(list(range(1, 17)))\n\n        src = \"Paris is the capital city of\"\n        resp = requests.post(url=url, json={\"model\": \"tokenizer\", \"input\": src})\n        tokenizer = Tokenizer.from_pretrained(\"bert-base-uncased\")\n        assert resp.text == str(tokenizer.encode(src).ids)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(ControllerTest(\"test_query\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_basic.py",
    "content": "\"\"\"Test auto sharding with simple computational graphs.\"\"\"\nimport unittest\n\nimport jax\nimport jax.numpy as jnp\nfrom jax.interpreters import pxla\nfrom jax.interpreters.pxla import Chunked, ShardedAxis, NoSharding, Replicated\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nimport optax\n\nfrom alpa import parallelize, ShardParallel\nfrom alpa.util import count_communication_primitives\nfrom alpa.testing import assert_allclose\n\nfrom tests.shard_parallel.test_mlp import assert_close\n\nMB = 1024**2\n\n\nclass AutoShardingBasicTest(unittest.TestCase):\n\n    def setUp(self):\n        assert len(jax.local_devices()) >= 4\n        self.devices = jax.local_devices()[:4]\n        self.method = ShardParallel(devices=self.devices)\n\n    def test_donate_buffer(self):\n\n        @parallelize(donate_argnums=(0,), method=self.method)\n        def add_one(x):\n            x = x + 1\n            return x\n\n        a = jnp.ones((128, 128))\n        b = add_one(a)\n\n        # Assert b is sharded\n        assert (b.sharding_spec == pxla.ShardingSpec(\n            sharding=(NoSharding(), Chunked([4])),\n            mesh_mapping=(ShardedAxis(0),)) or b.sharding_spec\n                == pxla.ShardingSpec(sharding=(Chunked([4]), NoSharding()),\n                                     mesh_mapping=(ShardedAxis(0),)))\n\n    def test_dot_reshape_transpose(self):\n        dim_0 = 64\n        dim_1 = 1024\n\n        def func(a, b):\n            a = jnp.transpose(a, [0, 2, 1])\n            a = jnp.reshape(a, (dim_0, dim_1))\n            b = jnp.reshape(b, (dim_1, dim_0))\n            out = a @ b\n            out = -out\n            return out\n\n        p_func = parallelize(func)\n\n        a = jnp.ones((dim_0, dim_1 // 4, 4))\n        b = jnp.ones((dim_1, dim_0 // 4, 4))\n\n        # Check correctness\n        expected = func(a, b)\n        actual = p_func(a, b)\n        assert_allclose(expected, actual)\n\n    def test_one_by_one_mesh(self):\n\n        @parallelize(method=ShardParallel(devices=self.devices[0:1]))\n        def add_one(x):\n            x = x + 1\n            return x\n\n        a = jnp.ones((128, 128))\n        b = add_one(a)\n\n        assert_allclose(b, a + 1)\n\n    def test_dropout(self):\n\n        class Model(nn.Module):\n\n            @nn.compact\n            def __call__(self, x, deterministic):\n                x = nn.Dense(16, use_bias=False)(x)\n                x = nn.Dropout(0.1, deterministic=deterministic)(x)\n                x = nn.Dense(16, use_bias=False)(x)\n                return x\n\n        x = jnp.ones((32, 32, 16))\n        y = jnp.ones((32, 32, 16))\n\n        # Init model and optimizer\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x, True)\n        tx = optax.sgd(learning_rate=1e-2)\n        state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        @parallelize(method=self.method)\n        def func(state, x, y, rngs):\n\n            def loss_func(params):\n                out = model.apply(params, x, False, rngs=rngs)\n                return jnp.mean((out - y)**2)\n\n            grad = jax.grad(loss_func)(state.params)\n            return state.apply_gradients(grads=grad)\n\n        # Check sharding strategy (data-parallel)\n        executable = func.get_executable(state, x, y, {\"dropout\": rngkey})\n        assert executable.auto_sharding_objective < 1e6\n\n        hlo_ir = executable.get_hlo_text()\n        assert \"u64[1024]{0} iota()\" in hlo_ir  # 1024 = 32 * 32 * 16 / 4 / 4\n        n_total, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir)\n        assert n_total == n_allreduce == 1\n\n    def test_gather(self):\n\n        class Model(nn.Module):\n\n            @nn.compact\n            def __call__(self, x):\n                x = nn.Dense(32, use_bias=False)(x)\n                idx = jnp.arange(16)\n                x = x[:, idx]\n                x = nn.Dense(16, use_bias=False)(x)\n                return x\n\n        x = jnp.ones((256, 32))\n        y = jnp.ones((256, 16))\n\n        # Init model and optimizer\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x)\n        tx = optax.sgd(learning_rate=1e-2)\n        state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        @parallelize(method=self.method)\n        def func(state, x, y):\n\n            def loss_func(params):\n                out = model.apply(params, x)\n                return jnp.mean((out - y)**2)\n\n            grad = jax.grad(loss_func)(state.params)\n            return state.apply_gradients(grads=grad)\n\n        executable = func.get_executable(state, x, y)\n        assert executable.auto_sharding_objective < 1e6\n\n        hlo_ir = executable.get_hlo_text()\n        assert \"gather(f32[64,32]\" in hlo_ir or \"gather(f32[32,64]\" in hlo_ir\n        assert \"scatter(f32[64,32]\" in hlo_ir or \"scatter(f32[32,64]\" in hlo_ir\n        n_total, n_allreduce, _, _, _ = count_communication_primitives(hlo_ir)\n        assert n_total == n_allreduce == 1\n\n    def test_reshape_uneven_partition(self):\n        # TODO(lmzheng): Support the uneven partition of reshape.\n        # But this seems too complicated.\n\n        @parallelize(method=self.method)\n        def func(a):\n            b = a.reshape((8, 18))\n            #b = a.reshape((9, 16))\n            return b\n\n        a = jnp.ones(144)\n        executable = func.get_executable(a)\n        assert_close(executable.auto_sharding_objective, 0)\n\n    def test_argmax(self):\n\n        @parallelize(method=self.method)\n        def func(a):\n            b = jnp.argmax(a, axis=0)\n            return b\n\n        a = jnp.ones((144, 144))\n        executable = func.get_executable(a)\n\n        assert_close(executable.auto_sharding_objective, 0)\n        hlo_ir = executable.get_hlo_text()\n        assert \"(param: f32[144,36])\" in hlo_ir\n\n    def test_sort(self):\n\n        @parallelize(method=self.method)\n        def func(a):\n            b = jnp.argsort(a)\n            return b\n\n        a = jnp.ones((1024,), dtype=jnp.int32)\n        executable = func.get_executable(a)\n\n    def test_gemv(self):\n\n        @parallelize(method=self.method)\n        def func(a, b):\n            return a @ b\n\n        a = jnp.ones((128,), dtype=jnp.float32)\n        b = jnp.ones((128, 256), dtype=jnp.float32)\n        executable = func.get_executable(a, b)\n\n        assert \"f32[128,64]\" in executable.get_hlo_text()\n\n    def test_fast_call(self):\n\n        @parallelize\n        def add_one(x, y):\n            return x + y\n\n        a = jnp.ones((32, 32))\n        b = jnp.ones((32, 32))\n        executable = add_one.get_executable(a, b)\n        c = executable(a, b)\n\n        assert isinstance(c, pxla.ShardedDeviceArray)\n\n        executable.dump_debug_info(\"tmp\")\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(AutoShardingBasicTest(\"test_donate_buffer\"))\n    suite.addTest(AutoShardingBasicTest(\"test_dot_reshape_transpose\"))\n    suite.addTest(AutoShardingBasicTest(\"test_one_by_one_mesh\"))\n    suite.addTest(AutoShardingBasicTest(\"test_dropout\"))\n    suite.addTest(AutoShardingBasicTest(\"test_gather\"))\n    suite.addTest(AutoShardingBasicTest(\"test_reshape_uneven_partition\"))\n    suite.addTest(AutoShardingBasicTest(\"test_argmax\"))\n    suite.addTest(AutoShardingBasicTest(\"test_sort\"))\n    suite.addTest(AutoShardingBasicTest(\"test_gemv\"))\n    suite.addTest(AutoShardingBasicTest(\"test_fast_call\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_bert.py",
    "content": "\"\"\"Test auto sharding on transformer layers and bert models.\"\"\"\n\nimport unittest\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nimport optax\n\nfrom alpa import parallelize, ShardParallel, LocalPhysicalDeviceMesh, AutoShardingOption\nfrom alpa.model.bert_model import (BertConfig, FlaxBertLayerCollection,\n                                   FlaxBertForMaskedLMModule)\nfrom alpa.util import count_communication_primitives\nfrom tests.shard_parallel.test_mlp import (\n    assert_all_replicated, assert_close, assert_column_partitioned,\n    assert_data_parallel_cost, assert_fully_sharded, assert_less_equal,\n    assert_sharded, assert_replicated_column_partitioned,\n    assert_replicated_row_partitioned, assert_row_partitioned, is_fully_sharded,\n    assert_sharding_zero_stage_3)\n\n\nclass AutoShardingAttentionTest(unittest.TestCase):\n\n    def setUp(self):\n        assert len(jax.local_devices()) >= 4\n        self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4])\n        self.as_option = AutoShardingOption()\n\n    def get_device_mesh(self, shape, mesh_alpha, mesh_beta):\n        return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta)\n\n    def run_bert_layers(self, batch_size, seq_len, num_layers, hidden_size,\n                        num_heads, deterministic, use_remat, device_mesh):\n\n        @parallelize(method=ShardParallel(devices=device_mesh,\n                                          auto_sharding_option=self.as_option))\n        def train_step(state, batch, deterministic):\n\n            def loss_func(params):\n                rngs = {\"dropout\": batch[\"rng\"]}\n                out = state.apply_fn(params,\n                                     batch[\"hidden_states\"],\n                                     batch[\"attention_mask\"],\n                                     deterministic,\n                                     rngs=rngs)[0]\n                return jnp.mean((out - batch[\"label\"])**2)\n\n            grads = jax.grad(loss_func)(state.params)\n            return state.apply_gradients(grads=grads)\n\n        # Init model and optimizer\n        hidden_states = jnp.ones((batch_size, seq_len, hidden_size),\n                                 dtype=jnp.float32)\n        attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        label = jnp.ones((batch_size, seq_len, hidden_size), dtype=jnp.float32)\n\n        model = FlaxBertLayerCollection(\n            BertConfig(num_hidden_layers=num_layers,\n                       hidden_size=hidden_size,\n                       intermediate_size=hidden_size * 4,\n                       num_attention_heads=num_heads,\n                       gradient_checkpointing=use_remat))\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, hidden_states, attention_mask)\n        tx = optax.adam(1e-2)\n        state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        # JIT compile\n        state = train_step(\n            state, {\n                \"hidden_states\": hidden_states,\n                \"attention_mask\": attention_mask,\n                \"label\": label,\n                \"rng\": rngkey\n            }, deterministic)\n\n        # Get optimized HLO IR\n        executable = train_step.get_last_executable()\n        return (state, executable.get_hlo_text(),\n                executable.auto_sharding_objective)\n\n    def run_bert_mlm(self, batch_size, seq_len, num_layers, hidden_size,\n                     num_heads, vocab_size, deterministic, device_mesh):\n\n        @parallelize(method=ShardParallel(devices=device_mesh,\n                                          auto_sharding_option=self.as_option))\n        def train_step(state, batch):\n\n            def loss_func(params):\n                rngs = {\"dropout\": batch[\"rng\"]}\n                logits = state.apply_fn(params,\n                                        batch[\"input_ids\"],\n                                        batch[\"attention_mask\"],\n                                        batch[\"token_type_ids\"],\n                                        batch[\"position_ids\"],\n                                        deterministic=deterministic,\n                                        rngs=rngs)[0]\n                label_mask = jnp.where(batch[\"labels\"] > 0, 1.0, 0.0)\n                labels = jax.nn.one_hot(batch[\"labels\"], logits.shape[-1])\n                loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1),\n                                axis=-1)\n                return (label_mask * loss).sum() / label_mask.sum() * 0.1234\n\n            grads = jax.grad(loss_func)(state.params)\n            return state.apply_gradients(grads=grads)\n\n        # Init model and optimizer\n        input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        token_type_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n\n        model = FlaxBertForMaskedLMModule(\n            BertConfig(\n                num_hidden_layers=num_layers,\n                hidden_size=hidden_size,\n                intermediate_size=hidden_size * 4,\n                num_attention_heads=num_heads,\n                vocab_size=vocab_size,\n                max_position_embeddings=seq_len,\n            ))\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, input_ids, attention_mask, token_type_ids,\n                            position_ids)\n        tx = optax.adam(1e-2)\n        state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        # JIT compile\n        state = train_step(\n            state, {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"token_type_ids\": token_type_ids,\n                \"position_ids\": position_ids,\n                \"labels\": labels,\n                \"rng\": rngkey\n            })\n\n        # Get optimized HLO IR\n        executable = train_step.get_last_executable()\n        return (state, executable.get_hlo_text(),\n                executable.auto_sharding_objective)\n\n    def test_bert_layer_data_parallel(self):\n        batch_size = 64\n        seq_len = 64\n        num_layers = 2\n        hidden_size = 32\n        num_heads = 8\n        deterministic = False\n        use_remat = False\n\n        # Test on different logical mesh shapes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_bert_layers(\n                batch_size, seq_len, num_layers, hidden_size, num_heads,\n                deterministic, use_remat, device_mesh)\n\n            assert_data_parallel_cost(state, hlo_ir, objective, device_mesh,\n                                      self.as_option, i)\n\n    def test_bert_layer_model_parallel(self):\n        batch_size = 8\n        seq_len = 8\n        num_layers = 2\n        hidden_size = 128\n        num_heads = 8\n        deterministic = False\n        use_remat = False\n\n        # Test on different logical mesh shapes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_bert_layers(\n                batch_size, seq_len, num_layers, hidden_size, num_heads,\n                deterministic, use_remat, device_mesh)\n\n            # Check communication cost\n            expected = (num_layers * 4 - 1) * device_mesh.all_reduce_cost(\n                batch_size * seq_len * hidden_size * 4, i)\n            assert_close(objective, expected)\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(hlo_ir))\n            if self.as_option.prefer_reduce_scatter:\n                assert n_total == num_layers * 4 - 1\n                assert n_all_reduce == num_layers * 4 - 1\n                assert n_total == n_all_reduce\n            else:\n                assert n_total == num_layers * 4 - 1\n                assert n_all_reduce == num_layers * 4 - 1\n                assert n_total == n_all_reduce\n\n            # Check sharding specification\n            for k in range(num_layers):\n                params = state.params[\"params\"][str(k)]\n                weights = [\n                    params[\"attention\"][\"self\"][\"qvk_combined\"][\"kernel\"],\n                    params[\"attention\"][\"output\"][\"dense\"][\"kernel\"],\n                    params[\"intermediate\"][\"dense\"][\"kernel\"],\n                    params[\"output\"][\"dense\"][\"kernel\"],\n                ]\n\n                for j in range(len(weights)):\n                    if j % 2 == 0:\n                        assert_column_partitioned(weights[j], mesh_shape[i], i)\n                    else:\n                        assert_row_partitioned(weights[j], mesh_shape[i], i)\n\n    def test_bert_layer_2d_mesh(self):\n        batch_size = 8\n        seq_len = 8\n        num_layers = 2\n        hidden_size = 128\n        num_heads = 8\n        deterministic = False\n        use_remat = False\n\n        # Test on different logical mesh shapes\n        mesh_shape = [2, 2]\n        device_mesh = self.get_device_mesh(mesh_shape, [2, 2], [1, 0.1])\n        state, hlo_ir, objective = self.run_bert_layers(batch_size, seq_len,\n                                                        num_layers, hidden_size,\n                                                        num_heads,\n                                                        deterministic,\n                                                        use_remat, device_mesh)\n\n        # Check communication cost\n        params = jax.tree_util.tree_leaves(state.params)\n        expected = (sum(\n            device_mesh.all_reduce_cost(\n                np.prod(x.shape) * 4 / mesh_shape[1], 0)\n            for x in params) + device_mesh.all_reduce_cost(\n                batch_size * seq_len * hidden_size * 4 / mesh_shape[0], 1) *\n                    (num_layers * 4 - 1))\n        assert_close(objective, expected)\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_ir,\n                                           ignore_scalar_all_reduce=True))\n        if self.as_option.prefer_reduce_scatter:\n            assert n_all_reduce == num_layers * 4 - 1\n            assert n_reduce_scatter == 2\n            assert n_all_gather <= 2\n            assert n_total == n_all_reduce + n_reduce_scatter + n_all_gather\n        else:\n            assert n_all_reduce == num_layers * 4\n            assert n_total == n_all_reduce\n\n        # Check sharding specification\n        if self.as_option.prefer_reduce_scatter:\n            for weight in jax.tree_util.tree_leaves(state.opt_state):\n                if len(weight.shape) > 1:\n                    assert_fully_sharded(weight)\n        else:\n            for k in range(num_layers):\n                params = state.params[\"params\"][str(k)]\n                weights = [\n                    params[\"attention\"][\"self\"][\"qvk_combined\"][\"kernel\"],\n                    params[\"attention\"][\"output\"][\"dense\"][\"kernel\"],\n                    params[\"intermediate\"][\"dense\"][\"kernel\"],\n                    params[\"output\"][\"dense\"][\"kernel\"],\n                ]\n\n                for j in range(len(weights)):\n                    if j % 2 == 0:\n                        assert_replicated_column_partitioned(\n                            weights[j], mesh_shape)\n                    else:\n                        assert_replicated_row_partitioned(\n                            weights[j], mesh_shape)\n\n    def test_bert_layer_force_batch_dim_mapping(self):\n        batch_size = 64\n        seq_len = 64\n        num_layers = 2\n        hidden_size = 32\n        num_heads = 8\n        deterministic = False\n        use_remat = False\n        self.as_option.force_batch_dim_to_mesh_dim = 0\n\n        # data parallel\n        device_mesh = self.get_device_mesh([4, 1], [1, 1], [1, 1])\n        state, hlo_ir, objective = self.run_bert_layers(batch_size, seq_len,\n                                                        num_layers, hidden_size,\n                                                        num_heads,\n                                                        deterministic,\n                                                        use_remat, device_mesh)\n        assert_data_parallel_cost(state, hlo_ir, objective, device_mesh,\n                                  self.as_option, 0)\n\n        # model parallel (case 1)\n        device_mesh = self.get_device_mesh([1, 4], [1, 1], [1, 1])\n        state, hlo_ir, objective = self.run_bert_layers(batch_size, seq_len,\n                                                        num_layers, hidden_size,\n                                                        num_heads,\n                                                        deterministic,\n                                                        use_remat, device_mesh)\n        expected = (num_layers * 4 - 1) * device_mesh.all_reduce_cost(\n            batch_size * seq_len * hidden_size * 4, 1)\n        assert_close(objective, expected)\n\n        # model parallel (case 2)\n        batch_size = 1\n        device_mesh = self.get_device_mesh([1, 4], [1, 1], [1, 1])\n        state, hlo_ir, objective = self.run_bert_layers(batch_size, seq_len,\n                                                        num_layers, hidden_size,\n                                                        num_heads,\n                                                        deterministic,\n                                                        use_remat, device_mesh)\n        expected = (num_layers * 4 - 1) * device_mesh.all_reduce_cost(\n            batch_size * seq_len * hidden_size * 4, 1)\n        assert_close(objective, expected)\n\n    def test_embedding_2d_mesh(self):\n        vocab_size = 1024\n        hidden_size = 8\n        batch_size = 8\n        seq_len = 8\n        mesh_shape = [2, 2]\n\n        # Model and training step definition\n        class Model(nn.Module):\n            \"\"\"Tied input and output embedding.\"\"\"\n\n            def setup(self):\n                self.embed = nn.Embed(vocab_size, hidden_size)\n\n            def __call__(self, x):\n                x = self.embed(x)\n                embed = self.embed.variables[\"params\"][\"embedding\"]\n                x = x @ embed.T\n                return x\n\n        logical_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n\n        @parallelize(method=ShardParallel(devices=logical_mesh))\n        def func(state, x, y):\n\n            def loss_func(params):\n                out = state.apply_fn(params, x)\n                y_ = jax.nn.one_hot(y, out.shape[-1])\n                loss = -jnp.sum(y_ * jax.nn.log_softmax(out, axis=-1), axis=-1)\n                return loss.sum()\n\n            grads = jax.grad(loss_func)(state.params)\n            return state.apply_gradients(grads=grads)\n\n        # Init model and optimizer\n        x = jnp.ones((batch_size, seq_len), np.int32)\n        y = jnp.ones((batch_size, seq_len), np.int32)\n\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x)\n        tx = optax.adam(1e-2)\n        state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        # JIT Compile\n        state = func(state, x, y)\n\n        # Check communication cost\n        executable = func.get_last_executable()\n        hlo_ir = executable.get_hlo_text()\n        objective = executable.auto_sharding_objective\n\n        expected = (\n            logical_mesh.all_reduce_cost(\n                vocab_size * hidden_size * 4 / mesh_shape[1], 0) +\n            logical_mesh.all_reduce_cost(\n                batch_size * seq_len * hidden_size * 4 / mesh_shape[0], 1) * 2 +\n            logical_mesh.all_reduce_cost(\n                batch_size * seq_len * 4 / mesh_shape[0], 1) * 2)\n\n        assert_close(objective, expected)\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_ir))\n        assert n_total == n_all_reduce\n\n    def test_bert_mlm_data_parallel(self):\n        batch_size = 32\n        seq_len = 32\n        num_layers = 2\n        hidden_size = 16\n        num_heads = 4\n        vocab_size = 128\n        deterministic = False\n\n        # Test on different logical mesh shapes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_bert_mlm(\n                batch_size, seq_len, num_layers, hidden_size, num_heads,\n                vocab_size, deterministic, device_mesh)\n\n            if self.as_option.force_zero_stage_3:\n                # only the weight and opt_state of token_embed is not sharded\n                assert_sharding_zero_stage_3(state, 3)\n                continue\n\n            assert_data_parallel_cost(state, hlo_ir, objective, device_mesh,\n                                      self.as_option, i, 1)\n\n    @unittest.skip(\"This test is broken after we disallow some replicated iota\")\n    def test_bert_mlm_model_parallel(self):\n        batch_size = 16\n        seq_len = 16\n        num_layers = 2\n        hidden_size = 128\n        num_heads = 4\n        vocab_size = 512\n        deterministic = False\n        self.as_option.allow_all_gather = False  # Temporary hack\n        self.as_option.allow_all_to_all = False  # Temporary hack\n\n        # Test on different logical mesh shapes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_bert_mlm(\n                batch_size, seq_len, num_layers, hidden_size, num_heads,\n                vocab_size, deterministic, device_mesh)\n\n            # Check communication cost\n            # expected_cost = embed.forward (1) + embed.backward(2) +\n            #                 LM_head.forward (1) + LM_head.backward (1) +\n            #                 LM_head.weight.backward (1) +  log_softmax.forward (2) +\n            #                 transformer.forward (2 * num_layers) + transformer.backward (2 * num_layers)\n            #\n            # Note that the final cost is different from this estimated cost in ILP solver.\n            # The SPMD partitioner will eliminate some unnecessary communication in favor of\n            # redundant computation (e.g., it will elimiate the all-reduce in embed.backward).\n            expected = (\n                device_mesh.all_reduce_cost(\n                    batch_size * seq_len * hidden_size * 4, i) * 5 +\n                device_mesh.all_reduce_cost(hidden_size * hidden_size * 4, i) +\n                device_mesh.all_reduce_cost(batch_size * seq_len * 4, i) * 2 +\n                device_mesh.all_reduce_cost(\n                    batch_size * seq_len * hidden_size * 4, i) * num_layers * 4)\n            assert_close(objective, expected)\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(hlo_ir))\n\n            # real number of all-reduce = transformers (4 * num_layers) + log_softmax (2) +\n            #                             embed.forward (1) + embad.backward (1)\n            assert n_all_reduce == num_layers * 4 + 4\n            assert n_total == n_all_reduce\n\n            # Check sharding specification\n            embed_weight = state.params[\"params\"][\"bert\"][\"embeddings\"][\n                \"word_embeddings\"][\"embedding\"]\n            lm_head = state.params[\"params\"][\"cls\"][\"predictions\"][\"transform\"][\n                \"dense\"][\"kernel\"]\n            assert_row_partitioned(embed_weight, mesh_shape[i], i)\n            assert_all_replicated(lm_head, np.prod(mesh_shape))\n\n            for k in range(num_layers):\n                params = state.params[\"params\"][\"bert\"][\"encoder\"][\"layer\"][str(\n                    k)]\n                weights = [\n                    params[\"attention\"][\"self\"][\"qvk_combined\"][\"kernel\"],\n                    params[\"attention\"][\"output\"][\"dense\"][\"kernel\"],\n                    params[\"intermediate\"][\"dense\"][\"kernel\"],\n                    params[\"output\"][\"dense\"][\"kernel\"],\n                ]\n\n                for j in range(len(weights)):\n                    if j % 2 == 0:\n                        assert_column_partitioned(weights[j], mesh_shape[i], i)\n                    else:\n                        assert_row_partitioned(weights[j], mesh_shape[i], i)\n\n    def test_bert_mlm_2d_mesh(self):\n        batch_size = 4\n        seq_len = 4\n        num_layers = 2\n        hidden_size = 512\n        num_heads = 4\n        vocab_size = 4096\n        deterministic = False\n        # To generate the desired strategy, we have to turn off mixed mesh shape and all-gather\n        # and enable recomputing heavy ops.\n        self.as_option.allow_recompute_heavy_op = True\n        self.as_option.allow_all_gather = False\n        self.as_option.allow_mixed_mesh_shape = False\n\n        mesh_shape = [2, 2]\n        device_mesh = self.get_device_mesh(mesh_shape, [2, 2], [1, 0.1])\n\n        state, hlo_ir, objective = self.run_bert_mlm(batch_size, seq_len,\n                                                     num_layers, hidden_size,\n                                                     num_heads, vocab_size,\n                                                     deterministic, device_mesh)\n\n        # Check communication cost.\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_ir,\n                                           ignore_scalar_all_reduce=True))\n        if self.as_option.prefer_reduce_scatter:\n            assert n_all_reduce == 4 * num_layers + 2 + 2\n            assert n_reduce_scatter <= 3  # The correct number should be 2,\n            # but GpuMultiOutputFusion can make\n            # some reduce-scatter unable to be combined\n            assert n_all_gather <= 2\n            assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter\n        else:\n            # real number of all-reduce = transformers (4 * num_layers) + log_softmax (2) +\n            #                             embed.forward (1) + embad.backward (1) + weights (1)\n            assert n_all_reduce == 4 * num_layers + 2 + 2 + 1\n            assert n_total == n_all_reduce\n\n        # Check sharding specification\n        assert \"s32[4,4,4096]{2,1,0} iota()\" not in hlo_ir\n        assert \"s32[2,4,2048]{2,1,0} iota()\" in hlo_ir\n\n        if self.as_option.prefer_reduce_scatter:\n            num_not_sharded = 0  # allow the token_type_embeddings not partitioned.\n            for weight in jax.tree_util.tree_leaves(state.opt_state):\n                if len(weight.shape) > 1:\n                    if not is_fully_sharded(weight):\n                        num_not_sharded += 1\n            assert num_not_sharded <= 2\n        else:\n            embed_weight = (state.params[\"params\"][\"bert\"][\"embeddings\"]\n                            [\"word_embeddings\"][\"embedding\"])\n            lm_head = (state.params[\"params\"][\"cls\"][\"predictions\"][\"transform\"]\n                       [\"dense\"][\"kernel\"])\n\n            assert_replicated_row_partitioned(embed_weight, mesh_shape)\n            assert_all_replicated(lm_head, np.prod(mesh_shape))\n\n            for k in range(num_layers):\n                params = state.params[\"params\"][\"bert\"][\"encoder\"][\"layer\"][str(\n                    k)]\n                weights = [\n                    params[\"attention\"][\"self\"][\"qvk_combined\"][\"kernel\"],\n                    params[\"attention\"][\"output\"][\"dense\"][\"kernel\"],\n                    params[\"intermediate\"][\"dense\"][\"kernel\"],\n                    params[\"output\"][\"dense\"][\"kernel\"],\n                ]\n\n                for j in range(len(weights)):\n                    if j % 2 == 0:\n                        assert_replicated_column_partitioned(\n                            weights[j], mesh_shape)\n                    else:\n                        assert_replicated_row_partitioned(\n                            weights[j], mesh_shape)\n\n    def test_bert_layer_data_parallel_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_bert_layer_data_parallel()\n\n    def test_bert_layer_model_parallel_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_bert_layer_model_parallel()\n\n    def test_bert_layer_2d_mesh_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_bert_layer_2d_mesh()\n\n    def test_bert_mlm_data_parallel_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_bert_mlm_data_parallel()\n\n    def test_bert_mlm_data_parallel_reduce_scatter_zero_3(self):\n        self.as_option.force_zero_stage_3 = True\n        self.as_option.force_zero_stage_3_all_gather_threshold = 1\n        self.test_bert_mlm_data_parallel()\n\n    @unittest.skip(\"This test is broken after we disallow some replicated iota.\"\n                  )\n    def test_bert_mlm_model_parallel_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_bert_mlm_model_parallel()\n\n    def test_bert_mlm_2d_mesh_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_bert_mlm_2d_mesh()\n\n    def test_bert_layer_model_parallel_remat(self):\n        batch_size = 8\n        seq_len = 8\n        num_layers = 2\n        hidden_size = 128\n        num_heads = 8\n        deterministic = False\n        use_remat = True\n\n        # Test on different logical mesh shapes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_bert_layers(\n                batch_size, seq_len, num_layers, hidden_size, num_heads,\n                deterministic, use_remat, device_mesh)\n\n            expected = (num_layers * 6 - 1) * device_mesh.all_reduce_cost(\n                batch_size * seq_len * hidden_size * 4, i)\n            assert_close(objective, expected)\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(hlo_ir))\n            assert n_total == num_layers * 6 - 1\n            assert n_all_reduce == num_layers * 6 - 1\n            assert n_total == n_all_reduce\n\n\ndef suite():\n    suite = unittest.TestSuite()\n\n    def add(name):\n        suite.addTest(AutoShardingAttentionTest(name))\n\n    add(\"test_bert_layer_data_parallel\")\n    add(\"test_bert_layer_model_parallel\")\n    add(\"test_bert_layer_2d_mesh\")\n    add(\"test_bert_layer_force_batch_dim_mapping\")\n\n    add(\"test_embedding_2d_mesh\")\n\n    add(\"test_bert_mlm_data_parallel\")\n    add(\"test_bert_mlm_model_parallel\")\n    add(\"test_bert_mlm_2d_mesh\")\n\n    add(\"test_bert_layer_data_parallel_reduce_scatter\")\n    add(\"test_bert_layer_model_parallel_reduce_scatter\")\n    add(\"test_bert_layer_2d_mesh_reduce_scatter\")\n\n    add(\"test_bert_mlm_data_parallel_reduce_scatter\")\n    add(\"test_bert_mlm_model_parallel_reduce_scatter\")\n    add(\"test_bert_mlm_2d_mesh_reduce_scatter\")\n    add(\"test_bert_mlm_data_parallel_reduce_scatter_zero_3\")\n\n    add(\"test_bert_layer_model_parallel_remat\")\n\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_conv.py",
    "content": "\"\"\"Test auto sharding with convolution nets.\"\"\"\n\nimport unittest\nfrom typing import Any\n\nfrom flax import linen as nn\nfrom flax.training import train_state, dynamic_scale as dynamic_scale_lib\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\nfrom alpa import parallelize, ShardParallel, LocalPhysicalDeviceMesh, AutoShardingOption\nfrom alpa.util import map_to_shape, count_communication_primitives\n\nfrom tests.shard_parallel.test_mlp import assert_close, assert_all_replicated, is_sharded\n\n\nclass TrainState(train_state.TrainState):\n    batch_stats: Any\n    dynamic_scale: dynamic_scale_lib.DynamicScale\n\n\ndef assert_data_parallel_cost(state,\n                              hlo_ir,\n                              objective,\n                              device_mesh,\n                              as_option,\n                              mesh_dim,\n                              allow_not_sharded_params=0):\n    params = jax.tree_util.tree_leaves(state.params)\n    opt_state = jax.tree_util.tree_leaves(state.opt_state)\n    batch_stats = jax.tree_util.tree_leaves(state.batch_stats)\n\n    # Check communication cost\n    replicated_penalty = int(\n        device_mesh.all_reduce_cost(1, 0) + device_mesh.all_reduce_cost(1, 1))\n    weight_sync = sum(\n        device_mesh.all_reduce_cost(np.prod(x.shape) * 4, mesh_dim) +\n        replicated_penalty for x in params)\n    num_batch_norm = len(batch_stats) // 2\n    batch_norm_sync = 2 * sum(\n        device_mesh.all_reduce_cost(np.prod(x.shape) * 4, mesh_dim) +\n        replicated_penalty for x in batch_stats)\n    expected = weight_sync + batch_norm_sync\n\n    assert_close(objective, expected, atol=0.05)\n\n    # Check numbers of communication primitives\n    n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n        count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True))\n\n    if as_option.prefer_reduce_scatter:\n        assert n_all_reduce == num_batch_norm * 2\n        assert n_reduce_scatter > 0\n        assert n_all_gather <= 2\n        assert n_total == n_all_reduce + n_reduce_scatter + n_all_gather\n    else:\n        assert n_all_reduce == 1 + num_batch_norm * 2\n        assert n_total == n_all_reduce\n\n    if as_option.prefer_reduce_scatter:\n        num_not_sharded = 0\n        for weight in opt_state:\n            if not is_sharded(weight) and len(weight.shape) > 1:\n                num_not_sharded += 1\n        assert num_not_sharded == 0\n    else:\n        for weight in params:\n            assert_all_replicated(weight, np.prod(device_mesh.shape))\n\n\nclass AutoShardingConvTest(unittest.TestCase):\n\n    def setUp(self):\n        assert len(jax.local_devices()) >= 4\n        self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4])\n        self.as_option = AutoShardingOption()\n\n    def get_device_mesh(self, shape, mesh_alpha, mesh_beta):\n        return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta)\n\n    def run_n_layer_conv(self,\n                         num_layers,\n                         batch_size,\n                         image_size,\n                         channel,\n                         device_mesh,\n                         use_bias=False,\n                         is_depthwise=False):\n        if not is_depthwise:\n\n            class Model(nn.Module):\n\n                @nn.compact\n                def __call__(self, x, train=True):\n                    for i in range(num_layers):\n                        x = nn.Conv(features=channel,\n                                    kernel_size=(3, 3),\n                                    strides=(2, 2),\n                                    use_bias=use_bias)(x)\n                        x = nn.BatchNorm(use_running_average=not train)(x)\n                        x = nn.relu(x)\n                        x = nn.max_pool(x,\n                                        window_shape=(2, 2),\n                                        strides=(1, 1),\n                                        padding=\"SAME\")\n                    return x\n\n            x = jnp.ones((batch_size, image_size, image_size, channel))\n            out_image_size = image_size // (2**num_layers)\n            y = jnp.ones((batch_size, out_image_size, out_image_size, channel))\n        else:\n\n            class Model(nn.Module):\n\n                @nn.compact\n                def __call__(self, x, train=True):\n                    x = nn.Conv(features=8 * channel,\n                                kernel_size=(3, 3),\n                                strides=(1, 1),\n                                use_bias=use_bias)(x)\n                    x = nn.Conv(features=8 * channel,\n                                kernel_size=(3, 3),\n                                strides=(1, 1),\n                                feature_group_count=8 * channel,\n                                use_bias=use_bias)(x)\n                    x = nn.Conv(features=channel,\n                                kernel_size=(3, 3),\n                                strides=(1, 1),\n                                use_bias=use_bias)(x)\n                    x = nn.relu(x)\n                    x = nn.BatchNorm(use_running_average=not train)(x)\n                    return x\n\n            x = jnp.ones((batch_size, image_size, image_size, channel))\n            y = jnp.ones((batch_size, image_size, image_size, channel))\n\n        @parallelize(method=ShardParallel(devices=device_mesh,\n                                          auto_sharding_option=self.as_option))\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out, new_model_state = state.apply_fn(\n                    {\n                        \"params\": params,\n                        \"batch_stats\": state.batch_stats\n                    },\n                    batch[\"x\"],\n                    mutable=['batch_stats'])\n                loss = jnp.mean((out - batch[\"y\"])**2)\n                return loss, new_model_state\n\n            grads, new_model_state = jax.grad(loss_func,\n                                              has_aux=True)(state.params)\n            new_state = state.apply_gradients(\n                grads=grads, batch_stats=new_model_state['batch_stats'])\n            return new_state\n\n        # Init train state\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x)\n        tx = optax.sgd(0.1, momentum=0.9)\n        state = TrainState.create(apply_fn=model.apply,\n                                  params=params[\"params\"],\n                                  tx=tx,\n                                  batch_stats=params[\"batch_stats\"],\n                                  dynamic_scale=None)\n\n        # JIT compile\n        state = train_step(state, {\"x\": x, \"y\": y})\n\n        # Get optimized HLO IR\n        executable = train_step.get_last_executable()\n        return (state, executable.get_hlo_text(),\n                executable.auto_sharding_objective)\n\n    def test_n_layer_conv_data_parallel(self):\n        batch_size = 16\n        image_size = 16\n        num_layers = 3\n        channel = 4\n\n        # Test on different device meshes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_n_layer_conv(\n                num_layers, batch_size, image_size, channel, device_mesh)\n\n            assert_data_parallel_cost(state, hlo_ir, objective, device_mesh,\n                                      self.as_option, i)\n\n    def test_n_layer_conv_model_parallel(self):\n        batch_size = 8\n        image_size = 16\n        num_layers = 4\n        channel = 256\n\n        # Test on different device meshes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_n_layer_conv(\n                num_layers, batch_size, image_size, channel, device_mesh)\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(hlo_ir,\n                                               ignore_scalar_all_reduce=True))\n\n            assert n_all_reduce == num_layers - 1\n            assert n_total == n_all_reduce\n\n    def test_n_layer_conv_2d_mesh(self):\n        batch_size = 8\n        image_size = 32\n        num_layers = 4\n        channel = 8\n        self.as_option.allow_mixed_mesh_shape = False\n\n        device_mesh = self.get_device_mesh([2, 2], [1, 1], [1, 0.1])\n        state, hlo_ir, objective = self.run_n_layer_conv(\n            num_layers, batch_size, image_size, channel, device_mesh)\n\n        # Check numbers of communication primitives\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = (\n            count_communication_primitives(hlo_ir,\n                                           ignore_scalar_all_reduce=True))\n        if self.as_option.prefer_reduce_scatter:\n            assert n_reduce_scatter > 0\n        if self.as_option.allow_mixed_mesh_shape:\n            assert n_all_to_all > 0\n\n    def test_n_layer_conv_2d_mesh_mixed_shape(self):\n        self.as_option.allow_mixed_mesh_shape = True\n        self.test_n_layer_conv_2d_mesh()\n\n    def test_n_layer_conv_data_parallel_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_n_layer_conv_data_parallel()\n\n    def test_n_layer_conv_2d_mesh_mixed_shape_reduce_scatter(self):\n        self.as_option.allow_mixed_mesh_shape = True\n        self.as_option.prefer_reduce_scatter = True\n        self.test_n_layer_conv_2d_mesh()\n\n    def test_n_layer_depthwise_conv_model_parallel(self):\n        batch_size = 4\n        image_size = 8\n        num_layers = 2\n        channel = 256\n\n        # Test on different device meshes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_n_layer_conv(num_layers,\n                                                             batch_size,\n                                                             image_size,\n                                                             channel,\n                                                             device_mesh,\n                                                             is_depthwise=True)\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(hlo_ir,\n                                               ignore_scalar_all_reduce=True))\n            assert n_all_reduce == 1\n            assert n_total == n_all_reduce\n\n\ndef suite():\n    suite = unittest.TestSuite()\n\n    def add(name):\n        suite.addTest(AutoShardingConvTest(name))\n\n    add(\"test_n_layer_conv_data_parallel\")\n    add(\"test_n_layer_conv_model_parallel\")\n    add(\"test_n_layer_conv_2d_mesh\")\n    add(\"test_n_layer_conv_2d_mesh_mixed_shape\")\n\n    add(\"test_n_layer_conv_data_parallel_reduce_scatter\")\n    add(\"test_n_layer_conv_2d_mesh_mixed_shape_reduce_scatter\")\n\n    add(\"test_n_layer_depthwise_conv_model_parallel\")\n\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_gradient_accumulation.py",
    "content": "\"\"\"\nTest the numerical correctness of shard parallel with gradient accumulation.\n\"\"\"\nimport os\nimport unittest\n\nimport numpy as np\n\nfrom flax import linen as nn\nimport jax\nimport jax.numpy as jnp\nimport ray\n\nfrom alpa import (init, shutdown, parallelize, grad, LocalPhysicalDeviceMesh,\n                  ShardParallel)\nfrom alpa.device_mesh import (get_global_cluster, get_global_physical_mesh,\n                              set_global_physical_mesh)\nfrom alpa.shard_parallel.auto_sharding import AutoShardingOption\nfrom alpa.util import count_communication_primitives\nfrom alpa.testing import assert_allclose\nfrom alpa.test_install import get_mlp_train_state_and_step\n\n\nclass GradAccumulationTest(unittest.TestCase):\n\n    def setUp(self):\n        os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"] = \"platform\"\n        self.as_option = AutoShardingOption(allow_all_to_all=False)\n\n    def run_gradient_accumulation(self, cluster, use_2d_mesh):\n        if cluster == \"ray\":\n            physical_mesh = get_global_physical_mesh()\n            if physical_mesh is None:\n                init(cluster=\"ray\")\n                physical_mesh = get_global_cluster().get_physical_mesh()\n                set_global_physical_mesh(physical_mesh)\n            logical_mesh = physical_mesh.get_logical_mesh()\n        else:\n            physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4])\n            if use_2d_mesh:\n                logical_mesh = physical_mesh.get_logical_mesh([2, 2], [1, 1],\n                                                              [1, 1])\n            else:\n                logical_mesh = physical_mesh.get_logical_mesh([1, 4], [1, 1],\n                                                              [1, 1])\n\n        state, batch, train_step = get_mlp_train_state_and_step(batch_size=256,\n                                                                hidden_size=16,\n                                                                num_layers=2)\n\n        # Serial execution\n        state_expected = train_step(state, batch)[0]\n\n        # Parallel execution\n        p_train_step = parallelize(train_step,\n                                   method=ShardParallel(\n                                       devices=logical_mesh,\n                                       num_micro_batches=2,\n                                       auto_sharding_option=self.as_option))\n        state_actual = p_train_step(state, batch)[0]\n\n        # Check results\n        assert_allclose(state_expected.params,\n                        state_actual.params,\n                        atol=5e-4,\n                        rtol=5e-4)\n\n        # Check sharding strategy\n        executable = p_train_step.get_last_executable()\n        hlo_text = executable.get_hlo_text()\n        if self.as_option.prefer_reduce_scatter:\n            _, accumulate_grad, apply_grad = hlo_text.split(\"HloModule\")\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(accumulate_grad))\n            assert n_total == n_all_reduce + n_reduce_scatter == 1\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(apply_grad))\n            assert n_total == n_all_gather == 1\n        else:\n            assert executable.grad_sync_channel_ids.count(\".\") == 2\n            _, accumulate_grad, apply_grad = hlo_text.split(\"HloModule\")\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(accumulate_grad))\n            if use_2d_mesh:\n                # TODO(lmzheng): investigate why n_total is 4 not 2\n                assert n_total == n_all_reduce\n            else:\n                assert n_total == n_all_reduce == 1\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(apply_grad))\n            assert n_total == 0\n\n        executable.dump_debug_info(\"tmp\")\n\n        if cluster == \"ray\":\n            shutdown()\n\n    def test_gradient_accumulation_single_host(self):\n        self.run_gradient_accumulation(\"local\", use_2d_mesh=False)\n\n    def test_gradient_accumulation_multi_host(self):\n        self.run_gradient_accumulation(\"ray\", use_2d_mesh=False)\n\n    def test_gradient_accumulation_2d_mesh(self):\n        self.run_gradient_accumulation(\"local\", use_2d_mesh=True)\n\n    def test_gradient_accumulation_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.run_gradient_accumulation(\"local\", use_2d_mesh=False)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(\n        GradAccumulationTest(\"test_gradient_accumulation_single_host\"))\n    suite.addTest(GradAccumulationTest(\"test_gradient_accumulation_multi_host\"))\n    suite.addTest(GradAccumulationTest(\"test_gradient_accumulation_2d_mesh\"))\n    suite.addTest(\n        GradAccumulationTest(\"test_gradient_accumulation_reduce_scatter\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_manual.py",
    "content": "\"\"\"\nTest the manual sharding spec.\n\"\"\"\nimport unittest\n\nimport jax\nfrom jax.experimental.pjit import PartitionSpec\nfrom jax.tree_util import tree_map\nimport jax.numpy as jnp\n\nimport alpa\nfrom alpa import (AutoShardingOption, LocalPhysicalDeviceMesh,\n                  ManualShardingOption, ShardParallel, parallelize)\nfrom alpa.testing import HloParser\n\n\nclass ManualShardingTest(unittest.TestCase):\n\n    def setUp(self):\n        self.as_option = AutoShardingOption(enable_auto_sharding=False)\n        self.devices = LocalPhysicalDeviceMesh(jax.local_devices()[:4])\n        self.devices = self.devices.get_logical_mesh((2, 2), (1, 1), (1, 1))\n        self.mesh_axis_names = (\"data\", \"model\")\n\n    def _get_fn_manual_sharding_with(self,\n                                     fn,\n                                     ms_option,\n                                     *args,\n                                     num_microbatches=None,\n                                     batch_argnums=(1,)):\n        method = ShardParallel(\n            devices=self.devices,\n            num_micro_batches=num_microbatches,\n            auto_sharding_option=self.as_option,\n            manual_sharding_option=ms_option,\n        )\n        parallelized = parallelize(fn,\n                                   method=method,\n                                   batch_argnums=batch_argnums)\n        return parallelized.get_executable(*args).get_hlo_text()\n\n    def test_set_input(self):\n\n        def fn(a, b):\n            return a + b\n\n        a = jnp.ones((6, 4))\n        b = jnp.ones((6, 4))\n        in_axis_resources = (PartitionSpec(None, \"model\"),\n                             PartitionSpec(None, \"model\"))\n        ms_option = ManualShardingOption(self.mesh_axis_names,\n                                         in_axis_resources=in_axis_resources)\n        text = self._get_fn_manual_sharding_with(fn, ms_option, a, b)\n        text = HloParser.get_param_line(text)\n        assert \"param: f32[6,2]\" in text and \"param.1: f32[6,2]\" in text\n        in_axis_resources = (PartitionSpec(\"data\", None),\n                             PartitionSpec(\"data\", \"model\"))\n        ms_option = ManualShardingOption(self.mesh_axis_names,\n                                         in_axis_resources=in_axis_resources)\n        text = self._get_fn_manual_sharding_with(fn, ms_option, a, b)\n        text = HloParser.get_param_line(text)\n        assert \"param: f32[3,4]\" in text and \"param.1: f32[3,2]\" in text\n        in_axis_resources = (None, PartitionSpec(\"data\", None))\n        ms_option = ManualShardingOption(self.mesh_axis_names,\n                                         in_axis_resources=in_axis_resources)\n        text = self._get_fn_manual_sharding_with(fn, ms_option, a, b)\n        text = HloParser.get_param_line(text)\n        assert \"param: f32[6,4]\" in text and \"param.1: f32[3,4]\" in text\n\n    def test_set_output(self):\n\n        def fn(a):\n            return a**2, a + 1, a * 2, a / 2\n\n        a = jnp.ones((6, 4))\n        out_axis_resources = (PartitionSpec(\"data\", None), None,\n                              PartitionSpec(None, \"model\"),\n                              PartitionSpec(\"data\", \"model\"))\n        ms_option = ManualShardingOption(self.mesh_axis_names,\n                                         out_axis_resources=out_axis_resources)\n        text = self._get_fn_manual_sharding_with(fn, ms_option, a)\n        text = HloParser.get_root_line(text)\n        assert (\"(f32[3,4]{1,0}, f32[6,4]{1,0}, f32[6,2]{1,0}, f32[3,2]{1,0}\"\n                in text)\n\n    def test_grad_acc(self):\n\n        def fn(params, batch):\n            x, tgt = batch\n\n            def loss_fn(params):\n                w1, b1, w2, b2 = params\n                y = jax.nn.relu(x @ w1 + b1)\n                z = jax.nn.softmax(y @ w2 + b2)\n                return jnp.mean((z - tgt)**2)\n\n            grads = alpa.grad(loss_fn)(params)\n            new_params = tree_map(lambda p, g: p - g, params, grads)\n            return new_params\n\n        batch_size = 64\n        x = jnp.ones((batch_size, 6))\n        tgt = jnp.ones((batch_size, 10))\n        params = (jnp.ones((6, 8)), jnp.ones((8,)), jnp.ones(\n            (8, 10)), jnp.ones((10,)))\n        batch = (x, tgt)\n        in_axis_resources = ((PartitionSpec(None,\n                                            \"model\"), PartitionSpec(\"model\"),\n                              PartitionSpec(\"model\",\n                                            None), PartitionSpec(None)),\n                             (PartitionSpec(\"data\",\n                                            None), PartitionSpec(\"data\", None)))\n        ms_option = ManualShardingOption(self.mesh_axis_names,\n                                         in_axis_resources=in_axis_resources)\n        text = self._get_fn_manual_sharding_with(fn,\n                                                 ms_option,\n                                                 params,\n                                                 batch,\n                                                 num_microbatches=2)\n        apply_grad_start = text.find(\"HloModule\", 1)\n        acc_grad_text = text[:apply_grad_start]\n        apply_grad_text = text[apply_grad_start:]\n        # 1. Accumulate grad:\n        acc_grad_params = HloParser.get_param_line(acc_grad_text)\n        acc_grad_param_shapes = HloParser.parse_param_shapes(acc_grad_params)\n        acc_grad_root = HloParser.get_root_line(acc_grad_text)\n        acc_grad_root_shapes = HloParser.parse_root_shapes(acc_grad_root)\n\n        param_shape = (\"f32[6,4]\", \"f32[4]\", \"f32[4,10]\", \"f32[10]\")\n        # batch_size / num_microbatches / data_parallel\n        batch_shape = (\"f32[16,6]\", \"f32[16,10]\")\n        assert acc_grad_param_shapes == param_shape + batch_shape + param_shape\n        assert acc_grad_root_shapes == param_shape\n        # 2. Apply grad:\n        apply_grad_params = HloParser.get_param_line(apply_grad_text)\n        apply_grad_param_shapes = HloParser.parse_param_shapes(\n            apply_grad_params)\n        apply_grad_root = HloParser.get_root_line(apply_grad_text)\n        apply_grad_root_shapes = HloParser.parse_root_shapes(apply_grad_root)\n        assert apply_grad_param_shapes == param_shape + param_shape\n        assert apply_grad_root_shapes == param_shape\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(ManualShardingTest(\"test_set_input\"))\n    suite.addTest(ManualShardingTest(\"test_set_output\"))\n    suite.addTest(ManualShardingTest(\"test_grad_acc\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_mixed_2d.py",
    "content": "\"\"\"Test auto sharding with mixed mesh shape.\"\"\"\n\nimport unittest\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nfrom jax.interpreters.pxla import Chunked, NoSharding, Replicated, ShardedAxis\nimport optax\n\nfrom alpa import parallelize, LocalPhysicalDeviceMesh, ShardParallel, AutoShardingOption\nfrom alpa.util import map_to_shape, count_communication_primitives\n\n\nclass AutoShardingMixedTest(unittest.TestCase):\n\n    def setUp(self):\n        assert len(jax.local_devices()) >= 4\n        self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4])\n\n    def get_device_mesh(self, shape, mesh_alpha, mesh_beta):\n        return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta)\n\n    def test_dot_all_to_all(self):\n        device_mesh = self.get_device_mesh([2, 2], [1, 1], [1, 0.1])\n\n        as_option = AutoShardingOption(allow_mixed_mesh_shape=True,\n                                       allow_all_gather=False)\n\n        use_bias = False\n        B = 256\n        E = 4\n        M = 16\n        M_ = M // E\n        H = M * 8\n\n        class Model(nn.Module):\n\n            @nn.compact\n            def __call__(self, x):\n                wi = self.param(\"wi\", jax.nn.initializers.zeros, (\n                    E,\n                    M_,\n                    H,\n                ))\n                wo = self.param(\"wo\", jax.nn.initializers.zeros, (\n                    E,\n                    H,\n                    M_,\n                ))\n\n                x = nn.Dense(features=M, use_bias=use_bias)(x)\n                x = nn.Dense(features=M, use_bias=use_bias)(x)\n                x = x.reshape((B, E, M_))\n\n                x = jnp.einsum(\"BEM,EMH->EBH\", x, wi)\n                x = jnp.einsum(\"EBH,EHM->BEM\", x, wo)\n\n                x = x.reshape((B, M))\n                x = nn.Dense(features=M, use_bias=use_bias)(x)\n                x = nn.Dense(features=M, use_bias=use_bias)(x)\n                return x\n\n        @parallelize(method=ShardParallel(devices=device_mesh,\n                                          auto_sharding_option=as_option))\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"])\n                return jnp.mean((out - batch[\"y\"])**2)\n\n            grads = jax.grad(loss_func)(state.params)\n            new_state = state.apply_gradients(grads=grads)\n            return new_state\n\n        x = jnp.ones((B, M))\n        y = jnp.ones((B, M))\n\n        # Init train state\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x)\n        tx = optax.sgd(learning_rate=1e-2)\n        state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        # JIT compile\n        executable = train_step.get_executable(state, {\"x\": x, \"y\": y})\n        hlo_ir = executable.get_hlo_text()\n\n        # Check sharding specs\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = (\n            count_communication_primitives(hlo_ir))\n        assert n_all_to_all > 0\n        assert n_total == n_all_reduce + n_all_to_all\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(AutoShardingMixedTest(\"test_dot_all_to_all\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_mlp.py",
    "content": "\"\"\"Test auto sharding with MLP.\"\"\"\n\nimport unittest\nfrom itertools import chain\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nfrom jax.interpreters.pxla import Chunked, NoSharding, Replicated, ShardedAxis\nimport optax\n\nfrom alpa import (parallelize, LocalPhysicalDeviceMesh, AutoShardingOption,\n                  ShardParallel, Zero2Parallel, Zero3Parallel)\nfrom alpa.util import count_communication_primitives\n\n\ndef assert_close(x, y, atol=0.01):\n    assert abs((x + 1e-9) / (y + 1e-9) - 1) <= atol, f\"{x} vs. {y}\"\n\n\ndef assert_less_equal(x, y):\n    assert abs((x + 1e-9) / (y + 1e-9)) <= 1.01, f\"{x} vs. {y}\"\n\n\ndef assert_column_partitioned(x, num_chunks, mesh_dim):\n    assert x.sharding_spec.sharding == (NoSharding(), Chunked([num_chunks]))\n    assert x.sharding_spec.mesh_mapping == (ShardedAxis(0),)\n\n\ndef assert_row_partitioned(x, num_chunks, mesh_dim):\n    assert x.sharding_spec.sharding == (Chunked([num_chunks]), NoSharding())\n    assert x.sharding_spec.mesh_mapping == (ShardedAxis(0),)\n\n\ndef assert_expert_partitioned(x, num_chunks, mesh_dim):\n    assert x.sharding_spec.sharding == (Chunked([num_chunks]), NoSharding(),\n                                        NoSharding())\n    assert x.sharding_spec.mesh_mapping == (ShardedAxis(0),)\n\n\ndef assert_replicated_column_partitioned(x, mesh_shape):\n    assert x.sharding_spec.sharding == (NoSharding(), Chunked([mesh_shape[1]]))\n    assert x.sharding_spec.mesh_mapping[0] == Replicated(mesh_shape[0])\n    assert x.sharding_spec.mesh_mapping[1] == ShardedAxis(0)\n\n\ndef assert_replicated_row_partitioned(x, mesh_shape):\n    assert x.sharding_spec.sharding == (Chunked([mesh_shape[1]]), NoSharding())\n    assert x.sharding_spec.mesh_mapping[0] == Replicated(mesh_shape[0])\n    assert x.sharding_spec.mesh_mapping[1] == ShardedAxis(0)\n\n\ndef assert_all_replicated(x, num_replicas):\n    for axis_shard in x.sharding_spec.sharding:\n        assert axis_shard == NoSharding()\n    assert x.sharding_spec.mesh_mapping[0] == Replicated(num_replicas)\n\n\ndef is_sharded(x):\n    for axis in x.sharding_spec.mesh_mapping:\n        if isinstance(axis, ShardedAxis):\n            return True\n    return False\n\n\ndef assert_sharded(x):\n    assert is_sharded(x), f\"Not sharded: {str(x.sharding_spec)}\"\n\n\ndef is_fully_sharded(x):\n    for axis in x.sharding_spec.mesh_mapping:\n        if not isinstance(axis, ShardedAxis):\n            return False\n    return True\n\n\ndef assert_fully_sharded(x):\n    assert is_fully_sharded(x), f\"Not fully sharded: {str(x.sharding_spec)}\"\n\n\ndef assert_sharding_zero_stage_3(state, allow_not_sharded_params=0):\n    params = jax.tree_util.tree_leaves(state.params)\n    opt_state = jax.tree_util.tree_leaves(state.opt_state)\n\n    num_not_sharded = 0\n    for weight in chain(params, opt_state):\n        if not is_sharded(weight) and len(weight.shape) > 1:\n            num_not_sharded += 1\n    assert num_not_sharded <= allow_not_sharded_params\n\n\ndef assert_data_parallel_cost(state,\n                              hlo_ir,\n                              objective,\n                              device_mesh,\n                              as_option,\n                              mesh_dim,\n                              allow_not_sharded_params=0,\n                              optimizer_type=None):\n    params = jax.tree_util.tree_leaves(state.params)\n    opt_state = jax.tree_util.tree_leaves(state.opt_state)\n\n    # Check communication cost\n    replicated_penalty = int(\n        device_mesh.all_reduce_cost(1, 0) + device_mesh.all_reduce_cost(1, 1))\n    expected = sum(\n        device_mesh.all_reduce_cost(np.prod(x.shape) * 4, mesh_dim)\n        for x in params)\n    expected += replicated_penalty * (len(params) + len(opt_state))\n    assert_close(objective, expected)\n\n    # Check numbers of communication primitives\n    n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n        count_communication_primitives(hlo_ir, ignore_scalar_all_reduce=True))\n\n    # Special case 1 : adafactor\n    if optimizer_type == \"adafactor\" and as_option.prefer_reduce_scatter:\n        assert n_reduce_scatter == 1\n        assert n_all_gather <= 2\n        assert n_all_reduce <= 2\n        return\n\n    # Special case 2 : force zero stage 3\n    if as_option.force_zero_stage_3:\n        assert n_all_reduce == 0\n        assert n_all_gather == 2\n        assert n_reduce_scatter == 1\n        assert_sharding_zero_stage_3(state)\n        return\n\n    # Normal case\n    if as_option.prefer_reduce_scatter:\n        assert n_reduce_scatter == 1\n        assert n_all_gather == 1\n        if allow_not_sharded_params:\n            assert n_all_reduce == 1\n        else:\n            assert n_all_reduce == 0\n        assert n_total == n_reduce_scatter + n_all_gather + n_all_reduce\n    else:\n        assert n_all_reduce == 1\n        assert n_total == n_all_reduce\n\n    # Check sharding specification\n    if as_option.prefer_reduce_scatter:\n        num_not_sharded = 0\n        for weight in opt_state:\n            if not is_sharded(weight) and len(weight.shape) > 0:\n                num_not_sharded += 1\n        assert num_not_sharded <= allow_not_sharded_params * 2\n    else:\n        for weight in params:\n            assert_all_replicated(weight, np.prod(device_mesh.shape))\n\n\nclass AutoShardingMLPTest(unittest.TestCase):\n\n    def setUp(self):\n        assert len(jax.local_devices()) >= 4\n        self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4])\n        self.method = ShardParallel(auto_sharding_option=AutoShardingOption())\n        self.optimizer_type = \"adam\"\n\n    def get_device_mesh(self, shape, mesh_alpha, mesh_beta):\n        return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta)\n\n    def run_n_layer_mlp(self,\n                        num_layers,\n                        batch_size,\n                        input_dim,\n                        output_dim,\n                        hidden_dim,\n                        device_mesh,\n                        use_bias=True):\n\n        class Model(nn.Module):\n\n            @nn.compact\n            def __call__(self, x):\n                for i in range(num_layers - 1):\n                    x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x)\n                    x = nn.relu(x)\n                x = nn.Dense(features=output_dim, use_bias=use_bias)(x)\n                return x\n\n        self.method.devices = device_mesh\n\n        @parallelize(method=self.method)\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"])\n                return jnp.mean((out - batch[\"y\"])**2)\n\n            grads = jax.grad(loss_func)(state.params)\n            new_state = state.apply_gradients(grads=grads)\n            return new_state\n\n        x = jnp.ones((batch_size, input_dim))\n        y = jnp.ones((batch_size, output_dim))\n\n        # Init train state\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x)\n        if self.optimizer_type == \"adam\":\n            tx = optax.adam(learning_rate=1e-2)\n        elif self.optimizer_type == \"adafactor\":\n            tx = optax.adafactor(learning_rate=1e-2, min_dim_size_to_factor=4)\n        else:\n            raise ValueError(f\"Invalid optimizer_type: {self.optimizer_type}\")\n        state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        # JIT compile\n        state = train_step(state, {\"x\": x, \"y\": y})\n\n        # Get optimized HLO IR\n        executable = train_step.get_last_executable()\n        return (state, executable.get_hlo_text(),\n                executable.auto_sharding_objective)\n\n    def test_n_layer_mlp_data_parallel(self):\n        num_layers = 6\n        batch_size = 256\n        hidden_dim = 32\n\n        # Test on different device meshes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_n_layer_mlp(\n                num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim,\n                device_mesh)\n\n            assert_data_parallel_cost(state,\n                                      hlo_ir,\n                                      objective,\n                                      device_mesh,\n                                      self.method.as_option,\n                                      i,\n                                      optimizer_type=self.optimizer_type)\n\n    def test_n_layer_mlp_model_parallel(self):\n        num_layers = 6\n        batch_size = 32\n        hidden_dim = 256\n\n        # Test on different device meshes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_n_layer_mlp(\n                num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim,\n                device_mesh)\n\n            # Check communication cost\n            expected = (\n                (num_layers - 1) *\n                device_mesh.all_reduce_cost(batch_size * hidden_dim * 4, i))\n            assert_close(objective, expected)\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n                count_communication_primitives(hlo_ir))\n            if self.method.as_option.prefer_reduce_scatter:\n                assert n_all_reduce + n_reduce_scatter == num_layers - 1\n                assert n_reduce_scatter == n_all_gather\n                assert n_total == n_all_reduce + n_reduce_scatter + n_all_gather\n            else:\n                assert n_all_reduce == num_layers - 1\n                assert n_total == n_all_reduce\n\n            # Check sharding specification\n            for k in range(num_layers):\n                weight = state.params[\"params\"][f\"Dense_{k}\"][\"kernel\"]\n                if k % 2 == 0:\n                    assert_column_partitioned(weight, mesh_shape[i], i)\n                else:\n                    assert_row_partitioned(weight, mesh_shape[i], i)\n\n    def test_n_layer_mlp_2d_mesh(self):\n        num_layers = 6\n        batch_size = 256\n        hidden_dim = 32\n\n        # Test on different device meshes\n        mesh_shape = [2, 2]\n        device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 0.1])\n        state, hlo_ir, objective = self.run_n_layer_mlp(num_layers, batch_size,\n                                                        hidden_dim, hidden_dim,\n                                                        hidden_dim, device_mesh)\n\n        # Check communication cost\n        expected = (num_layers *\n                    (device_mesh.all_reduce_cost(\n                        hidden_dim * hidden_dim * 4 / mesh_shape[1], 0) +\n                     device_mesh.all_reduce_cost(hidden_dim * 4, 0)) +\n                    (num_layers - 1) * device_mesh.all_reduce_cost(\n                        batch_size * hidden_dim * 4 / mesh_shape[0], 1))\n        assert_close(objective, expected)\n\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, _ = (\n            count_communication_primitives(hlo_ir))\n        if self.method.as_option.prefer_reduce_scatter:\n            assert n_all_reduce == num_layers - 1\n            # two reduce-scatter for two tensor dimensions\n            assert n_reduce_scatter == 2\n            # two for two tensor dimensions, although we can merge them\n            assert n_all_gather <= 2\n            assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter\n        else:\n            assert n_all_reduce == num_layers\n            assert n_total == n_all_reduce\n\n        # Check sharding specification\n        if self.method.as_option.prefer_reduce_scatter:\n            for weight in jax.tree_util.tree_leaves(state.opt_state):\n                if len(weight.shape) > 1:\n                    assert_fully_sharded(weight)\n        else:\n            for k in range(num_layers):\n                weight = state.params[\"params\"][f\"Dense_{k}\"][\"kernel\"]\n                if k % 2 == 0:\n                    assert_replicated_column_partitioned(weight, mesh_shape)\n                else:\n                    assert_replicated_row_partitioned(weight, mesh_shape)\n\n    def test_n_layer_mlp_force_data_parallel(self):\n        num_layers = 6\n        batch_size = 32\n        hidden_dim = 256\n\n        # Test on different device meshes\n        for i, mesh_shape in enumerate([(4, 1), (2, 2)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            self.method.as_option.force_data_parallel = True\n            state, hlo_ir, objective = self.run_n_layer_mlp(\n                num_layers, batch_size, hidden_dim, hidden_dim, hidden_dim,\n                device_mesh)\n\n            assert_data_parallel_cost(state, hlo_ir, objective,\n                                      device_mesh.flatten(),\n                                      self.method.as_option, 0)\n\n    def test_n_layer_mlp_force_batch_dim_mapping(self):\n        num_layers = 6\n        batch_size = 32\n        hidden_dim = 256\n        self.method.as_option.force_batch_dim_to_mesh_dim = 0\n\n        # Data parallel\n        device_mesh = self.get_device_mesh([4, 1], [1, 1], [1, 1])\n        state, hlo_ir, objective = self.run_n_layer_mlp(num_layers, batch_size,\n                                                        hidden_dim, hidden_dim,\n                                                        hidden_dim, device_mesh)\n        assert_data_parallel_cost(state, hlo_ir, objective, device_mesh,\n                                  self.method.as_option, 0)\n\n        # Model parallel\n        device_mesh = self.get_device_mesh([1, 4], [1, 1], [1, 1])\n        state, hlo_ir, objective = self.run_n_layer_mlp(num_layers, batch_size,\n                                                        hidden_dim, hidden_dim,\n                                                        hidden_dim, device_mesh)\n        expected = ((num_layers - 1) *\n                    device_mesh.all_reduce_cost(batch_size * hidden_dim * 4, 1))\n        assert_close(objective, expected)\n\n    def test_n_layer_mlp_data_parallel_reduce_scatter(self):\n        self.method = Zero2Parallel()\n        self.test_n_layer_mlp_data_parallel()\n\n    def test_n_layer_mlp_model_parallel_reduce_scatter(self):\n        self.method.as_option.prefer_reduce_scatter = True\n        self.test_n_layer_mlp_model_parallel()\n\n    def test_n_layer_mlp_2d_mesh_reduce_scatter(self):\n        self.method.as_option.prefer_reduce_scatter = True\n        self.test_n_layer_mlp_2d_mesh()\n\n    def test_n_layer_mlp_data_parallel_reduce_scatter_adafactor(self):\n        self.method.as_option.prefer_reduce_scatter = True\n        self.optimizer_type = \"adafactor\"\n        self.test_n_layer_mlp_data_parallel()\n\n    def test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3(self):\n        self.method = Zero3Parallel()\n        self.method.as_option.force_zero_stage_3_all_gather_threshold = (\n            (32 * 32 + 32) * 6 * 4)\n        self.test_n_layer_mlp_data_parallel()\n\n    def test_weight_init(self):\n\n        class Model(nn.Module):\n\n            @nn.compact\n            def __call__(self, x, deterministic):\n                x = nn.Dense(16)(x)\n                x = nn.Dense(16)(x)\n                return x\n\n        x = jnp.ones((64, 16))\n        y = jnp.ones((64, 16))\n\n        # Init model and optimizer\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n\n        @parallelize(method=ShardParallel(devices=self.physical_mesh))\n        def init_weight(rngkey):\n            params = model.init(rngkey, x, True)\n            tx = optax.adam(learning_rate=1e-2)\n            state = TrainState.create(apply_fn=model.apply,\n                                      params=params,\n                                      tx=tx)\n            return state\n\n        state = init_weight(rngkey)\n\n        # Check sharding specification\n        assert_all_replicated(state.step, self.physical_mesh.num_devices)\n        assert_sharded(state.params[\"params\"][\"Dense_0\"][\"kernel\"])\n        assert_sharded(state.params[\"params\"][\"Dense_1\"][\"kernel\"])\n        assert_sharded(state.opt_state[0].mu[\"params\"][\"Dense_0\"][\"kernel\"])\n        assert_sharded(state.opt_state[0].nu[\"params\"][\"Dense_1\"][\"kernel\"])\n\n\ndef suite():\n    suite = unittest.TestSuite()\n\n    def add(name):\n        suite.addTest(AutoShardingMLPTest(name))\n\n    add(\"test_n_layer_mlp_data_parallel\")\n    add(\"test_n_layer_mlp_model_parallel\")\n    add(\"test_n_layer_mlp_2d_mesh\")\n    add(\"test_n_layer_mlp_force_data_parallel\")\n    add(\"test_n_layer_mlp_force_batch_dim_mapping\")\n\n    add(\"test_n_layer_mlp_data_parallel_reduce_scatter\")\n    add(\"test_n_layer_mlp_model_parallel_reduce_scatter\")\n    add(\"test_n_layer_mlp_2d_mesh_reduce_scatter\")\n\n    add(\"test_n_layer_mlp_data_parallel_reduce_scatter_adafactor\")\n\n    add(\"test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3\")\n\n    add(\"test_weight_init\")\n\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_moe.py",
    "content": "\"\"\"Test auto sharding with MoE.\"\"\"\n\nimport unittest\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\nfrom alpa import parallelize, ShardParallel, LocalPhysicalDeviceMesh, AutoShardingOption\nfrom alpa.util import count_communication_primitives\nfrom alpa.model.moe import FlaxMoELayer, FlaxMoEForLMModule, MoEConfig, TrainState\n\nfrom tests.shard_parallel.test_mlp import (assert_all_replicated, assert_close,\n                                           assert_expert_partitioned,\n                                           assert_sharding_zero_stage_3)\n\n\nclass AutoShardingMoETest(unittest.TestCase):\n\n    def setUp(self):\n        assert len(jax.local_devices()) >= 4\n        self.physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4])\n        self.as_option = AutoShardingOption()\n\n    def get_device_mesh(self, shape, mesh_alpha, mesh_beta):\n        return self.physical_mesh.get_logical_mesh(shape, mesh_alpha, mesh_beta)\n\n    def run_moe_layer(self, batch_size, seq_len, hidden_size, num_heads, S, E,\n                      deterministic, device_mesh):\n\n        @parallelize(method=ShardParallel(devices=device_mesh,\n                                          auto_sharding_option=self.as_option))\n        def train_step(state, batch, deterministic):\n\n            def loss_func(params):\n                rngs = {\"dropout\": batch[\"rng\"]}\n                out = state.apply_fn(params,\n                                     batch[\"hidden_states\"],\n                                     batch[\"attention_mask\"],\n                                     deterministic,\n                                     rngs=rngs)[0]\n                return jnp.mean((out - batch[\"labels\"])**2)\n\n            grads = jax.grad(loss_func)(state.params)\n            return state.apply_gradients(grads=grads)\n\n        dtype = jnp.float32\n        hidden_states = jnp.ones((batch_size, seq_len, hidden_size),\n                                 dtype=dtype)\n        attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        labels = jnp.ones((batch_size, seq_len, hidden_size), dtype=dtype)\n\n        # Init model and optimizer\n        model = FlaxMoELayer(MoEConfig(\n            hidden_size=hidden_size,\n            intermediate_size=hidden_size * 4,\n            num_attention_heads=num_heads,\n            expert_group_size=S,\n            expert_number=E,\n        ),\n                             dtype=dtype)\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, hidden_states, attention_mask)\n        tx = optax.adam(1e-2)\n        state = TrainState.create(apply_fn=model.apply,\n                                  params=params,\n                                  tx=tx,\n                                  dynamic_scale=None)\n\n        # JIT compile\n        state = train_step(\n            state, {\n                \"hidden_states\": hidden_states,\n                \"attention_mask\": attention_mask,\n                \"labels\": labels,\n                \"rng\": rngkey\n            }, deterministic)\n\n        # Get optimized HLO IR\n        executable = train_step.get_last_executable()\n        return (state, executable.get_hlo_text(),\n                executable.auto_sharding_objective)\n\n    def run_moe_lm(self, batch_size, seq_len, num_layers, hidden_size,\n                   num_heads, vocab_size, S, E, deterministic, device_mesh):\n\n        @parallelize(method=ShardParallel(devices=device_mesh,\n                                          auto_sharding_option=self.as_option))\n        def train_step(state, batch, deterministic, rng_key):\n\n            def loss_func(params):\n                rngs = {\"dropout\": rng_key}\n                logits = state.apply_fn(params,\n                                        batch[\"input_ids\"],\n                                        batch[\"attention_mask\"],\n                                        batch[\"token_type_ids\"],\n                                        batch[\"position_ids\"],\n                                        deterministic=deterministic,\n                                        rngs=rngs)[0]\n                label_mask = jnp.where(batch[\"labels\"] > 0, 1.0, 0.0)\n                labels = jax.nn.one_hot(batch[\"labels\"], logits.shape[-1])\n                loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1),\n                                axis=-1)\n                loss = (label_mask * loss).sum() / label_mask.sum()\n                return loss\n\n            grads = jax.grad(loss_func)(state.params)\n            return state.apply_gradients(grads=grads)\n\n        # Init model and optimizer\n        input_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        token_type_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        position_ids = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        labels = jnp.ones((batch_size, seq_len), dtype=jnp.int32)\n        dtype = jnp.float32\n\n        model = FlaxMoEForLMModule(MoEConfig(\n            num_hidden_layers=num_layers,\n            hidden_size=hidden_size,\n            intermediate_size=hidden_size * 4,\n            num_attention_heads=num_heads,\n            max_position_embeddings=seq_len,\n            vocab_size=vocab_size,\n            expert_group_size=S,\n            expert_number=E,\n        ),\n                                   dtype=dtype)\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, input_ids, attention_mask, token_type_ids,\n                            position_ids)\n\n        def weight_decay_mask(pytree):\n            # do not use weight decay on layer norm and bias.\n            return jax.tree_map(lambda x: x.ndim > 1, pytree)\n\n        tx = optax.adafactor(\n            learning_rate=1e-2,\n            weight_decay_mask=weight_decay_mask,\n            min_dim_size_to_factor=4,\n        )\n        state = TrainState.create(apply_fn=model.apply,\n                                  params=params,\n                                  tx=tx,\n                                  dynamic_scale=None,\n                                  use_master_copy=(dtype == jnp.float16))\n\n        # JIT compile\n        state = train_step(\n            state, {\n                \"input_ids\": input_ids,\n                \"attention_mask\": attention_mask,\n                \"token_type_ids\": token_type_ids,\n                \"position_ids\": position_ids,\n                \"labels\": labels,\n            }, deterministic, rngkey)\n\n        # Get optimized HLO IR\n        executable = train_step.get_last_executable()\n        return (state, executable.get_hlo_text(),\n                executable.auto_sharding_objective)\n\n    def test_moe_layer(self):\n        batch_size = 64\n        seq_len = 16\n        hidden_size = 64\n        num_heads = 16\n        S = 32\n        E = 16\n        deterministic = True\n\n        # Test on different logical mesh shapes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_moe_layer(\n                batch_size, seq_len, hidden_size, num_heads, S, E,\n                deterministic, device_mesh)\n\n            # Check communication cost\n            # all-to-all + data-parallel on attention_w_i, attention_w_o, layer_norm, moe_w_g\n            expected = (\n                device_mesh.all_to_all_cost(\n                    batch_size * seq_len * hidden_size * 2 * 4, i) * 4 +\n                device_mesh.all_reduce_cost(hidden_size * hidden_size * 3 * 4,\n                                            i) +\n                device_mesh.all_reduce_cost(hidden_size * 3 * 4, i) +\n                device_mesh.all_reduce_cost(hidden_size * hidden_size * 4, i) +\n                device_mesh.all_reduce_cost(hidden_size * 4, i) +\n                device_mesh.all_reduce_cost(hidden_size * 4, i) * 4 +\n                device_mesh.all_reduce_cost(hidden_size * E * 4, i))\n            assert_close(expected, objective)\n\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = (\n                count_communication_primitives(hlo_ir))\n            assert n_all_reduce == 1\n            assert n_all_to_all == 4\n            assert n_total == n_all_reduce + n_all_to_all\n\n            # Check sharding specification\n            num_devices = np.prod(device_mesh.shape)\n            assert_all_replicated(\n                state.params[\"params\"][\"attention\"][\"output\"][\"dense\"]\n                [\"kernel\"], num_devices)\n            assert_all_replicated(\n                state.params[\"params\"][\"attention\"][\"self\"][\"qvk_combined\"]\n                [\"kernel\"], num_devices)\n            assert_all_replicated(state.params[\"params\"][\"moe\"][\"wg\"],\n                                  num_devices)\n            assert_expert_partitioned(state.params[\"params\"][\"moe\"][\"wi\"],\n                                      num_devices, i)\n            assert_expert_partitioned(state.params[\"params\"][\"moe\"][\"wo\"],\n                                      num_devices, i)\n\n    def test_moe_layer_2d(self):\n        batch_size = 64\n        seq_len = 16\n        hidden_size = 64\n        num_heads = 16\n        S = 32\n        E = 16\n        deterministic = True\n        self.as_option.allow_mixed_mesh_shape = True\n        self.as_option.allow_all_gather = False\n\n        # Test on different logical mesh shapes\n        device_mesh = self.get_device_mesh([2, 2], [1, 1], [1, 1])\n        state, hlo_ir, objective = self.run_moe_layer(batch_size, seq_len,\n                                                      hidden_size, num_heads, S,\n                                                      E, deterministic,\n                                                      device_mesh)\n\n        # Check communication cost\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = (\n            count_communication_primitives(hlo_ir))\n        assert n_all_reduce == 2  # one data-parallel for experts weights,\n        # one data-parallel for normal weights\n        assert n_all_to_all > 0\n        assert n_total == n_all_reduce + n_all_to_all\n\n    def test_moe_layer_2d_reduce_scatter(self):\n        batch_size = 64\n        seq_len = 16\n        hidden_size = 64\n        num_heads = 16\n        S = 32\n        E = 16\n        deterministic = True\n        self.as_option.allow_mixed_mesh_shape = True\n        self.as_option.allow_all_gather = False\n        self.as_option.prefer_reduce_scatter = True\n\n        # Test on different logical mesh shapes\n        device_mesh = self.get_device_mesh([2, 2], [1, 1], [1, 1])\n        state, hlo_ir, objective = self.run_moe_layer(batch_size, seq_len,\n                                                      hidden_size, num_heads, S,\n                                                      E, deterministic,\n                                                      device_mesh)\n\n        # Check communication cost\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = (\n            count_communication_primitives(hlo_ir))\n        assert n_all_to_all > 0\n        assert n_reduce_scatter > 0\n        assert n_all_reduce == 0\n        assert n_total == n_all_reduce + n_reduce_scatter + n_all_to_all + n_all_gather\n\n    def test_moe_lm(self):\n        num_layers = 2\n        batch_size = 64\n        seq_len = 16\n        hidden_size = 64\n        num_heads = 16\n        vocab_size = 32\n        S = 32\n        E = 16\n        deterministic = True\n\n        # Test on different logical mesh shapes\n        for i, mesh_shape in enumerate([(4, 1), (1, 4)]):\n            device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n            state, hlo_ir, objective = self.run_moe_lm(batch_size, seq_len,\n                                                       num_layers, hidden_size,\n                                                       num_heads, vocab_size, S,\n                                                       E, deterministic,\n                                                       device_mesh)\n\n            # Check communication cost\n            # all-to-all + data-parallel on attention_w_i, attention_w_o, layer_norm, moe_w_g\n            n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = (\n                count_communication_primitives(hlo_ir,\n                                               ignore_scalar_all_reduce=True))\n\n            # Special case: zero stage 3\n            if self.as_option.force_zero_stage_3:\n                assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter + n_all_to_all\n                assert_sharding_zero_stage_3(state, 4)\n                continue\n\n            # Normal cases\n            if self.as_option.prefer_reduce_scatter:\n                if self.as_option.force_data_parallel:\n                    assert 0 < n_reduce_scatter <= 2\n                    assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter\n                else:\n                    assert n_reduce_scatter == 1\n                    assert n_all_to_all == 4\n                    assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter + n_all_to_all\n            else:\n                if self.as_option.force_data_parallel:\n                    assert n_all_reduce == 1\n                    assert n_total == n_all_reduce\n                else:\n                    assert n_all_reduce <= 4\n                    assert n_all_to_all == 4\n                    assert n_total == n_all_reduce + n_all_to_all\n\n    def test_moe_lm_2d(self):\n        num_layers = 2\n        batch_size = 64\n        seq_len = 16\n        hidden_size = 64\n        num_heads = 16\n        vocab_size = 32\n        S = 32\n        E = 16\n        deterministic = True\n        self.as_option.allow_mixed_mesh_shape = True\n\n        mesh_shape = (2, 2)\n        device_mesh = self.get_device_mesh(mesh_shape, [1, 1], [1, 1])\n        state, hlo_ir, objective = self.run_moe_lm(batch_size, seq_len,\n                                                   num_layers, hidden_size,\n                                                   num_heads, vocab_size, S, E,\n                                                   deterministic, device_mesh)\n\n        # Check communication cost\n        n_total, n_all_reduce, n_all_gather, n_reduce_scatter, n_all_to_all = (\n            count_communication_primitives(hlo_ir))\n        if self.as_option.prefer_reduce_scatter:\n            assert n_reduce_scatter > 0\n            assert n_total == n_all_reduce + n_all_gather + n_reduce_scatter + n_all_to_all\n        else:\n            assert n_all_to_all == 4\n            assert n_total == n_all_reduce + n_all_to_all\n\n    def test_moe_lm_data_parallel(self):\n        self.as_option.force_data_parallel = True\n        self.test_moe_lm()\n\n    def test_moe_lm_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_moe_lm()\n\n    def test_moe_lm_2d_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.test_moe_lm_2d()\n\n    def test_moe_lm_data_parallel_reduce_scatter(self):\n        self.as_option.prefer_reduce_scatter = True\n        self.as_option.force_data_parallel = True\n        self.test_moe_lm()\n\n    def test_moe_lm_data_parallel_reduce_scatter_zero_3(self):\n        self.as_option.force_zero_stage_3 = True\n        self.as_option.force_zero_stage_3_all_gather_threshold = 1\n        self.test_moe_lm()\n\n\ndef suite():\n    suite = unittest.TestSuite()\n\n    def add(name):\n        suite.addTest(AutoShardingMoETest(name))\n\n    add(\"test_moe_layer\")\n    add(\"test_moe_layer_2d\")\n    add(\"test_moe_layer_2d_reduce_scatter\")\n\n    add(\"test_moe_lm\")\n    add(\"test_moe_lm_2d\")\n    add(\"test_moe_lm_data_parallel\")\n\n    add(\"test_moe_lm_reduce_scatter\")\n    add(\"test_moe_lm_2d_reduce_scatter\")\n    add(\"test_moe_lm_data_parallel_reduce_scatter\")\n    add(\"test_moe_lm_data_parallel_reduce_scatter_zero_3\")\n\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/shard_parallel/test_numerical_correctness.py",
    "content": "\"\"\"Test the numerical correctness of shard parallel.\"\"\"\nimport unittest\n\nfrom flax import linen as nn\nimport jax\nimport jax.numpy as jnp\nimport optax\nimport ray\n\nimport alpa\nfrom alpa import parallelize, LocalPhysicalDeviceMesh\nfrom alpa.model.bert_model import BertConfig, FlaxBertLayer, TrainState\nfrom alpa.testing import (assert_allclose, create_train_state,\n                          get_bert_layer_train_state_and_step)\n\n\nclass AutoShardingCorrectnessTest(unittest.TestCase):\n\n    def test_2_layer_bert_shard_parallel(self):\n        physical_mesh = LocalPhysicalDeviceMesh(jax.local_devices()[:4])\n        logical_mesh = physical_mesh.get_logical_mesh([2, 2])\n\n        # Init model\n        state, batch, train_step = get_bert_layer_train_state_and_step(\n            batch_size=16,\n            seq_len=8,\n            num_layers=2,\n            hidden_size=256,\n            num_heads=8,\n            clip_by_global_norm=False,\n            use_dynamic_scale=False,\n            add_manual_pipeline_marker=False)\n\n        # Train one step\n        p_train_step = parallelize(train_step)\n        expected_state, expected_grads = train_step(state, batch)\n        actual_state, actual_grads = p_train_step(state, batch)\n\n        #print(expected_state)\n        #print(actual_state)\n\n        # print(\"group 1:\")\n        # print(\"expected param example: \", jax.tree_util.tree_flatten(expected_params.params)[0][0][0:10])\n        # print(\"actual param example: \", jax.tree_util.tree_flatten(actual_params.params)[0][0]._value[0:10])\n        # print(\"expected grad example: \", jax.tree_util.tree_flatten(expected_grads)[0][0][0:10])\n        # print(\"actual grad example: \", jax.tree_util.tree_flatten(actual_grads)[0][0]._value[0:10])\n\n        # print(\"group 2:\")\n        # print(\"expected param example: \", jax.tree_util.tree_flatten(expected_params.params)[0][-1][0:100])\n        # print(\"actual param example: \", jax.tree_util.tree_flatten(actual_params.params)[0][-1]._value[0:100])\n        # print(\"expected grad example: \", jax.tree_util.tree_flatten(expected_grads)[0][-1][0:100])\n        # print(\"actual grad example: \", jax.tree_util.tree_flatten(actual_grads)[0][-1]._value[0:100])\n\n        assert_allclose(expected_state, actual_state, rtol=5e-4, atol=5e-4)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(\n        AutoShardingCorrectnessTest(\"test_2_layer_bert_shard_parallel\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/torch_frontend/test_dict_input.py",
    "content": "import unittest\n\nimport torch\nimport alpa.torch.optim as torchoptim\nimport alpa\nfrom alpa.torch.trainer import train_torch_module\n\n\nclass MyModule(torch.nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        self.linear1 = torch.nn.Linear(16, 16)\n        self.linear2 = torch.nn.Linear(16, 16)\n        self.linear3 = torch.nn.Linear(16, 16)\n        self.linear4 = torch.nn.Linear(16, 16)\n\n    def forward(self, input_dict):\n        x = input_dict[\"x\"]\n        y = input_dict[\"dict2\"][\"y\"]\n        x = self.linear1(x) + y\n        # do some debugging when in local mode\n        if getattr(torch, \"local_mode\", True):\n            print(x)\n        x = self.linear2(x)\n        x = self.linear3(x)\n        x = self.linear4(x)\n        return x\n\n\ndef weight_init_func(pt_module, name_map, params, bufs):\n    for k, m in pt_module.named_modules():\n        if isinstance(m, torch.nn.Linear):\n            params[name_map[f\"{k}.weight\"]] = torch.nn.init.xavier_uniform(\n                params[name_map[f\"{k}.weight\"]])\n            params[name_map[f\"{k}.bias\"]] = torch.nn.init.normal(\n                params[name_map[f\"{k}.bias\"]], std=1e-6)\n    return params, bufs\n\n\nclass TorchDictInputTest(unittest.TestCase):\n\n    def setUp(self):\n        torch.manual_seed(123)\n        alpa.set_seed(123)\n\n    def test_dict_input(self):\n        pt_module_gen = lambda: MyModule()\n\n        dataloader = [\n            ({\n                \"x\": torch.randn(8, 16),\n                \"dict2\": {\n                    \"y\": torch.randn(8, 16)\n                }\n            }, torch.randn(8, 16)),\n            ({\n                \"x\": torch.randn(8, 16),\n                \"dict2\": {\n                    \"y\": torch.randn(8, 16)\n                }\n            }, torch.randn(8, 16)),\n        ]\n        loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss(\n            *args, **kwargs)\n        optim_gen = torchoptim.adam(lr=1e-3)\n        parallel_method = alpa.ShardParallel()\n\n        train_torch_module(pt_module_gen, weight_init_func, dataloader,\n                           loss_func, optim_gen, parallel_method)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(TorchDictInputTest(\"test_dict_input\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/torch_frontend/test_reshape.py",
    "content": "import unittest\n\nimport torch\nimport alpa.torch.optim as torchoptim\nimport alpa\nfrom alpa.torch.trainer import train_torch_module\n\n\nclass MyModule(torch.nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        self.linear1 = torch.nn.Linear(16, 16)\n        self.linear2 = torch.nn.Linear(16, 16)\n\n    def forward(self, x):\n        x = self.linear1(x)\n        x = self.linear2(x)\n        x = x.reshape(x.shape[0], 2, -1)\n        x = x.reshape(x.shape[0], -1, 2)\n        x = x.reshape(x.shape[0], 16)\n        return x\n\n\ndef weight_init_func(pt_module, name_map, params, bufs):\n    # for k, m in pt_module.named_modules():\n    #     if isinstance(m, torch.nn.Linear):\n    #         params[name_map[f\"{k}.weight\"]] = torch.nn.init.xavier_uniform(params[name_map[f\"{k}.weight\"]])\n    #         params[name_map[f\"{k}.bias\"]] = torch.nn.init.normal(params[name_map[f\"{k}.bias\"]], std=1e-6)\n    return params, bufs\n\n\nclass TorchReshapeTest(unittest.TestCase):\n\n    def setUp(self):\n        torch.manual_seed(123)\n        alpa.set_seed(123)\n\n    def test_reshape(self):\n        B = 64\n\n        pt_module_gen = lambda: MyModule()\n\n        dataloader = [\n            (torch.randn(B, 16), torch.randn(B, 16)),\n            (torch.randn(B, 16), torch.randn(B, 16)),\n        ]\n        loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss(\n            *args, **kwargs)\n        optim_gen = torchoptim.adam(lr=1e-3)\n        parallel_method = alpa.ShardParallel()\n\n        train_torch_module(pt_module_gen, weight_init_func, dataloader,\n                           loss_func, optim_gen, parallel_method)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(TorchReshapeTest(\"test_reshape\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/torch_frontend/test_simple.py",
    "content": "import unittest\n\nimport torch\nimport alpa.torch.optim as torchoptim\nimport alpa\nfrom alpa.torch.trainer import train_torch_module\n\n\nclass MyModule(torch.nn.Module):\n\n    def __init__(self):\n        super().__init__()\n        self.linear1 = torch.nn.Linear(16, 16)\n        self.linear2 = torch.nn.Linear(16, 16)\n        self.linear3 = torch.nn.Linear(16, 16)\n        self.linear4 = torch.nn.Linear(16, 16)\n\n    def forward(self, x):\n        x = self.linear1(x)\n        # do some debugging when in local mode\n        if getattr(torch, \"local_mode\", True):\n            print(x)\n        x = self.linear2(x)\n        x = self.linear3(x)\n        x = self.linear4(x)\n        return x\n\n\ndef weight_init_func(pt_module, name_map, params, bufs):\n    for k, m in pt_module.named_modules():\n        if isinstance(m, torch.nn.Linear):\n            params[name_map[f\"{k}.weight\"]] = torch.nn.init.xavier_uniform(\n                params[name_map[f\"{k}.weight\"]])\n            params[name_map[f\"{k}.bias\"]] = torch.nn.init.normal(\n                params[name_map[f\"{k}.bias\"]], std=1e-6)\n    return params, bufs\n\n\nclass TorchSimpleTest(unittest.TestCase):\n\n    def setUp(self):\n        torch.manual_seed(123)\n        alpa.set_seed(123)\n\n    def test_simple_shard(self):\n        pt_module_gen = lambda: MyModule()\n\n        dataloader = [\n            (torch.randn(128, 16), torch.randn(128, 16)),\n            (torch.randn(128, 16), torch.randn(128, 16)),\n        ]\n        loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss(\n            *args, **kwargs)\n        optim_gen = torchoptim.adam(lr=1e-3)\n        parallel_method = alpa.ShardParallel()\n\n        train_torch_module(pt_module_gen, weight_init_func, dataloader,\n                           loss_func, optim_gen, parallel_method)\n\n    def test_simple_pipeshard(self):\n        pt_module_gen = lambda: MyModule()\n\n        dataloader = [\n            (torch.randn(128, 16), torch.randn(128, 16)),\n            (torch.randn(128, 16), torch.randn(128, 16)),\n        ]\n        loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss(\n            *args, **kwargs)\n        optim_gen = torchoptim.adam(lr=1e-3)\n        num_micro_batches = 2\n        parallel_method = alpa.PipeshardParallel(\n            num_micro_batches=num_micro_batches,\n            layer_option=alpa.AutoLayerOption(layer_num=2),\n            stage_option=\"auto\")\n\n        train_torch_module(pt_module_gen, weight_init_func, dataloader,\n                           loss_func, optim_gen, parallel_method)\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(TorchSimpleTest(\"test_simple_shard\"))\n    suite.addTest(TorchSimpleTest(\"test_simple_pipeshard\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/torch_frontend/test_zhen.py",
    "content": "import unittest\nfrom enum import Enum\nfrom typing import List, Optional, Tuple, Union, Callable\n\nimport torch\nimport torch.nn as nn\nfrom torch import Tensor, embedding\nimport alpa.torch.optim as torchoptim\nfrom alpa.torch.trainer import train_torch_module\nimport alpa\n\n\n# Copied from timm\n# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py\nclass Attention(nn.Module):\n\n    def __init__(self,\n                 dim,\n                 num_heads=8,\n                 qkv_bias=False,\n                 attn_drop=0.,\n                 proj_drop=0.):\n        super().__init__()\n        assert dim % num_heads == 0, 'dim should be divisible by num_heads'\n        self.num_heads = num_heads\n        head_dim = dim // num_heads\n        self.scale = head_dim**-0.5\n\n        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n        self.attn_drop = nn.Dropout(attn_drop)\n        self.proj = nn.Linear(dim, dim)\n        self.proj_drop = nn.Dropout(proj_drop)\n\n    def forward(self, x):\n        B, N, C = x.shape\n        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,\n                                  C // self.num_heads).permute(2, 0, 3, 1, 4)\n        q, k, v = qkv.unbind(\n            0)  # make torchscript happy (cannot use tensor as tuple)\n\n        attn = (q @ k.transpose(-2, -1)) * self.scale\n        attn = attn.softmax(dim=-1)\n        attn = self.attn_drop(attn)\n\n        x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n        x = self.proj(x)\n        x = self.proj_drop(x)\n        return x\n\n\ndef _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]:\n    if activation == \"relu\":\n        return torch.nn.functional.relu\n    elif activation == \"gelu\":\n        return torch.nn.functional.gelu\n\n    raise RuntimeError(\n        \"activation should be relu/gelu, not {}\".format(activation))\n\n\n# Adapted from torch/nn/modules/transformer.py\nclass TransformerEncoderLayer(nn.Module):\n    r\"\"\"TransformerEncoderLayer is made up of self-attn and feedforward network.\n    This standard encoder layer is based on the paper \"Attention Is All You Need\".\n    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,\n    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in\n    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement\n    in a different way during application.\n\n    Args:\n        d_model: the number of expected features in the input (required).\n        nhead: the number of heads in the multiheadattention models (required).\n        dim_feedforward: the dimension of the feedforward network model (default=2048).\n        dropout: the dropout value (default=0.1).\n        activation: the activation function of the intermediate layer, can be a string\n            (\"relu\" or \"gelu\") or a unary callable. Default: relu\n        layer_norm_eps: the eps value in layer normalization components (default=1e-5).\n        batch_first: If ``True``, then the input and output tensors are provided\n            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).\n        norm_first: if ``True``, layer norm is done prior to attention and feedforward\n            operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).\n\n    Examples::\n        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)\n        >>> src = torch.rand(10, 32, 512)\n        >>> out = encoder_layer(src)\n\n    Alternatively, when ``batch_first`` is ``True``:\n        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)\n        >>> src = torch.rand(32, 10, 512)\n        >>> out = encoder_layer(src)\n    \"\"\"\n    __constants__ = ['batch_first', 'norm_first']\n\n    def __init__(self,\n                 d_model: int,\n                 nhead: int,\n                 dim_feedforward: int = 2048,\n                 dropout: float = 0.1,\n                 activation: Union[str, Callable[[Tensor], Tensor]] = \"relu\",\n                 layer_norm_eps: float = 1e-5,\n                 batch_first: bool = False,\n                 norm_first: bool = False,\n                 device=None,\n                 dtype=None) -> None:\n        factory_kwargs = {'device': device, 'dtype': dtype}\n        super(TransformerEncoderLayer, self).__init__()\n        self.self_attn = Attention(d_model, num_heads=nhead, attn_drop=dropout)\n        # Implementation of Feedforward model\n        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)\n        self.dropout = nn.Dropout(dropout)\n        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)\n\n        self.norm_first = norm_first\n        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)\n        self.dropout1 = nn.Dropout(dropout)\n        self.dropout2 = nn.Dropout(dropout)\n\n        # Legacy string support for activation function.\n        if isinstance(activation, str):\n            self.activation = _get_activation_fn(activation)\n        else:\n            self.activation = activation\n\n    def __setstate__(self, state):\n        if 'activation' not in state:\n            state['activation'] = torch.nn.functional.relu\n        super(TransformerEncoderLayer, self).__setstate__(state)\n\n    def forward(self,\n                src: Tensor,\n                src_mask: Optional[Tensor] = None,\n                src_key_padding_mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Pass the input through the encoder layer.\n\n        Args:\n            src: the sequence to the encoder layer (required).\n            src_mask: the mask for the src sequence (optional).\n            src_key_padding_mask: the mask for the src keys per batch (optional).\n\n        Shape:\n            see the docs in Transformer class.\n        \"\"\"\n\n        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf\n\n        x = src\n        if self.norm_first:\n            x = x + self._sa_block(self.norm1(x), src_mask,\n                                   src_key_padding_mask)\n            x = x + self._ff_block(self.norm2(x))\n        else:\n            x = self.norm1(x +\n                           self._sa_block(x, src_mask, src_key_padding_mask))\n            x = self.norm2(x + self._ff_block(x))\n\n        return x\n\n    # self-attention block\n    def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor],\n                  key_padding_mask: Optional[Tensor]) -> Tensor:\n        # x = self.self_attn(x, x, x,\n        #                    attn_mask=attn_mask,\n        #                    key_padding_mask=key_padding_mask,\n        #                    need_weights=False)[0]\n        # TODO: add support for `attn_mask` / `key_padding_mask` if needed.\n        x = self.self_attn(x)\n        return self.dropout1(x)\n\n    # feed forward block\n    def _ff_block(self, x: Tensor) -> Tensor:\n        x = self.linear2(self.dropout(self.activation(self.linear1(x))))\n        return self.dropout2(x)\n\n\nclass TokenMixer(Enum):\n    DOT = 1\n    LINEAR = 2\n    ATTENTION = 3\n    CONVOLUTION = 4\n\n\n# util for generating a weight and a bias based on a size, and initializing them\ndef construct_w_b_pair(\n        shape: List[int],\n        uniform_const: float) -> Tuple[nn.Parameter, nn.Parameter]:\n    assert len(shape) == 2\n    w = nn.Parameter(\n        torch.empty(shape).uniform_(-1 * uniform_const, uniform_const))\n    b = nn.Parameter(\n        torch.empty([shape[0]]).uniform_(-1 * uniform_const,\n                                         uniform_const))  # UniformFillß\n\n    return w, b\n\n\n# The implementation of ZHEN layer is based on the paper: https://arxiv.org/pdf/2203.11014.pdf\n#\n# This is a single ZHEN layer. It:\n# - receives an input from the previous layer, or the embedding (first layer)\n# - receives the skip connection, which is the input to the previous layer (or nothing, in the case of first ZHEN layer)\n# - adds input and skip connection together, and treat it as the new input\n# and runs the new input through the different modules in token_mixer_list one by one, and concat them together as the ensemble.\n# It outputs the ensemble result and the new input\n# see https://bit.ly/3wNuqfz for a visualization.\nclass ZHENLayer(nn.Module):\n\n    def __init__(\n        self,\n        layer_index: int,\n        emb_dim: int,\n        token_mixer_list: List[\n            TokenMixer],  # determines this layer's output features\n        previous_n_embs:\n        int = 369,  # previous layer's output dim, may not be inferrable if token_mixer is different per layer. If 0th layer, this is original_n_embs.\n        previous_input_embs:\n        int = 369,  # skip connection's num embs. This is previous layer's input num embs.\n        output_embs_per_mixer: int = 50,  # each module outputs 50 embeddings\n        original_n_embs:\n        int = 369,  # whatever overarch gives us for the 0th zhen layer . the rest, is whatever output previous layer is\n    ):\n        super().__init__()\n        self.layer_index = layer_index\n        self.emb_dim = emb_dim\n        self.token_mixer_list = token_mixer_list\n        self.mismatched_skip_and_input_shape = previous_n_embs != previous_input_embs\n        if token_mixer_list is not None:\n            self.token_mixer_list = token_mixer_list\n        # self.sum_for_skip = sum_for_skip\n        zhen_n_embs = len(token_mixer_list) * output_embs_per_mixer\n        self.n_embs = zhen_n_embs\n        if self.layer_index != 0:\n            if self.mismatched_skip_and_input_shape:\n                self.match_w, self.match_b = construct_w_b_pair(\n                    [previous_n_embs, previous_input_embs], 0.0)\n\n        self.layer_norm_w = nn.Parameter(torch.empty(\n            [emb_dim]).fill_(0.0))  # ConstantFill\n        self.layer_norm_b = nn.Parameter(torch.empty(\n            [emb_dim]).fill_(0.0))  # ConstantFill\n        for token_mixer in self.token_mixer_list:\n            if token_mixer == TokenMixer.DOT:\n                self.ffn_w, self.ffn_b = construct_w_b_pair(\n                    [\n                        512,\n                        original_n_embs**2\n                        if self.layer_index == 0 else previous_n_embs**2,\n                    ],\n                    0.03125,\n                )\n                self.pool_w, self.pool_b = construct_w_b_pair(\n                    [\n                        output_embs_per_mixer * emb_dim,\n                        512,\n                    ],\n                    0.3125,\n                )\n            elif token_mixer == TokenMixer.LINEAR:  # n = 50\n                self.w_linear, self.b_linear = construct_w_b_pair(\n                    [output_embs_per_mixer, previous_n_embs], 0.0)\n\n            elif token_mixer == TokenMixer.ATTENTION:  # n = 50\n                self.encoder_layer = TransformerEncoderLayer(d_model=emb_dim,\n                                                             nhead=1,\n                                                             batch_first=True)\n\n                self.w_attention, self.b_attention = construct_w_b_pair(\n                    [output_embs_per_mixer, previous_n_embs], 0.0)\n\n            elif token_mixer == TokenMixer.CONVOLUTION:\n                self.conv = nn.Conv2d(1, 1, 5, stride=1, padding=(2, 2))\n                self.w_conv, self.b_conv = construct_w_b_pair(\n                    [\n                        output_embs_per_mixer,\n                        original_n_embs\n                        if self.layer_index == 0 else previous_n_embs,\n                    ],\n                    0.0,\n                )\n\n    def get_dense_params(self) -> List[nn.Parameter]:\n        # do not save because this may turn into FSDP\n        return list(self.parameters())\n\n    def forward(\n            self,\n            skip_connection: Optional[\n                torch.\n                Tensor],  # the skip connection, i.e., previous layer's input\n            input: torch.Tensor,  # this is previous layer's ensemble output\n            # B, D, F\n    ):\n        B = input.shape[0]\n        # process orig embs\n        # token mixer not None\n        if self.layer_index != 0:\n            if self.mismatched_skip_and_input_shape:\n                skip_connection = torch.nn.functional.linear(skip_connection,\n                                                             self.match_w,\n                                                             bias=self.match_b)\n            input_feature = skip_connection + input\n        else:\n            # 0th layer, no skip\n            input_feature = input\n\n        output = []  # do not call cat N times. Call it once.\n        for token_mixer in self.token_mixer_list:\n            if token_mixer == TokenMixer.DOT:  # num_dot_emb = 50\n                # B,D,F\n                input_feature_t = input_feature.permute(0, 2, 1)\n                # B,F,D\n                dot_products = torch.bmm(input_feature_t, input_feature)\n                # B,F,F\n                flattened_dot_products = torch.flatten(dot_products,\n                                                       start_dim=-2)  # Flatten\n                # B,F**2\n                r = torch.addmm(self.ffn_b, flattened_dot_products,\n                                self.ffn_w.t())  # FC\n                r_act = torch.relu(r)  # Relu\n                r_pooled = torch.nn.functional.linear(\n                    r_act,\n                    self.pool_w,\n                    bias=self.pool_b,\n                )\n                output.append(r_pooled.view(B, -1, self.emb_dim))\n\n            elif token_mixer == TokenMixer.LINEAR:\n                linear_emb_list = torch.nn.functional.linear(input_feature,\n                                                             self.w_linear,\n                                                             bias=self.b_linear)\n                flat_linear_emb_list = linear_emb_list.permute(0, 2, 1)\n                output.append(flat_linear_emb_list)\n\n            elif token_mixer == TokenMixer.ATTENTION:\n                # input: B,D,F\n                compress_list = torch.nn.functional.linear(\n                    input_feature, self.w_attention, bias=self.b_attention)\n                # B,D,O\n                compress_list_t = compress_list.permute(0, 2, 1)  # (B,O,D)\n                attention_emb_list = self.encoder_layer(compress_list_t)\n                output.append(attention_emb_list)\n\n            elif token_mixer == TokenMixer.CONVOLUTION:\n                reshape_input_feature = input_feature.reshape(\n                    B, 1, self.emb_dim, -1)\n                r_conv = self.conv(reshape_input_feature)\n                reshape_r_conv = r_conv.reshape(B, self.emb_dim, -1)\n                compress_list = torch.nn.functional.linear(\n                    reshape_r_conv, self.w_conv, bias=self.b_conv)  # B,output,D\n                flat_compress_list = compress_list.permute(0, 2, 1)\n                output.append(flat_compress_list)\n            else:\n                assert 0, f\"unknown module: {token_mixer}\"\n\n        # each output should be B,F,D\n        output = torch.cat(output, dim=1)\n        output_embs = torch.nn.functional.layer_norm(\n            output,\n            output.size()[2:],\n            weight=self.layer_norm_w,\n            bias=self.layer_norm_b,\n        )\n        return output_embs.permute(0, 2, 1), input_feature\n\n\n# ZHEN collection is different ZHEN layers\nclass ZHENCollection(nn.Module):\n\n    def __init__(\n        self,\n        num_layers: int,\n        emb_dim: int,\n        token_mixer_list: Union[List[TokenMixer], List[List[TokenMixer]]],\n        original_emb_num: int,\n        output_emb_per_ensemble_module: int,\n    ):\n        super().__init__()\n        self.num_layers = num_layers\n        self.emb_dim = emb_dim\n        self.token_mixer_list = token_mixer_list\n        self.layers: nn.ModuleList = nn.ModuleList([])\n\n        assert len(token_mixer_list) > 0\n        if type(token_mixer_list[0]) == list:\n            # this is a heterogeneous ZHEN\n            assert num_layers == len(\n                token_mixer_list\n            ), \"if token_mixer_list is a list of list of modules, ensure num_layers = len(token_mixer_list)\"  # noqa\n        else:\n            # this is a homogeneous ZHEN. Convert it to heterogeneous ZHEN\n            # pyre-ignore\n            token_mixer_list = [token_mixer_list] * num_layers\n\n        for i in range(num_layers):\n            layer = ZHENLayer(\n                layer_index=i,\n                emb_dim=emb_dim,\n                # pyre-ignore[6]\n                token_mixer_list=token_mixer_list[i],\n                previous_n_embs=(\n                    original_emb_num if i == 0\n                    # pyre-ignore[6]\n                    else len(token_mixer_list[i - 1]) *\n                    output_emb_per_ensemble_module),\n                previous_input_embs=(\n                    original_emb_num if i <= 1\n                    # pyre-ignore[6]\n                    else len(token_mixer_list[i - 2]) *\n                    output_emb_per_ensemble_module),\n                output_embs_per_mixer=output_emb_per_ensemble_module,\n                original_n_embs=original_emb_num,\n            )\n            self.layers.append(layer)\n\n    def forward(\n        self,\n        input: torch.Tensor,\n        skip_connection: Optional[torch.Tensor] = None,\n    ):\n        skip_connection = None  # previous layer's input\n        for layer in self.layers:\n            input, skip_connection = layer(skip_connection, input)\n\n        output = input.reshape(input.shape[0], -1)\n        return output\n\n    def get_dense_params(self) -> List[nn.Parameter]:\n        return list(self.parameters())\n\n\ndef weight_init_func(pt_module, name_map, params, bufs):\n    # for k, m in pt_module.named_modules():\n    #     if isinstance(m, torch.nn.Linear):\n    #         params[name_map[f\"{k}.weight\"]] = torch.nn.init.xavier_uniform(params[name_map[f\"{k}.weight\"]])\n    #         params[name_map[f\"{k}.bias\"]] = torch.nn.init.normal(params[name_map[f\"{k}.bias\"]], std=1e-6)\n    return params, bufs\n\n\nclass TorchZHENTest(unittest.TestCase):\n\n    def setUp(self):\n        torch.manual_seed(123)\n        alpa.set_seed(123)\n\n    def test_zhen_homogeneous(self):\n        B = 64  # made multiples of 8\n        F = 48  # made multiples of 8\n        D = 64\n        LAYERS = 5\n        OUTPUT_PER_ENSEMBLE = 48  # made multiples of 8\n        TOKENS = [\n            TokenMixer.ATTENTION, TokenMixer.LINEAR, TokenMixer.ATTENTION,\n            TokenMixer.CONVOLUTION, TokenMixer.DOT\n        ]\n\n        pt_module_gen = lambda: ZHENCollection(LAYERS, D, TOKENS, F,\n                                               OUTPUT_PER_ENSEMBLE)\n\n        dataloader = [(torch.empty(\n            B, D, F), torch.empty(B, D * LAYERS * OUTPUT_PER_ENSEMBLE))] * 2\n        loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss(\n            *args, **kwargs)\n        optim_gen = torchoptim.adam(lr=1e-3)\n        num_micro_batches = 2\n        parallel_method = alpa.PipeshardParallel(\n            num_micro_batches=num_micro_batches,\n            layer_option=alpa.AutoLayerOption(layer_num=2),\n            stage_option=\"auto\")\n\n        _xla_client_mem_fraction_orig_value = alpa.global_config.xla_client_mem_fraction\n        alpa.global_config.xla_client_mem_fraction = 0.7\n\n        train_torch_module(pt_module_gen, weight_init_func, dataloader,\n                           loss_func, optim_gen, parallel_method)\n\n        alpa.global_config.xla_client_mem_fraction = _xla_client_mem_fraction_orig_value\n\n    def test_zhen_heterogeneous(self):\n        B = 64\n        F = 37\n        D = 64\n        OUTPUT_PER_ENSEMBLE = 48  # 50  # made multiples of 8\n        TOKENS = [[TokenMixer.ATTENTION, TokenMixer.LINEAR],\n                  [\n                      TokenMixer.ATTENTION, TokenMixer.CONVOLUTION,\n                      TokenMixer.DOT\n                  ], [TokenMixer.LINEAR, TokenMixer.DOT]]  # 3-layer ZHEN\n\n        pt_module_gen = lambda: ZHENCollection(len(TOKENS), D, TOKENS, F,\n                                               OUTPUT_PER_ENSEMBLE)\n\n        dataloader = [(torch.empty(\n            B, D, F), torch.empty(B,\n                                  D * len(TOKENS[-1]) * OUTPUT_PER_ENSEMBLE))\n                     ] * 2\n        loss_func = lambda *args, **kwargs: torch.nn.functional.mse_loss(\n            *args, **kwargs)\n        optim_gen = torchoptim.adam(lr=1e-3)\n        num_micro_batches = 2\n        parallel_method = alpa.PipeshardParallel(\n            num_micro_batches=num_micro_batches,\n            layer_option=alpa.AutoLayerOption(layer_num=2),\n            stage_option=\"auto\")\n\n        _xla_client_mem_fraction_orig_value = alpa.global_config.xla_client_mem_fraction\n        alpa.global_config.xla_client_mem_fraction = 0.7\n\n        train_torch_module(pt_module_gen, weight_init_func, dataloader,\n                           loss_func, optim_gen, parallel_method)\n\n        alpa.global_config.xla_client_mem_fraction = _xla_client_mem_fraction_orig_value\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(TorchZHENTest(\"test_zhen_homogeneous\"))\n    suite.addTest(TorchZHENTest(\"test_zhen_heterogeneous\"))\n    return suite\n\n\nif __name__ == '__main__':\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/tpu/test_create_state_parallel.py",
    "content": "\"\"\"Test CreateStateParallel on TPU.\"\"\"\nimport unittest\n\nfrom alpa import global_config\n\nimport tests.runtime.test_create_state as test_create_state\nfrom tests.tpu.test_shard_parallel import has_tpu\n\n\nclass TpuCreateStateTest(test_create_state.CreateStateTest):\n\n    def setUp(self):\n        global_config.backend = \"tpu\"\n\n    def tearDown(self):\n        return\n\n    @unittest.skip(\"unsupported yet.\")\n    def test_shard_parallel_grad_acc(self):\n        super().test_shard_parallel_grad_acc()\n\n    @unittest.skip(\"unsupported yet.\")\n    def test_pipeshard_parallel(self):\n        super().test_pipeshard_parallel()\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    if not has_tpu():\n        return suite\n\n    suite.addTest(TpuCreateStateTest(\"test_shard_parallel\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())"
  },
  {
    "path": "tests/tpu/test_follow_parallel.py",
    "content": "\"\"\"Test FollowParallel on TPU.\"\"\"\nimport unittest\n\nfrom alpa import global_config\n\nimport tests.runtime.test_follow_parallel as test_follow_parallel\nfrom tests.tpu.test_shard_parallel import has_tpu\n\n\nclass TpuFollowParallelTest(test_follow_parallel.FollowParallelTest):\n\n    def setUp(self):\n        global_config.backend = \"tpu\"\n\n    def tearDown(self):\n        return\n\n    @unittest.skip(\"unsupported yet.\")\n    def test_shard_parallel_grad_acc(self):\n        super().test_shard_parallel_grad_acc()\n\n    @unittest.skip(\"unsupported yet.\")\n    def test_pipeshard_parallel(self):\n        super().test_pipeshard_parallel()\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    if not has_tpu():\n        return suite\n\n    suite.addTest(TpuFollowParallelTest(\"test_shard_parallel\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())"
  },
  {
    "path": "tests/tpu/test_shard_parallel.py",
    "content": "\"\"\"Test auto sharding with MLP and MoE on TPU.\"\"\"\nimport unittest\n\nimport jax\n\nfrom alpa import global_config\n\nimport tests.shard_parallel.test_mlp as test_mlp\nimport tests.shard_parallel.test_moe as test_moe\n\nwith_device = {}\n\n\ndef has_device(name):\n    global with_device\n    if name in with_device:\n        return with_device[name]\n    try:\n        jax.devices(name)\n        with_device[name] = True\n    except RuntimeError:\n        with_device[name] = False\n    return with_device[name]\n\n\ndef has_tpu():\n    return has_device(\"tpu\")\n\n\ndef has_gpu():\n    return has_device(\"gpu\")\n\n\nclass AutoShardingTpuMlpTest(test_mlp.AutoShardingMLPTest):\n\n    def setUp(self):\n        global_config.backend = \"tpu\"\n        super().setUp()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_n_layer_mlp_data_parallel_reduce_scatter(self):\n        super().test_n_layer_mlp_data_parallel_reduce_scatter()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_n_layer_mlp_model_parallel_reduce_scatter(self):\n        super().test_n_layer_mlp_model_parallel_reduce_scatter()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_n_layer_mlp_2d_mesh_reduce_scatter(self):\n        super().test_n_layer_mlp_2d_mesh_reduce_scatter()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_n_layer_mlp_data_parallel_reduce_scatter_adafactor(self):\n        super().test_n_layer_mlp_data_parallel_reduce_scatter_adafactor()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3(self):\n        super().test_n_layer_mlp_data_parallel_reduce_scatter_zero_stage_3()\n\n\nclass AutoShardingTpuMoeTest(test_moe.AutoShardingMoETest):\n\n    def setUp(self):\n        global_config.backend = \"tpu\"\n        super().setUp()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_moe_layer_2d_reduce_scatter(self):\n        super().test_moe_layer_2d_reduce_scatter()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_moe_lm_reduce_scatter(self):\n        super().test_moe_lm_reduce_scatter()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_moe_lm_2d_reduce_scatter(self):\n        super().test_moe_lm_2d_reduce_scatter()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_moe_lm_data_parallel_reduce_scatter(self):\n        super().test_moe_lm_data_parallel_reduce_scatter()\n\n    @unittest.skip(\"unsupported yet\")\n    def test_moe_lm_data_parallel_reduce_scatter_zero_3(self):\n        super().test_moe_lm_data_parallel_reduce_scatter_zero_3()\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    if not has_tpu():\n        return suite\n\n    def add_mlp(name):\n        suite.addTest(AutoShardingTpuMlpTest(name))\n\n    def add_moe(name):\n        suite.addTest(AutoShardingTpuMoeTest(name))\n\n    add_mlp(\"test_n_layer_mlp_data_parallel\")\n    add_mlp(\"test_n_layer_mlp_model_parallel\")\n    add_mlp(\"test_n_layer_mlp_2d_mesh\")\n    add_mlp(\"test_n_layer_mlp_force_data_parallel\")\n    add_mlp(\"test_n_layer_mlp_force_batch_dim_mapping\")\n    add_mlp(\"test_weight_init\")\n\n    add_moe(\"test_moe_layer\")\n    add_moe(\"test_moe_layer_2d\")\n    add_moe(\"test_moe_lm\")\n    add_moe(\"test_moe_lm_2d\")\n    add_moe(\"test_moe_lm_data_parallel\")\n\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())"
  },
  {
    "path": "tests/util/test_hlo_cost_model.py",
    "content": "\"\"\"Test HLO cost model.\"\"\"\nimport pickle\nimport unittest\n\nimport jax\nimport jax.numpy as jnp\nfrom flax import linen as nn\nfrom flax.training.train_state import TrainState\nimport optax\nimport ray\n\nfrom alpa import (init, parallelize, global_config, ShardParallel,\n                  LocalPhysicalDeviceMesh, ProfilingResultDatabase)\nfrom alpa.device_mesh import get_global_cluster\nfrom alpa.mesh_profiling import estimate_hlo_module_cost\nfrom alpa.util import map_to_shape\n\n\nclass HloCostModelTest(unittest.TestCase):\n\n    def run_n_layer_mlp(self,\n                        num_layers,\n                        batch_size,\n                        input_dim,\n                        output_dim,\n                        hidden_dim,\n                        device_mesh,\n                        use_bias=True):\n\n        class Model(nn.Module):\n\n            @nn.compact\n            def __call__(self, x):\n                for i in range(num_layers - 1):\n                    x = nn.Dense(features=hidden_dim, use_bias=use_bias)(x)\n                    x = nn.relu(x)\n                x = nn.Dense(features=output_dim, use_bias=use_bias)(x)\n                return x\n\n        @parallelize(method=ShardParallel(devices=device_mesh))\n        def train_step(state, batch):\n\n            def loss_func(params):\n                out = state.apply_fn(params, batch[\"x\"])\n                return jnp.mean((out - batch[\"y\"])**2)\n\n            grads = jax.grad(loss_func)(state.params)\n            new_state = state.apply_gradients(grads=grads)\n            return new_state\n\n        x = jnp.ones((batch_size, input_dim))\n        y = jnp.ones((batch_size, output_dim))\n\n        # Init train state\n        model = Model()\n        rngkey = jax.random.PRNGKey(0)\n        params = model.init(rngkey, x)\n        tx = optax.adam(learning_rate=1e-2)\n        state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)\n\n        # Get optimized HLO IR\n        executable = train_step.get_executable(state, {\"x\": x, \"y\": y})\n        return executable.compiled.hlo_modules()[0]\n\n    def test_cluster_profling(self):\n        init(cluster=\"ray\")\n        cluster = get_global_cluster()\n        manually_specified_submeshes = [\n            (1, 1),\n            cluster.get_virtual_physical_mesh().shape,\n        ]\n\n        prof_database = cluster.profile_all(\n            \"p3.16\",\n            2,\n            2,\n            max_fail_retry=5,\n            cache_filename=\"tmp_cache.pkl\",\n            dot_range=(0, 1),\n            mesh_size_choices=manually_specified_submeshes)\n        prof_database.save(\"tmp_prof_database.pkl\")\n\n    @unittest.skip(\"Temporary disabled due to being flaky\")\n    def test_n_layer_mlp(self):\n        num_layers = 2\n        batch_size = 32\n        hidden_dim = 16\n\n        prof_database = ProfilingResultDatabase()\n        prof_database.load(\"tmp_prof_database.pkl\")\n\n        device_mesh = LocalPhysicalDeviceMesh()\n        hlo_module = self.run_n_layer_mlp(num_layers, batch_size, hidden_dim,\n                                          hidden_dim, hidden_dim, device_mesh)\n        mesh_result = prof_database.query(\"p3.16\", device_mesh.shape)\n        cost = estimate_hlo_module_cost(hlo_module, mesh_result)\n        # assert cost > 0\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(HloCostModelTest(\"test_cluster_profling\"))\n    suite.addTest(HloCostModelTest(\"test_n_layer_mlp\"))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "tests/util/test_ordered_set.py",
    "content": "\"\"\"Test OrderedSet.\"\"\"\n\nimport os\nimport unittest\n\nfrom alpa.util import OrderedSet\n\n\nclass OrderedSetTest(unittest.TestCase):\n    \"\"\"Test OrderedSet.\"\"\"\n\n    def test_init(self):\n        \"\"\"Test OrderedSet.__init__.\"\"\"\n        oset = OrderedSet()\n        self.assertEqual(len(oset), 0)\n\n        oset = OrderedSet([1, 2, 3])\n        self.assertEqual(len(oset), 3)\n\n    def test_add(self):\n        \"\"\"Test OrderedSet.add.\"\"\"\n        oset = OrderedSet()\n        oset.add(1)\n        self.assertEqual(len(oset), 1)\n\n        oset.add(2)\n        self.assertEqual(len(oset), 2)\n\n    def test_update(self):\n        \"\"\"Test OrderedSet.update.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        oset.update([4, 5])\n        self.assertEqual(len(oset), 5)\n        self.assertEqual(oset, OrderedSet([1, 2, 3, 4, 5]))\n\n    def test_union(self):\n        \"\"\"Test OrderedSet.union.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        self.assertEqual(oset.union([4, 5]), OrderedSet([1, 2, 3, 4, 5]))\n\n    def test_intersection_update(self):\n        \"\"\"Test OrderedSet.intersection_update.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        oset.intersection_update([2, 3, 4])\n        self.assertEqual(len(oset), 2)\n        self.assertEqual(oset, OrderedSet([2, 3]))\n\n        oset = OrderedSet([1, 2, 3])\n        oset.intersection_update([2, 3, 4])\n        self.assertEqual(len(oset), 2)\n        self.assertEqual(oset, OrderedSet([2, 3]))\n\n    def test_intersection(self):\n        \"\"\"Test OrderedSet.intersection.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        result = oset.intersection([2, 3, 4])\n        self.assertEqual(len(result), 2)\n        self.assertEqual(result, OrderedSet([2, 3]))\n\n    def test_remove(self):\n        \"\"\"Test OrderedSet.remove.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        oset.remove(2)\n        self.assertEqual(len(oset), 2)\n        self.assertEqual(oset, OrderedSet([1, 3]))\n\n    def test_discard(self):\n        \"\"\"Test OrderedSet.discard.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        oset.discard(2)\n        self.assertEqual(len(oset), 2)\n        self.assertEqual(oset, OrderedSet([1, 3]))\n\n        oset.discard(4)\n        self.assertEqual(len(oset), 2)\n        self.assertEqual(oset, OrderedSet([1, 3]))\n\n    def test_clear(self):\n        \"\"\"Test OrderedSet.clear.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        oset.clear()\n        self.assertEqual(len(oset), 0)\n\n    def test_difference(self):\n        \"\"\"Test OrderedSet.difference.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        result = oset.difference([2, 3, 4])\n        self.assertEqual(len(result), 1)\n        self.assertEqual(result, OrderedSet([1]))\n\n    def test_difference_update(self):\n        \"\"\"Test OrderedSet.difference_update.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        oset.difference_update([2, 3, 4])\n        self.assertEqual(len(oset), 1)\n        self.assertEqual(oset, OrderedSet([1]))\n\n    def test_symmetric_difference(self):\n        \"\"\"Test OrderedSet.symmetric_difference.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        result = oset.symmetric_difference([2, 3, 4])\n        self.assertEqual(len(result), 2)\n        self.assertEqual(result, OrderedSet([1, 4]))\n\n    def test_repr(self):\n        \"\"\"Test OrderedSet.__repr__.\"\"\"\n        oset = OrderedSet([1, 2, 3])\n        self.assertEqual(repr(oset), 'OrderedSet([1, 2, 3])')\n\n\ndef suite():\n    suite = unittest.TestSuite()\n    suite.addTest(unittest.makeSuite(OrderedSetTest))\n    return suite\n\n\nif __name__ == \"__main__\":\n    runner = unittest.TextTestRunner()\n    runner.run(suite())\n"
  },
  {
    "path": "update_version.py",
    "content": "# Licensed to the Apache Software Foundation (ASF) under one\n# or more contributor license agreements.  See the NOTICE file\n# distributed with this work for additional information\n# regarding copyright ownership.  The ASF licenses this file\n# to you under the Apache License, Version 2.0 (the\n# \"License\"); you may not use this file except in compliance\n# with the License.  You may obtain a copy of the License at\n#\n#   http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing,\n# software distributed under the License is distributed on an\n# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n# KIND, either express or implied.  See the License for the\n# specific language governing permissions and limitations\n# under the License.\n\n\"\"\"\nThis is the global script that set the version information of Alpa.\nThis script runs and update all the locations that related to versions\n\nList of affected files:\n- root/python/alpa/version.py\n\"\"\"\nimport os\nimport re\nimport argparse\nimport logging\nimport subprocess\n\n# Modify the following value during release\n# ---------------------------------------------------\n# Current version:\n# We use the version of the incoming release for code\n# that is under development.\n#\n# It is also fallback version to be used when --git-describe\n# is not invoked, or when the repository does not present the\n# git tags in a format that this script can use.\n#\n# Two tag formats are supported:\n# - vMAJ.MIN.PATCH (e.g. v0.8.0) or\n# - vMAJ.MIN.devN (e.g. v0.8.dev0)\n__version__ = \"v0.2.dev0\"\n\n# ---------------------------------------------------\n\nPROJ_ROOT = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))\n\n\ndef py_str(cstr):\n    return cstr.decode(\"utf-8\")\n\n\ndef git_describe_version():\n    \"\"\"Get PEP-440 compatible public and local version using git describe.\n\n    Returns\n    -------\n    pub_ver: str\n        Public version.\n\n    local_ver: str\n        Local version (with additional label appended to pub_ver).\n\n    Notes\n    -----\n    - We follow PEP 440's convention of public version\n      and local versions.\n    - Only tags conforming to vMAJOR.MINOR.REV (e.g. \"v0.7.0\")\n      are considered in order to generate the version string.\n      See the use of `--match` in the `git` command below.\n\n    Here are some examples:\n\n    - pub_ver = '0.7.0', local_ver = '0.7.0':\n      We are at the 0.7.0 release.\n    - pub_ver =  '0.8.dev94', local_ver = '0.8.dev94+g0d07a329e':\n      We are at the the 0.8 development cycle.\n      The current source contains 94 additional commits\n      after the most recent tag(v0.7.0),\n      the git short hash tag of the current commit is 0d07a329e.\n    \"\"\"\n    cmd = [\n        \"git\",\n        \"describe\",\n        \"--tags\",\n        \"--match\",\n        \"v[0-9]*.[0-9]*.[0-9]*\",\n        \"--match\",\n        \"v[0-9]*.[0-9]*.dev[0-9]*\",\n    ]\n    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=PROJ_ROOT)\n    (out, _) = proc.communicate()\n\n    if proc.returncode != 0:\n        msg = py_str(out)\n        if msg.find(\"not a git repository\") != -1:\n            return __version__, __version__\n        logging.warning(\"git describe: %s, use %s\", msg, __version__)\n        return __version__, __version__\n    describe = py_str(out).strip()\n    arr_info = describe.split(\"-\")\n\n    # Remove the v prefix, mainly to be robust\n    # to the case where v is not presented as well.\n    if arr_info[0].startswith(\"v\"):\n        arr_info[0] = arr_info[0][1:]\n\n    # hit the exact tag\n    if len(arr_info) == 1:\n        return arr_info[0], arr_info[0]\n\n    if len(arr_info) != 3:\n        logging.warning(\"Invalid output from git describe %s\", describe)\n        return __version__, __version__\n\n    dev_pos = arr_info[0].find(\".dev\")\n\n    # Development versions:\n    # The code will reach this point in case it can't match a full release version, such as v0.7.0.\n    #\n    # 1. in case the last known label looks like vMAJ.MIN.devN e.g. v0.8.dev0, we use\n    # the current behaviour of just using vMAJ.MIN.devNNNN+gGIT_REV\n    if dev_pos != -1:\n        dev_version = arr_info[0][: arr_info[0].find(\".dev\")]\n    # 2. in case the last known label looks like vMAJ.MIN.PATCH e.g. v0.8.0\n    # then we just carry on with a similar version to what git describe provides, which is\n    # vMAJ.MIN.PATCH.devNNNN+gGIT_REV\n    else:\n        dev_version = arr_info[0]\n\n    pub_ver = \"%s.dev%s\" % (dev_version, arr_info[1])\n    local_ver = \"%s+%s\" % (pub_ver, arr_info[2])\n    return pub_ver, local_ver\n\n\n# Implementations\ndef update(file_name, pattern, repl, dry_run=False):\n    update = []\n    hit_counter = 0\n    need_update = False\n    with open(file_name) as file:\n        for l in file:\n            result = re.findall(pattern, l)\n            if result:\n                assert len(result) == 1\n                hit_counter += 1\n                if result[0] != repl:\n                    l = re.sub(pattern, repl, l)\n                    need_update = True\n                    print(\"%s: %s -> %s\" % (file_name, result[0], repl))\n                else:\n                    print(\"%s: version is already %s\" % (file_name, repl))\n\n            update.append(l)\n    if hit_counter != 1:\n        raise RuntimeError(\"Cannot find version in %s\" % file_name)\n\n    if need_update and not dry_run:\n        with open(file_name, \"w\") as output_file:\n            for l in update:\n                output_file.write(l)\n\n\ndef sync_version(pub_ver, local_ver, dry_run):\n    \"\"\"Synchronize version.\"\"\"\n    # python uses the PEP-440: local version\n    update(\n        os.path.join(PROJ_ROOT, \"alpa\", \"version.py\"),\n        r\"(?<=__version__ = \\\")[.0-9a-z\\+]+\",\n        local_ver,\n        dry_run,\n    )\n\n\ndef main():\n    logging.basicConfig(level=logging.INFO)\n    parser = argparse.ArgumentParser(description=\"Detect and synchronize version.\")\n    parser.add_argument(\n        \"--print-version\",\n        action=\"store_true\",\n        help=\"Print version to the command line. No changes is applied to files.\",\n    )\n    parser.add_argument(\n        \"--git-describe\",\n        action=\"store_true\",\n        help=\"Use git describe to generate development version.\",\n    )\n    parser.add_argument(\"--dry-run\", action=\"store_true\")\n\n    opt = parser.parse_args()\n    pub_ver, local_ver = __version__, __version__\n    if opt.git_describe:\n        pub_ver, local_ver = git_describe_version()\n    if opt.print_version:\n        print(local_ver)\n    else:\n        sync_version(pub_ver, local_ver, opt.dry_run)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]